From c4ee6d9b73300d906b8df4602157d646df2415ef Mon Sep 17 00:00:00 2001 From: Robert Barron Date: Sun, 30 Jul 2023 00:41:10 -0700 Subject: [PATCH 0001/1174] xyz_grid: allow varying the seed along an axis along with the axis's other changes --- scripts/xyz_grid.py | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/scripts/xyz_grid.py b/scripts/xyz_grid.py index 1010845e56b..4eb1b197d50 100644 --- a/scripts/xyz_grid.py +++ b/scripts/xyz_grid.py @@ -416,6 +416,10 @@ def ui(self, is_img2img): with gr.Column(): include_lone_images = gr.Checkbox(label='Include Sub Images', value=False, elem_id=self.elem_id("include_lone_images")) include_sub_grids = gr.Checkbox(label='Include Sub Grids', value=False, elem_id=self.elem_id("include_sub_grids")) + with gr.Column(): + vary_seeds_x = gr.Checkbox(label='Vary seed on X axis', value=False, elem_id=self.elem_id("vary_seeds_x")) + vary_seeds_y = gr.Checkbox(label='Vary seed on Y axis', value=False, elem_id=self.elem_id("vary_seeds_y")) + vary_seeds_z = gr.Checkbox(label='Vary seed on Z axis', value=False, elem_id=self.elem_id("vary_seeds_z")) with gr.Column(): margin_size = gr.Slider(label="Grid margins (px)", minimum=0, maximum=500, value=0, step=2, elem_id=self.elem_id("margin_size")) @@ -475,9 +479,9 @@ def get_dropdown_update_from_params(axis,params): (z_values_dropdown, lambda params:get_dropdown_update_from_params("Z",params)), ) - return [x_type, x_values, x_values_dropdown, y_type, y_values, y_values_dropdown, z_type, z_values, z_values_dropdown, draw_legend, include_lone_images, include_sub_grids, no_fixed_seeds, margin_size] + return [x_type, x_values, x_values_dropdown, y_type, y_values, y_values_dropdown, z_type, z_values, z_values_dropdown, draw_legend, include_lone_images, include_sub_grids, no_fixed_seeds, vary_seeds_x, vary_seeds_y, vary_seeds_z, margin_size] - def run(self, p, x_type, x_values, x_values_dropdown, y_type, y_values, y_values_dropdown, z_type, z_values, z_values_dropdown, draw_legend, include_lone_images, include_sub_grids, no_fixed_seeds, margin_size): + def run(self, p, x_type, x_values, x_values_dropdown, y_type, y_values, y_values_dropdown, z_type, z_values, z_values_dropdown, draw_legend, include_lone_images, include_sub_grids, no_fixed_seeds, vary_seeds_x, vary_seeds_y, vary_seeds_z, margin_size): if not no_fixed_seeds: modules.processing.fix_seed(p) @@ -648,6 +652,16 @@ def cell(x, y, z, ix, iy, iz): y_opt.apply(pc, y, ys) z_opt.apply(pc, z, zs) + xdim = len(xs) if vary_seeds_x else 1 + ydim = len(ys) if vary_seeds_y else 1 + + if vary_seeds_x: + pc.seed += ix + if vary_seeds_y: + pc.seed += iy * xdim + if vary_seeds_z: + pc.seed += iz * xdim * ydim + res = process_images(pc) # Sets subgrid infotexts From 4130e5db3d5c2fa7cbfe9e09e5eadba1ed958ab0 Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Fri, 25 Aug 2023 10:12:19 +0900 Subject: [PATCH 0002/1174] img2img batch PNG info model hash --- modules/img2img.py | 12 +++++++++++- modules/ui.py | 2 +- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/modules/img2img.py b/modules/img2img.py index 1519e132b2b..c81c7ab9e8c 100644 --- a/modules/img2img.py +++ b/modules/img2img.py @@ -10,6 +10,7 @@ from modules.generation_parameters_copypaste import create_override_settings_dict, parse_generation_parameters from modules.processing import Processed, StableDiffusionProcessingImg2Img, process_images from modules.shared import opts, state +from modules.sd_models import get_closet_checkpoint_match import modules.shared as shared import modules.processing as processing from modules.ui import plaintext_to_html @@ -41,7 +42,8 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=Fal cfg_scale = p.cfg_scale sampler_name = p.sampler_name steps = p.steps - + override_settings = p.override_settings + sd_model_checkpoint_override = get_closet_checkpoint_match(override_settings.get("sd_model_checkpoint", None)) for i, image in enumerate(images): state.job = f"{i+1} out of {len(images)}" if state.skipped: @@ -104,6 +106,14 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=Fal p.sampler_name = parsed_parameters.get("Sampler", sampler_name) p.steps = int(parsed_parameters.get("Steps", steps)) + model_info = get_closet_checkpoint_match(parsed_parameters.get("Model hash", None)) + if model_info is not None: + p.override_settings['sd_model_checkpoint'] = model_info.name + elif sd_model_checkpoint_override: + p.override_settings['sd_model_checkpoint'] = sd_model_checkpoint_override + else: + p.override_settings.pop("sd_model_checkpoint", None) + proc = modules.scripts.scripts_img2img.run(p, *args) if proc is None: if output_dir: diff --git a/modules/ui.py b/modules/ui.py index 2b6a13cbb6c..9c5082c3190 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -614,7 +614,7 @@ def update_orig(image, state): with gr.Accordion("PNG info", open=False): img2img_batch_use_png_info = gr.Checkbox(label="Append png info to prompts", **shared.hide_dirs, elem_id="img2img_batch_use_png_info") img2img_batch_png_info_dir = gr.Textbox(label="PNG info directory", **shared.hide_dirs, placeholder="Leave empty to use input directory", elem_id="img2img_batch_png_info_dir") - img2img_batch_png_info_props = gr.CheckboxGroup(["Prompt", "Negative prompt", "Seed", "CFG scale", "Sampler", "Steps"], label="Parameters to take from png info", info="Prompts from png info will be appended to prompts set in ui.") + img2img_batch_png_info_props = gr.CheckboxGroup(["Prompt", "Negative prompt", "Seed", "CFG scale", "Sampler", "Steps", "Model hash"], label="Parameters to take from png info", info="Prompts from png info will be appended to prompts set in ui.") img2img_tabs = [tab_img2img, tab_sketch, tab_inpaint, tab_inpaint_color, tab_inpaint_upload, tab_batch] From 3369fb27df6c1badd39bcb59b3f71c61a47d3d91 Mon Sep 17 00:00:00 2001 From: SpenserCai Date: Fri, 25 Aug 2023 22:15:35 +0800 Subject: [PATCH 0003/1174] support installed extensions list api --- modules/api/api.py | 20 ++++++++++++++++++++ modules/api/models.py | 9 +++++++++ 2 files changed, 29 insertions(+) diff --git a/modules/api/api.py b/modules/api/api.py index e6edffe7144..0bcf54977a0 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -243,6 +243,7 @@ def __init__(self, app: FastAPI, queue_lock: Lock): self.add_api_route("/sdapi/v1/reload-checkpoint", self.reloadapi, methods=["POST"]) self.add_api_route("/sdapi/v1/scripts", self.get_scripts_list, methods=["GET"], response_model=models.ScriptsList) self.add_api_route("/sdapi/v1/script-info", self.get_script_info, methods=["GET"], response_model=List[models.ScriptInfo]) + self.add_api_route("/sdapi/v1/extensions", self.get_extensions_list, methods=["GET"], response_model=List[models.ExtensionItem]) if shared.cmd_opts.api_server_stop: self.add_api_route("/sdapi/v1/server-kill", self.kill_webui, methods=["POST"]) @@ -769,6 +770,25 @@ def get_memory(self): except Exception as err: cuda = {'error': f'{err}'} return models.MemoryResponse(ram=ram, cuda=cuda) + + def get_extensions_list(self): + from modules import extensions + extensions.list_extensions() + ext_list = [] + for ext in extensions.extensions: + ext: extensions.Extension + ext.read_info_from_repo() + if ext.remote is not None: + ext_list.append({ + "name": ext.name, + "remote": ext.remote, + "branch": ext.branch, + "commit_hash":ext.commit_hash, + "commit_date":ext.commit_date, + "version":ext.version, + "enabled":ext.enabled + }) + return ext_list def launch(self, server_name, port, root_path): self.app.include_router(self.router) diff --git a/modules/api/models.py b/modules/api/models.py index 6a574771c33..731ab03da5f 100644 --- a/modules/api/models.py +++ b/modules/api/models.py @@ -311,3 +311,12 @@ class ScriptInfo(BaseModel): is_alwayson: bool = Field(default=None, title="IsAlwayson", description="Flag specifying whether this script is an alwayson script") is_img2img: bool = Field(default=None, title="IsImg2img", description="Flag specifying whether this script is an img2img script") args: List[ScriptArg] = Field(title="Arguments", description="List of script's arguments") + +class ExtensionItem(BaseModel): + name: str = Field(title="Name", description="Extension name") + remote: str = Field(title="Remote", description="Extension Repository URL") + branch: str = Field(title="Branch", description="Extension Repository Branch") + commit_hash: str = Field(title="Commit Hash", description="Extension Repository Commit Hash") + version: str = Field(title="Version", description="Extension Version") + commit_date: str = Field(title="Commit Date", description="Extension Repository Commit Date") + enabled: bool = Field(title="Enabled", description="Flag specifying whether this extension is enabled") From dd07b5193efa547929629b310ef5c9ff0fc83a19 Mon Sep 17 00:00:00 2001 From: SpenserCai Date: Fri, 25 Aug 2023 22:23:17 +0800 Subject: [PATCH 0004/1174] fix format error --- modules/api/api.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/api/api.py b/modules/api/api.py index 0bcf54977a0..785ee828784 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -770,7 +770,7 @@ def get_memory(self): except Exception as err: cuda = {'error': f'{err}'} return models.MemoryResponse(ram=ram, cuda=cuda) - + def get_extensions_list(self): from modules import extensions extensions.list_extensions() From db56bdce33264aab5e6b565a41df503d785bbbff Mon Sep 17 00:00:00 2001 From: catboxanon <122327233+catboxanon@users.noreply.github.com> Date: Fri, 25 Aug 2023 16:04:06 -0400 Subject: [PATCH 0005/1174] Don't show hidden samplers in dropdown for XYZ script --- scripts/xyz_grid.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/scripts/xyz_grid.py b/scripts/xyz_grid.py index daaf761f165..517d6332ed2 100644 --- a/scripts/xyz_grid.py +++ b/scripts/xyz_grid.py @@ -238,9 +238,9 @@ def __init__(self, *args, **kwargs): AxisOptionImg2Img("Image CFG Scale", float, apply_field("image_cfg_scale")), AxisOption("Prompt S/R", str, apply_prompt, format_value=format_value), AxisOption("Prompt order", str_permutations, apply_order, format_value=format_value_join_list), - AxisOptionTxt2Img("Sampler", str, apply_field("sampler_name"), format_value=format_value, confirm=confirm_samplers, choices=lambda: [x.name for x in sd_samplers.samplers]), - AxisOptionTxt2Img("Hires sampler", str, apply_field("hr_sampler_name"), confirm=confirm_samplers, choices=lambda: [x.name for x in sd_samplers.samplers_for_img2img]), - AxisOptionImg2Img("Sampler", str, apply_field("sampler_name"), format_value=format_value, confirm=confirm_samplers, choices=lambda: [x.name for x in sd_samplers.samplers_for_img2img]), + AxisOptionTxt2Img("Sampler", str, apply_field("sampler_name"), format_value=format_value, confirm=confirm_samplers, choices=lambda: [x.name for x in sd_samplers.samplers if x.name not in opts.hide_samplers]), + AxisOptionTxt2Img("Hires sampler", str, apply_field("hr_sampler_name"), confirm=confirm_samplers, choices=lambda: [x.name for x in sd_samplers.samplers_for_img2img if x.name not in opts.hide_samplers]), + AxisOptionImg2Img("Sampler", str, apply_field("sampler_name"), format_value=format_value, confirm=confirm_samplers, choices=lambda: [x.name for x in sd_samplers.samplers_for_img2img if x.name not in opts.hide_samplers]), AxisOption("Checkpoint name", str, apply_checkpoint, format_value=format_remove_path, confirm=confirm_checkpoints, cost=1.0, choices=lambda: sorted(sd_models.checkpoints_list, key=str.casefold)), AxisOption("Negative Guidance minimum sigma", float, apply_field("s_min_uncond")), AxisOption("Sigma Churn", float, apply_field("s_churn")), From bb90b0ff42ea55cbc73df15ea1ef8fd79af2e026 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sat, 26 Aug 2023 06:33:48 +0300 Subject: [PATCH 0006/1174] fix defaults settings page breaking when any of main UI tabs are hidden --- modules/ui.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/modules/ui.py b/modules/ui.py index 9c5082c3190..f4028475661 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1279,11 +1279,8 @@ def get_textual_inversion_template_names(): with gr.TabItem(label, id=ifid, elem_id=f"tab_{ifid}"): interface.render() - for interface, _label, ifid in interfaces: - if ifid in ["extensions", "settings"]: - continue - - loadsave.add_block(interface, ifid) + if ifid not in ["extensions", "settings"]: + loadsave.add_block(interface, ifid) loadsave.add_component(f"webui/Tabs@{tabs.elem_id}", tabs) From 72ee347eabf04d1a238a738a03e7973cc2a46ca3 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sat, 26 Aug 2023 06:52:18 +0300 Subject: [PATCH 0007/1174] update pnginfo checkpoint to return dict with parsed values --- modules/api/api.py | 10 ++++------ modules/api/models.py | 3 ++- 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/modules/api/api.py b/modules/api/api.py index 785ee828784..844e31ee75a 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -17,7 +17,7 @@ from secrets import compare_digest import modules.shared as shared -from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing, errors, restart, shared_items +from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing, errors, restart, shared_items, script_callbacks, generation_parameters_copypaste from modules.api import models from modules.shared import opts from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images @@ -474,9 +474,6 @@ def extras_batch_images_api(self, req: models.ExtrasBatchImagesRequest): return models.ExtrasBatchImagesResponse(images=list(map(encode_pil_to_base64, result[0])), html_info=result[1]) def pnginfoapi(self, req: models.PNGInfoRequest): - if(not req.image.strip()): - return models.PNGInfoResponse(info="") - image = decode_base64_to_image(req.image.strip()) if image is None: return models.PNGInfoResponse(info="") @@ -485,9 +482,10 @@ def pnginfoapi(self, req: models.PNGInfoRequest): if geninfo is None: geninfo = "" - items = {**{'parameters': geninfo}, **items} + params = generation_parameters_copypaste.parse_generation_parameters(geninfo) + script_callbacks.infotext_pasted_callback(geninfo, params) - return models.PNGInfoResponse(info=geninfo, items=items) + return models.PNGInfoResponse(info=geninfo, items=items, parameters=params) def progressapi(self, req: models.ProgressRequest = Depends()): # copy from check_progress_call of ui.py diff --git a/modules/api/models.py b/modules/api/models.py index 731ab03da5f..94eca97dcb0 100644 --- a/modules/api/models.py +++ b/modules/api/models.py @@ -178,7 +178,8 @@ class PNGInfoRequest(BaseModel): class PNGInfoResponse(BaseModel): info: str = Field(title="Image info", description="A string with the parameters used to generate the image") - items: dict = Field(title="Items", description="An object containing all the info the image had") + items: dict = Field(title="Items", description="A dictionary containing all the other fields the image had") + parameters: dict = Field(title="Parameters", description="A dictionary with parsed generation info fields") class ProgressRequest(BaseModel): skip_current_image: bool = Field(default=False, title="Skip current image", description="Skip current image serialization") From ec54257cb21bacd6281a5f9c6f74c2529fe446c5 Mon Sep 17 00:00:00 2001 From: catboxanon <122327233+catboxanon@users.noreply.github.com> Date: Sat, 26 Aug 2023 07:00:09 -0400 Subject: [PATCH 0008/1174] Hide broken image crop tool for now --- style.css | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/style.css b/style.css index d67b63363b5..5090f289750 100644 --- a/style.css +++ b/style.css @@ -2,6 +2,14 @@ @import url('https://fonts.googleapis.com/css2?family=Source+Sans+Pro:wght@400;600&display=swap'); +/* temporary fix to hide gradio crop tool until it's fixed https://github.com/gradio-app/gradio/issues/3810 */ + + +div.gradio-image button[aria-label="Edit"] { + display: none; +} + + /* general gradio fixes */ :root, .dark{ From 73f69a74534be17c020fd1a5e64dfce71981fc31 Mon Sep 17 00:00:00 2001 From: catboxanon <122327233+catboxanon@users.noreply.github.com> Date: Sat, 26 Aug 2023 07:04:11 -0400 Subject: [PATCH 0009/1174] Fix CSS whitespace --- style.css | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/style.css b/style.css index 5090f289750..e336e79df2e 100644 --- a/style.css +++ b/style.css @@ -2,8 +2,8 @@ @import url('https://fonts.googleapis.com/css2?family=Source+Sans+Pro:wght@400;600&display=swap'); -/* temporary fix to hide gradio crop tool until it's fixed https://github.com/gradio-app/gradio/issues/3810 */ +/* temporary fix to hide gradio crop tool until it's fixed https://github.com/gradio-app/gradio/issues/3810 */ div.gradio-image button[aria-label="Edit"] { display: none; From 168eac319d0f45c778d5b9d35dd5ce280f8d5094 Mon Sep 17 00:00:00 2001 From: Daniel Dengler Date: Sat, 26 Aug 2023 23:22:57 +0200 Subject: [PATCH 0010/1174] is_automatic is missing () for call --- modules/sd_vae.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/sd_vae.py b/modules/sd_vae.py index 669097daade..31306d8ba4b 100644 --- a/modules/sd_vae.py +++ b/modules/sd_vae.py @@ -159,7 +159,7 @@ def resolve_vae_from_user_metadata(checkpoint_file) -> VaeResolution: def resolve_vae_near_checkpoint(checkpoint_file) -> VaeResolution: vae_near_checkpoint = find_vae_near_checkpoint(checkpoint_file) - if vae_near_checkpoint is not None and (not shared.opts.sd_vae_overrides_per_model_preferences or is_automatic): + if vae_near_checkpoint is not None and (not shared.opts.sd_vae_overrides_per_model_preferences or is_automatic()): return VaeResolution(vae_near_checkpoint, 'found near the checkpoint') return VaeResolution(resolved=False) From 9d8d279d0d4d103e1b7d0bad21a3eb835dbab9aa Mon Sep 17 00:00:00 2001 From: catboxanon <122327233+catboxanon@users.noreply.github.com> Date: Sat, 26 Aug 2023 15:56:17 -0400 Subject: [PATCH 0011/1174] Prevent duplicate resize handler --- javascript/resizeHandle.js | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/javascript/resizeHandle.js b/javascript/resizeHandle.js index 2fd3c4d2982..8c5c5169210 100644 --- a/javascript/resizeHandle.js +++ b/javascript/resizeHandle.js @@ -134,6 +134,8 @@ onUiLoaded(function() { for (var elem of gradioApp().querySelectorAll('.resize-handle-row')) { - setupResizeHandle(elem); + if (!elem.querySelector('.resize-handle')) { + setupResizeHandle(elem); + } } }); From b7f0e815624dab182aff406c8f227b39ec17452f Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sun, 27 Aug 2023 08:41:26 +0300 Subject: [PATCH 0012/1174] fix error that causes some extra networks to be disabled if both and are present in the prompt --- modules/extra_networks.py | 58 ++++++++++++++++++++++++++++----------- 1 file changed, 42 insertions(+), 16 deletions(-) diff --git a/modules/extra_networks.py b/modules/extra_networks.py index fa28ac752ac..b9533677887 100644 --- a/modules/extra_networks.py +++ b/modules/extra_networks.py @@ -1,6 +1,7 @@ import json import os import re +import logging from collections import defaultdict from modules import errors @@ -86,27 +87,55 @@ def deactivate(self, p): raise NotImplementedError -def activate(p, extra_network_data): - """call activate for extra networks in extra_network_data in specified order, then call - activate for all remaining registered networks with an empty argument list""" +def lookup_extra_networks(extra_network_data): + """returns a dict mapping ExtraNetwork objects to lists of arguments for those extra networks. - activated = [] + Example input: + { + 'lora': [], + 'lyco': [], + 'hypernet': [] + } + + Example output: + + { + : [, ], + : [] + } + """ - for extra_network_name, extra_network_args in extra_network_data.items(): + res = {} + + for extra_network_name, extra_network_args in list(extra_network_data.items()): extra_network = extra_network_registry.get(extra_network_name, None) + alias = extra_network_aliases.get(extra_network_name, None) - if extra_network is None: - extra_network = extra_network_aliases.get(extra_network_name, None) + if alias is not None and extra_network is None: + extra_network = alias if extra_network is None: - print(f"Skipping unknown extra network: {extra_network_name}") + logging.info(f"Skipping unknown extra network: {extra_network_name}") continue + res.setdefault(extra_network, []).extend(extra_network_args) + + return res + + +def activate(p, extra_network_data): + """call activate for extra networks in extra_network_data in specified order, then call + activate for all remaining registered networks with an empty argument list""" + + activated = [] + + for extra_network, extra_network_args in lookup_extra_networks(extra_network_data).items(): + try: extra_network.activate(p, extra_network_args) activated.append(extra_network) except Exception as e: - errors.display(e, f"activating extra network {extra_network_name} with arguments {extra_network_args}") + errors.display(e, f"activating extra network {extra_network.name} with arguments {extra_network_args}") for extra_network_name, extra_network in extra_network_registry.items(): if extra_network in activated: @@ -125,19 +154,16 @@ def deactivate(p, extra_network_data): """call deactivate for extra networks in extra_network_data in specified order, then call deactivate for all remaining registered networks""" - for extra_network_name in extra_network_data: - extra_network = extra_network_registry.get(extra_network_name, None) - if extra_network is None: - continue + data = lookup_extra_networks(extra_network_data) + for extra_network in data: try: extra_network.deactivate(p) except Exception as e: - errors.display(e, f"deactivating extra network {extra_network_name}") + errors.display(e, f"deactivating extra network {extra_network.name}") for extra_network_name, extra_network in extra_network_registry.items(): - args = extra_network_data.get(extra_network_name, None) - if args is not None: + if extra_network in data: continue try: From cb5f0823c6f7fadb5eb81b93e5a587a11856b478 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sun, 27 Aug 2023 08:45:16 +0300 Subject: [PATCH 0013/1174] update gradio to 3.41.2 --- modules/errors.py | 2 +- requirements.txt | 2 +- requirements_versions.txt | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/modules/errors.py b/modules/errors.py index a56fd30ca3a..8c339464d46 100644 --- a/modules/errors.py +++ b/modules/errors.py @@ -95,7 +95,7 @@ def check_versions(): expected_torch_version = "2.0.0" expected_xformers_version = "0.0.20" - expected_gradio_version = "3.41.0" + expected_gradio_version = "3.41.2" if version.parse(torch.__version__) < version.parse(expected_torch_version): print_error_explanation(f""" diff --git a/requirements.txt b/requirements.txt index 960fa0bd7be..80b438455ce 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,7 +8,7 @@ clean-fid einops fastapi>=0.90.1 gfpgan -gradio==3.41.0 +gradio==3.41.2 inflection jsonmerge kornia diff --git a/requirements_versions.txt b/requirements_versions.txt index 6c679e242c5..f8ae1f385ae 100644 --- a/requirements_versions.txt +++ b/requirements_versions.txt @@ -7,7 +7,7 @@ clean-fid==0.1.35 einops==0.4.1 fastapi==0.94.0 gfpgan==1.3.8 -gradio==3.41.0 +gradio==3.41.2 httpcore==0.15 inflection==0.5.1 jsonmerge==1.8.0 From f2c55523c0cde3dca5cc154c45046316396b14a6 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sun, 27 Aug 2023 08:45:25 +0300 Subject: [PATCH 0014/1174] update changelog --- CHANGELOG.md | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 0c5e0f11990..5e78b3d2d9e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -67,7 +67,7 @@ * make it possible to localize tooltips and placeholders ### Extensions and API: - * gradio 3.41.0 + * gradio 3.41.2 * also bump versions for packages: transformers, GitPython, accelerate, scikit-image, timm, tomesd * support tooltip kwarg for gradio elements: gr.Textbox(label='hello', tooltip='world') * properly clear the total console progressbar when using txt2img and img2img from API @@ -127,6 +127,9 @@ * set devices.dtype_unet correctly * run RealESRGAN on GPU for non-CUDA devices ([#12737](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12737)) * prevent extra network buttons being obscured by description for very small card sizes ([#12745](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12745)) + * fix error that causes some extra networks to be disabled if both and are present in the prompt + * fix defaults settings page breaking when any of main UI tabs are hidden + * fix incorrect save/display of new values in Defaults page in settings ## 1.5.2 From bd5c16e8da5837b2b08fe6e329694553dd688a5f Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sun, 27 Aug 2023 09:19:02 +0300 Subject: [PATCH 0015/1174] fix for Reload UI function: if you reload UI on one tab, other opened tabs will no longer stop working --- CHANGELOG.md | 1 + modules/generation_parameters_copypaste.py | 1 + 2 files changed, 2 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 5e78b3d2d9e..1bbde23482f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -130,6 +130,7 @@ * fix error that causes some extra networks to be disabled if both and are present in the prompt * fix defaults settings page breaking when any of main UI tabs are hidden * fix incorrect save/display of new values in Defaults page in settings + * fix for Reload UI function: if you reload UI on one tab, other opened tabs will no longer stop working ## 1.5.2 diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py index 386517acaef..2ca1605547c 100644 --- a/modules/generation_parameters_copypaste.py +++ b/modules/generation_parameters_copypaste.py @@ -32,6 +32,7 @@ def __init__(self, paste_button, tabname, source_text_component=None, source_ima def reset(): paste_fields.clear() + registered_param_bindings.clear() def quote(text): From 23c6b5f1242f89c37a094d7a9237491c1ca1da34 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sun, 27 Aug 2023 09:39:37 +0300 Subject: [PATCH 0016/1174] fix style editing dialog breaking if it's opened in both img2img and txt2img tabs --- javascript/extraNetworks.js | 9 +++++++++ modules/ui_common.py | 2 +- 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/javascript/extraNetworks.js b/javascript/extraNetworks.js index 3bc723d3718..ad1a4e0008e 100644 --- a/javascript/extraNetworks.js +++ b/javascript/extraNetworks.js @@ -249,6 +249,15 @@ function popup(contents) { globalPopup.style.display = "flex"; } +var storedPopupIds = {}; +function popupId(id) { + if(! storedPopupIds[id]){ + storedPopupIds[id] = gradioApp().getElementById(id); + } + + popup(storedPopupIds[id]); +} + function extraNetworksShowMetadata(text) { var elem = document.createElement('pre'); elem.classList.add('popup-metadata'); diff --git a/modules/ui_common.py b/modules/ui_common.py index eddc4bc882c..84a7d7f2756 100644 --- a/modules/ui_common.py +++ b/modules/ui_common.py @@ -261,7 +261,7 @@ def setup_dialog(button_show, dialog, *, button_close=None): fn=lambda: gr.update(visible=True), inputs=[], outputs=[dialog], - ).then(fn=None, _js="function(){ popup(gradioApp().getElementById('" + dialog.elem_id + "')); }") + ).then(fn=None, _js="function(){ popupId('" + dialog.elem_id + "'); }") if button_close: button_close.click(fn=None, _js="closePopup") From 897312de46352c39d03b6811844c128426fbae80 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sun, 27 Aug 2023 09:44:13 +0300 Subject: [PATCH 0017/1174] update changelog --- CHANGELOG.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 1bbde23482f..07798b5a3ee 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -131,6 +131,10 @@ * fix defaults settings page breaking when any of main UI tabs are hidden * fix incorrect save/display of new values in Defaults page in settings * fix for Reload UI function: if you reload UI on one tab, other opened tabs will no longer stop working + * fix an error that prevents VAE being reloaded after an option change if a VAE near the checkpoint exists ([#12797](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12737)) + * hide broken image crop tool ([#12792](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12737)) + * don't show hidden samplers in dropdown for XYZ script ([#12780](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12737)) + * fix style editing dialog breaking if it's opened in both img2img and txt2img tabs ## 1.5.2 From 63d3150dc4f5c4452a4a385329eb8954f53d6451 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sun, 27 Aug 2023 10:11:14 +0300 Subject: [PATCH 0018/1174] lint --- javascript/extraNetworks.js | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/javascript/extraNetworks.js b/javascript/extraNetworks.js index ad1a4e0008e..493f31af28a 100644 --- a/javascript/extraNetworks.js +++ b/javascript/extraNetworks.js @@ -251,7 +251,7 @@ function popup(contents) { var storedPopupIds = {}; function popupId(id) { - if(! storedPopupIds[id]){ + if (!storedPopupIds[id]) { storedPopupIds[id] = gradioApp().getElementById(id); } From 896fde789ee69bd5dd2a829def0878793aa28079 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sun, 27 Aug 2023 20:16:50 +0300 Subject: [PATCH 0019/1174] hide --gradio-auth and --api-auth values from /internal/sysinfo report --- modules/sysinfo.py | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/modules/sysinfo.py b/modules/sysinfo.py index 058e66ce4e9..2db7551dcf0 100644 --- a/modules/sysinfo.py +++ b/modules/sysinfo.py @@ -82,7 +82,7 @@ def get_dict(): "Data path": paths_internal.data_path, "Extensions dir": paths_internal.extensions_dir, "Checksum": checksum_token, - "Commandline": sys.argv, + "Commandline": get_argv(), "Torch env info": get_torch_sysinfo(), "Exceptions": get_exceptions(), "CPU": { @@ -123,6 +123,22 @@ def get_environment(): return {k: os.environ[k] for k in sorted(os.environ) if k in environment_whitelist} +def get_argv(): + res = [] + + for v in sys.argv: + if shared.cmd_opts.gradio_auth and shared.cmd_opts.gradio_auth == v: + res.append("") + continue + + if shared.cmd_opts.api_auth and shared.cmd_opts.api_auth == v: + res.append("") + continue + + res.append(v) + + return res + re_newline = re.compile(r"\r*\n") From e422f19ee9f78066ec023d6b3b0a307a677c6ad9 Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Mon, 28 Aug 2023 03:27:07 +0900 Subject: [PATCH 0020/1174] non-local condition --- modules/shared_cmd_options.py | 2 +- webui.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/modules/shared_cmd_options.py b/modules/shared_cmd_options.py index af24938b05f..dd93f5206ce 100644 --- a/modules/shared_cmd_options.py +++ b/modules/shared_cmd_options.py @@ -15,4 +15,4 @@ cmd_opts, _ = parser.parse_known_args() -cmd_opts.disable_extension_access = (cmd_opts.share or cmd_opts.listen or cmd_opts.server_name) and not cmd_opts.enable_insecure_extension_access +cmd_opts.disable_extension_access = any([cmd_opts.share, cmd_opts.listen, cmd_opts.ngrok, cmd_opts.server_name]) and not cmd_opts.enable_insecure_extension_access diff --git a/webui.py b/webui.py index 5c827dae87d..12328423d0d 100644 --- a/webui.py +++ b/webui.py @@ -74,7 +74,7 @@ def webui(): if shared.opts.auto_launch_browser == "Remote" or cmd_opts.autolaunch: auto_launch_browser = True elif shared.opts.auto_launch_browser == "Local": - auto_launch_browser = not any([cmd_opts.listen, cmd_opts.share, cmd_opts.ngrok]) + auto_launch_browser = not any([cmd_opts.listen, cmd_opts.share, cmd_opts.ngrok, cmd_opts.server_name]) app, local_url, share_url = shared.demo.launch( share=cmd_opts.share, From 18e3e6d6abfc084324cc8ae13f70ba4af5ddc35f Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Mon, 28 Aug 2023 03:42:02 +0900 Subject: [PATCH 0021/1174] consolidate local check --- modules/shared_cmd_options.py | 4 ++-- webui.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/modules/shared_cmd_options.py b/modules/shared_cmd_options.py index dd93f5206ce..c9626667fb8 100644 --- a/modules/shared_cmd_options.py +++ b/modules/shared_cmd_options.py @@ -14,5 +14,5 @@ else: cmd_opts, _ = parser.parse_known_args() - -cmd_opts.disable_extension_access = any([cmd_opts.share, cmd_opts.listen, cmd_opts.ngrok, cmd_opts.server_name]) and not cmd_opts.enable_insecure_extension_access +cmd_opts.webui_is_non_local = any([cmd_opts.share, cmd_opts.listen, cmd_opts.ngrok, cmd_opts.server_name]) +cmd_opts.disable_extension_access = cmd_opts.webui_is_non_local and not cmd_opts.enable_insecure_extension_access diff --git a/webui.py b/webui.py index 12328423d0d..9ed20b30672 100644 --- a/webui.py +++ b/webui.py @@ -74,7 +74,7 @@ def webui(): if shared.opts.auto_launch_browser == "Remote" or cmd_opts.autolaunch: auto_launch_browser = True elif shared.opts.auto_launch_browser == "Local": - auto_launch_browser = not any([cmd_opts.listen, cmd_opts.share, cmd_opts.ngrok, cmd_opts.server_name]) + auto_launch_browser = not cmd_opts.webui_is_non_local app, local_url, share_url = shared.demo.launch( share=cmd_opts.share, From 2b8484a29d7d1dbfd69d97616e4617cb02006192 Mon Sep 17 00:00:00 2001 From: catboxanon <122327233+catboxanon@users.noreply.github.com> Date: Sun, 27 Aug 2023 16:25:26 -0400 Subject: [PATCH 0022/1174] Add missing infotext for RNG --- modules/shared_options.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/shared_options.py b/modules/shared_options.py index 83f56314900..0f054f472c8 100644 --- a/modules/shared_options.py +++ b/modules/shared_options.py @@ -144,7 +144,7 @@ "comma_padding_backtrack": OptionInfo(20, "Prompt word wrap length limit", gr.Slider, {"minimum": 0, "maximum": 74, "step": 1}).info("in tokens - for texts shorter than specified, if they don't fit into 75 token limit, move them to the next 75 token chunk"), "CLIP_stop_at_last_layers": OptionInfo(1, "Clip skip", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}, infotext="Clip skip").link("wiki", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#clip-skip").info("ignore last layers of CLIP network; 1 ignores none, 2 ignores one layer"), "upcast_attn": OptionInfo(False, "Upcast cross attention layer to float32"), - "randn_source": OptionInfo("GPU", "Random number generator source.", gr.Radio, {"choices": ["GPU", "CPU", "NV"]}).info("changes seeds drastically; use CPU to produce the same picture across different videocard vendors; use NV to produce same picture as on NVidia videocards"), + "randn_source": OptionInfo("GPU", "Random number generator source.", gr.Radio, {"choices": ["GPU", "CPU", "NV"]}, infotext="RNG").info("changes seeds drastically; use CPU to produce the same picture across different videocard vendors; use NV to produce same picture as on NVidia videocards"), "tiling": OptionInfo(False, "Tiling", infotext='Tiling').info("produce a tileable picture"), })) From f3d1631aab82d559294126a9230c979ef4c4e1d6 Mon Sep 17 00:00:00 2001 From: JaredTherriault Date: Sun, 27 Aug 2023 21:54:05 -0700 Subject: [PATCH 0023/1174] Offloading custom work -custom_statics works to do mass replace strings, intended for copy-pasting gen info from internet generations and replacing unsavory prompts with safer prompts for my own sanity -tried to implement this into generation_parameters_copypaste but it didn't work out this iteration, presumably because we return a string and the calling method is looking for an object type -updated webui-user.bat to set a custom temp directory (for disk space concerns) and to apply xformers (for generation speed) I probably won't be merging any of this work into the main repo since I don't want to mess with anyone else's prompts, this is just intended to keep my workspace safe from anything I don't want to see. Eventually this should be done in an extension which I could then publish, but I need to learn a lot more about the extension and callback systems in the main repo first. just uploading this to my fork for now so i don't lose the current progress. --- modules/custom_statics.py | 29 ++++++++++++++++++++++ modules/generation_parameters_copypaste.py | 8 ++++++ webui-user.bat | 4 +-- 3 files changed, 39 insertions(+), 2 deletions(-) create mode 100644 modules/custom_statics.py diff --git a/modules/custom_statics.py b/modules/custom_statics.py new file mode 100644 index 00000000000..207bd5fbb90 --- /dev/null +++ b/modules/custom_statics.py @@ -0,0 +1,29 @@ +import os +import gc +import re + +import modules.paths as paths + +class CustomStatics: + + @staticmethod + # loads a file with strings structured as below, on each line with a : between the search and replace strings, into a list + # search0:replace0 + # search string:replace string + # + # Then replaces all occurrences of the list's search strings with the list's replace strings in one go + def mass_replace_strings(input_string): + with open(os.path.join(paths.data_path, "custom_statics/Replacements.txt"), "r", encoding="utf8") as file: + replacements = file.readlines() + + replacement_dict = {} + for line in replacements: + search, replace = line.strip().split(":") + replacement_dict[search] = replace + + def replace(match_text): + return replacement_dict[match_text.group(0)] + + return re.sub('|'.join(r'\b%s\b' % re.escape(s) for s in replacement_dict.keys()), replace, str(input_string)) + + return str(geninfo) \ No newline at end of file diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py index a3448be9db8..bd7b0018350 100644 --- a/modules/generation_parameters_copypaste.py +++ b/modules/generation_parameters_copypaste.py @@ -370,6 +370,14 @@ def paste_func(prompt): prompt = file.read() params = parse_generation_parameters(prompt) + + # This sanitizes unsavory prompt words when copying from another image + # for my own sanity. This is not intended to be contributed to the main repo, + # it's just so I don't have to see anything I'm not interested in when batch + # reproducing images from civit.ai or elsewhere when working on loras + # todo: make this work with the callback instead of forcing it here, this can be an extension when I feel like putting it together :D + from modules import custom_statics + params = custom_statics.CustomStatics.mass_replace_strings(params) script_callbacks.infotext_pasted_callback(prompt, params) res = [] diff --git a/webui-user.bat b/webui-user.bat index e5a257bef06..1ba2116d01c 100644 --- a/webui-user.bat +++ b/webui-user.bat @@ -1,8 +1,8 @@ @echo off - +set TEMP=G:\SD-temp set PYTHON= set GIT= set VENV_DIR= -set COMMANDLINE_ARGS= +set COMMANDLINE_ARGS= --xformers call webui.bat From f898833ea38718e87b39ab090b2a2325638559cb Mon Sep 17 00:00:00 2001 From: omahs <73983677+omahs@users.noreply.github.com> Date: Mon, 28 Aug 2023 10:43:13 +0200 Subject: [PATCH 0024/1174] fix typos --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index b796d150041..007d6fde2b9 100644 --- a/README.md +++ b/README.md @@ -88,7 +88,7 @@ A browser interface based on Gradio library for Stable Diffusion. - [Alt-Diffusion](https://arxiv.org/abs/2211.06679) support - see [wiki](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#alt-diffusion) for instructions - Now without any bad letters! - Load checkpoints in safetensors format -- Eased resolution restriction: generated image's domension must be a multiple of 8 rather than 64 +- Eased resolution restriction: generated image's dimensions must be a multiple of 8 rather than 64 - Now with a license! - Reorder elements in the UI from settings screen @@ -100,7 +100,7 @@ Alternatively, use online services (like Google Colab): - [List of Online Services](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Online-Services) ### Installation on Windows 10/11 with NVidia-GPUs using release package -1. Download `sd.webui.zip` from [v1.0.0-pre](https://github.com/AUTOMATIC1111/stable-diffusion-webui/releases/tag/v1.0.0-pre) and extract it's contents. +1. Download `sd.webui.zip` from [v1.0.0-pre](https://github.com/AUTOMATIC1111/stable-diffusion-webui/releases/tag/v1.0.0-pre) and extract its contents. 2. Run `update.bat`. 3. Run `run.bat`. > For more details see [Install-and-Run-on-NVidia-GPUs](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-NVidia-GPUs) From 99acbd5ebe3e43e7d07905c7fc274b321cc905be Mon Sep 17 00:00:00 2001 From: catboxanon <122327233+catboxanon@users.noreply.github.com> Date: Mon, 28 Aug 2023 11:17:47 -0400 Subject: [PATCH 0025/1174] Don't print blank stdout in extension installers --- modules/launch_utils.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/modules/launch_utils.py b/modules/launch_utils.py index 7e4d5a61392..43b986c6a67 100644 --- a/modules/launch_utils.py +++ b/modules/launch_utils.py @@ -228,7 +228,9 @@ def run_extension_installer(extension_dir): env = os.environ.copy() env['PYTHONPATH'] = f"{os.path.abspath('.')}{os.pathsep}{env.get('PYTHONPATH', '')}" - print(run(f'"{python}" "{path_installer}"', errdesc=f"Error running install.py for extension {extension_dir}", custom_env=env)) + stdout = run(f'"{python}" "{path_installer}"', errdesc=f"Error running install.py for extension {extension_dir}", custom_env=env) + if stdout: + print(stdout) except Exception as e: errors.report(str(e)) From 20df81b0cc146c117c8a8c002997941063b32f13 Mon Sep 17 00:00:00 2001 From: catboxanon <122327233+catboxanon@users.noreply.github.com> Date: Mon, 28 Aug 2023 11:26:50 -0400 Subject: [PATCH 0026/1174] Honor `--skip-install` for extension installers --- modules/launch_utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/modules/launch_utils.py b/modules/launch_utils.py index 7e4d5a61392..27d7d617ac7 100644 --- a/modules/launch_utils.py +++ b/modules/launch_utils.py @@ -404,7 +404,8 @@ def prepare_environment(): run_pip(f"install -r \"{requirements_file}\"", "requirements") startup_timer.record("install requirements") - run_extensions_installers(settings_file=args.ui_settings_file) + if not args.skip_install: + run_extensions_installers(settings_file=args.ui_settings_file) if args.update_check: version_check(commit) From 592b0dcfa705c42654eff48a40a51d8d6924f987 Mon Sep 17 00:00:00 2001 From: catboxanon <122327233+catboxanon@users.noreply.github.com> Date: Mon, 28 Aug 2023 12:09:37 -0400 Subject: [PATCH 0027/1174] Fix notification not playing when built-in webui tab is inactive --- javascript/notification.js | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/javascript/notification.js b/javascript/notification.js index 76c5715dab4..6d79956125c 100644 --- a/javascript/notification.js +++ b/javascript/notification.js @@ -15,7 +15,7 @@ onAfterUiUpdate(function() { } } - const galleryPreviews = gradioApp().querySelectorAll('div[id^="tab_"][style*="display: block"] div[id$="_results"] .thumbnail-item > img'); + const galleryPreviews = gradioApp().querySelectorAll('div[id^="tab_"] div[id$="_results"] .thumbnail-item > img'); if (galleryPreviews == null) return; From cd48308a2a37b1e838b1b0cc5e8e507a174b14fb Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Mon, 28 Aug 2023 22:22:35 +0300 Subject: [PATCH 0028/1174] always show NV as RNG source in infotext --- modules/processing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/processing.py b/modules/processing.py index 7dc931ba53e..0138e5ac9be 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -689,7 +689,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter "Token merging ratio": None if token_merging_ratio == 0 else token_merging_ratio, "Token merging ratio hr": None if not enable_hr or token_merging_ratio_hr == 0 else token_merging_ratio_hr, "Init image hash": getattr(p, 'init_img_hash', None), - "RNG": opts.randn_source if opts.randn_source != "GPU" and opts.randn_source != "NV" else None, + "RNG": opts.randn_source if opts.randn_source != "GPU" else None, "NGMS": None if p.s_min_uncond == 0 else p.s_min_uncond, "Tiling": "True" if p.tiling else None, **p.extra_generation_params, From 739686b1c521a77b0d3ce43c1f3b123df0317504 Mon Sep 17 00:00:00 2001 From: bluelovers Date: Tue, 29 Aug 2023 06:19:22 +0800 Subject: [PATCH 0029/1174] style: file-metadata word-break --- style.css | 2 ++ 1 file changed, 2 insertions(+) diff --git a/style.css b/style.css index e336e79df2e..bbfb7d39557 100644 --- a/style.css +++ b/style.css @@ -1009,6 +1009,8 @@ div.block.gradio-box.edit-user-metadata { .edit-user-metadata .file-metadata th, .edit-user-metadata .file-metadata td{ padding: 0.3em 1em; + overflow-wrap: anywhere; + word-break: break-word; } .edit-user-metadata .wrap.translucent{ From 1bb21f35102326da28e1360750a879386c015716 Mon Sep 17 00:00:00 2001 From: bluelovers Date: Tue, 29 Aug 2023 06:25:16 +0800 Subject: [PATCH 0030/1174] feat: display file metadata path https://github.com/AUTOMATIC1111/stable-diffusion-webui/issues/12289 --- modules/ui_extra_networks_user_metadata.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/modules/ui_extra_networks_user_metadata.py b/modules/ui_extra_networks_user_metadata.py index b11622a1a19..877d0285945 100644 --- a/modules/ui_extra_networks_user_metadata.py +++ b/modules/ui_extra_networks_user_metadata.py @@ -2,10 +2,29 @@ import html import json import os.path +from pathlib import Path import gradio as gr from modules import generation_parameters_copypaste, images, sysinfo, errors +from modules.paths_internal import models_path + + +def windows_to_unix_style(path): + return Path(path).as_posix() + + +def exclude_root_path(root_path, path_to_exclude): + try: + relative_path = os.path.relpath(path_to_exclude, root_path) + # 如果路径已经在 root_path 之外,relpath 会返回绝对路径 + # 所以需要检查路径是否在 root_path 之内 + if not relative_path.startswith('..'): + return windows_to_unix_style(relative_path) + except ValueError: + pass + # 如果路径无法相对化,或者位于 root_path 之外,则返回原始路径 + return windows_to_unix_style(path_to_exclude) class UserMetadataEditor: @@ -98,6 +117,7 @@ def get_metadata_table(self, name): stats = os.stat(filename) params = [ ('Filename: ', os.path.basename(filename)), + ('Path: ', exclude_root_path(models_path, filename)), ('File size: ', sysinfo.pretty_bytes(stats.st_size)), ('Hash: ', shorthash), ('Modified: ', datetime.datetime.fromtimestamp(stats.st_mtime).strftime('%Y-%m-%d %H:%M')), From d83a1ba65be1b0fbdba8f10212193c52dc8f5e90 Mon Sep 17 00:00:00 2001 From: bluelovers Date: Tue, 29 Aug 2023 06:33:00 +0800 Subject: [PATCH 0031/1174] feat: display file metadata ss_output_name https://github.com/AUTOMATIC1111/stable-diffusion-webui/issues/12289 --- extensions-builtin/Lora/ui_edit_user_metadata.py | 1 + 1 file changed, 1 insertion(+) diff --git a/extensions-builtin/Lora/ui_edit_user_metadata.py b/extensions-builtin/Lora/ui_edit_user_metadata.py index 390d9dde3fb..c7011909055 100644 --- a/extensions-builtin/Lora/ui_edit_user_metadata.py +++ b/extensions-builtin/Lora/ui_edit_user_metadata.py @@ -70,6 +70,7 @@ def get_metadata_table(self, name): metadata = item.get("metadata") or {} keys = { + 'ss_output_name': "Output name:", 'ss_sd_model_name': "Model:", 'ss_clip_skip': "Clip skip:", 'ss_network_module': "Kohya module:", From 02e7824e6a9c3d74af7b383dd66a3a1c231ef082 Mon Sep 17 00:00:00 2001 From: ibrainventures Date: Tue, 29 Aug 2023 02:04:07 +0200 Subject: [PATCH 0032/1174] [RC 1.6.1 - zoom is partly hidden] Update style.css If a image / batch result image is higher or wider than the current viewport, and is zoomed (left corner zoom icon) it is cutted off on the top and also to the left. This new rule seems to be the culprit. --- style.css | 7 ------- 1 file changed, 7 deletions(-) diff --git a/style.css b/style.css index e336e79df2e..56e2cb4c350 100644 --- a/style.css +++ b/style.css @@ -660,13 +660,6 @@ table.popup-table .link{ min-height: 0; } -#modalImage{ - position: absolute; - top: 50%; - left: 50%; - transform: translateX(-50%) translateY(-50%); -} - .modalPrev, .modalNext { cursor: pointer; From 5070ab80042e0cafb1eb97f5da9d1d871cb85de6 Mon Sep 17 00:00:00 2001 From: dhwz Date: Tue, 29 Aug 2023 07:16:32 +0200 Subject: [PATCH 0033/1174] remove xformers Python version check --- modules/launch_utils.py | 12 +----------- 1 file changed, 1 insertion(+), 11 deletions(-) diff --git a/modules/launch_utils.py b/modules/launch_utils.py index 7e4d5a61392..14252c3a6de 100644 --- a/modules/launch_utils.py +++ b/modules/launch_utils.py @@ -366,17 +366,7 @@ def prepare_environment(): startup_timer.record("install open_clip") if (not is_installed("xformers") or args.reinstall_xformers) and args.xformers: - if platform.system() == "Windows": - if platform.python_version().startswith("3.10"): - run_pip(f"install -U -I --no-deps {xformers_package}", "xformers", live=True) - else: - print("Installation of xformers is not supported in this version of Python.") - print("You can also check this and build manually: https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Xformers#building-xformers-on-windows-by-duckness") - if not is_installed("xformers"): - exit(0) - elif platform.system() == "Linux": - run_pip(f"install -U -I --no-deps {xformers_package}", "xformers") - + run_pip(f"install -U -I --no-deps {xformers_package}", "xformers") startup_timer.record("install xformers") if not is_installed("ngrok") and args.ngrok: From 7ab16e99eedf3b5da7e596218a585f6966aee4d8 Mon Sep 17 00:00:00 2001 From: catboxanon <122327233+catboxanon@users.noreply.github.com> Date: Tue, 29 Aug 2023 01:51:13 -0400 Subject: [PATCH 0034/1174] Add option to align with sgm repo sampling implementation --- modules/sd_samplers_kdiffusion.py | 14 ++++++++++++-- modules/shared_options.py | 1 + scripts/xyz_grid.py | 1 + 3 files changed, 14 insertions(+), 2 deletions(-) diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py index b9e0d577656..a8a2735f4e5 100644 --- a/modules/sd_samplers_kdiffusion.py +++ b/modules/sd_samplers_kdiffusion.py @@ -144,7 +144,13 @@ def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, sigmas = self.get_sigmas(p, steps) sigma_sched = sigmas[steps - t_enc - 1:] - xi = x + noise * sigma_sched[0] + if opts.sgm_noise_multiplier: + p.extra_generation_params["SGM noise multiplier"] = True + noise_multiplier = torch.sqrt(1.0 + sigma_sched[0] ** 2.0) + else: + noise_multiplier = sigma_sched[0] + + xi = x + noise * noise_multiplier if opts.img2img_extra_noise > 0: p.extra_generation_params["Extra noise"] = opts.img2img_extra_noise @@ -197,7 +203,11 @@ def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, ima sigmas = self.get_sigmas(p, steps) - x = x * sigmas[0] + if opts.sgm_noise_multiplier: + p.extra_generation_params["SGM noise multiplier"] = True + x = x * torch.sqrt(1.0 + sigmas[0] ** 2.0) + else: + x = x * sigmas[0] extra_params_kwargs = self.initialize(p) parameters = inspect.signature(self.func).parameters diff --git a/modules/shared_options.py b/modules/shared_options.py index 83f56314900..67f7a8df65a 100644 --- a/modules/shared_options.py +++ b/modules/shared_options.py @@ -309,6 +309,7 @@ 'rho': OptionInfo(0.0, "rho", gr.Number, infotext='Schedule rho').info("0 = default (7 for karras, 1 for polyexponential); higher values result in a steeper noise schedule (decreases faster)"), 'eta_noise_seed_delta': OptionInfo(0, "Eta noise seed delta", gr.Number, {"precision": 0}, infotext='ENSD').info("ENSD; does not improve anything, just produces different results for ancestral samplers - only useful for reproducing images"), 'always_discard_next_to_last_sigma': OptionInfo(False, "Always discard next-to-last sigma", infotext='Discard penultimate sigma').link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/6044"), + 'sgm_noise_multiplier': OptionInfo(False, "SGM noise multiplier", infotext='SGM noise multplier').link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12818").info("Match initial noise to official SDXL implementation - only useful for reproducing images"), 'uni_pc_variant': OptionInfo("bh1", "UniPC variant", gr.Radio, {"choices": ["bh1", "bh2", "vary_coeff"]}, infotext='UniPC variant'), 'uni_pc_skip_type': OptionInfo("time_uniform", "UniPC skip type", gr.Radio, {"choices": ["time_uniform", "time_quadratic", "logSNR"]}, infotext='UniPC skip type'), 'uni_pc_order': OptionInfo(3, "UniPC order", gr.Slider, {"minimum": 1, "maximum": 50, "step": 1}, infotext='UniPC order').info("must be < sampling steps"), diff --git a/scripts/xyz_grid.py b/scripts/xyz_grid.py index 517d6332ed2..939d86053bd 100644 --- a/scripts/xyz_grid.py +++ b/scripts/xyz_grid.py @@ -265,6 +265,7 @@ def __init__(self, *args, **kwargs): AxisOption("Token merging ratio", float, apply_override('token_merging_ratio')), AxisOption("Token merging ratio high-res", float, apply_override('token_merging_ratio_hr')), AxisOption("Always discard next-to-last sigma", str, apply_override('always_discard_next_to_last_sigma', boolean=True), choices=boolean_choice(reverse=True)), + AxisOption("SGM noise multiplier", str, apply_override('sgm_noise_multiplier', boolean=True), choices=boolean_choice(reverse=True)), AxisOption("Refiner checkpoint", str, apply_field('refiner_checkpoint'), format_value=format_remove_path, confirm=confirm_checkpoints_or_none, cost=1.0, choices=lambda: ['None'] + sorted(sd_models.checkpoints_list, key=str.casefold)), AxisOption("Refiner switch at", float, apply_field('refiner_switch_at')), AxisOption("RNG source", str, apply_override("randn_source"), choices=lambda: ["GPU", "CPU", "NV"]), From b6c1a1bbbf29a3041025aa336f6f843ffd7c7d46 Mon Sep 17 00:00:00 2001 From: a666 <19142162+a666@users.noreply.github.com> Date: Fri, 25 Aug 2023 01:58:19 -0600 Subject: [PATCH 0035/1174] Fix some deprecated types --- modules/api/api.py | 26 +++++++++++++------------- modules/api/models.py | 24 +++++++++++------------- modules/gitpython_hack.py | 2 +- modules/prompt_parser.py | 7 +++---- modules/script_callbacks.py | 6 +++--- modules/sub_quadratic_attention.py | 4 ++-- modules/ui.py | 3 +-- 7 files changed, 34 insertions(+), 38 deletions(-) diff --git a/modules/api/api.py b/modules/api/api.py index 844e31ee75a..905ef9c9536 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -29,7 +29,7 @@ from modules.sd_models_config import find_checkpoint_config_near_filename from modules.realesrgan_model import get_realesrgan_models from modules import devices -from typing import Dict, List, Any +from typing import Any import piexif import piexif.helper from contextlib import closing @@ -221,15 +221,15 @@ def __init__(self, app: FastAPI, queue_lock: Lock): self.add_api_route("/sdapi/v1/options", self.get_config, methods=["GET"], response_model=models.OptionsModel) self.add_api_route("/sdapi/v1/options", self.set_config, methods=["POST"]) self.add_api_route("/sdapi/v1/cmd-flags", self.get_cmd_flags, methods=["GET"], response_model=models.FlagsModel) - self.add_api_route("/sdapi/v1/samplers", self.get_samplers, methods=["GET"], response_model=List[models.SamplerItem]) - self.add_api_route("/sdapi/v1/upscalers", self.get_upscalers, methods=["GET"], response_model=List[models.UpscalerItem]) - self.add_api_route("/sdapi/v1/latent-upscale-modes", self.get_latent_upscale_modes, methods=["GET"], response_model=List[models.LatentUpscalerModeItem]) - self.add_api_route("/sdapi/v1/sd-models", self.get_sd_models, methods=["GET"], response_model=List[models.SDModelItem]) - self.add_api_route("/sdapi/v1/sd-vae", self.get_sd_vaes, methods=["GET"], response_model=List[models.SDVaeItem]) - self.add_api_route("/sdapi/v1/hypernetworks", self.get_hypernetworks, methods=["GET"], response_model=List[models.HypernetworkItem]) - self.add_api_route("/sdapi/v1/face-restorers", self.get_face_restorers, methods=["GET"], response_model=List[models.FaceRestorerItem]) - self.add_api_route("/sdapi/v1/realesrgan-models", self.get_realesrgan_models, methods=["GET"], response_model=List[models.RealesrganItem]) - self.add_api_route("/sdapi/v1/prompt-styles", self.get_prompt_styles, methods=["GET"], response_model=List[models.PromptStyleItem]) + self.add_api_route("/sdapi/v1/samplers", self.get_samplers, methods=["GET"], response_model=list[models.SamplerItem]) + self.add_api_route("/sdapi/v1/upscalers", self.get_upscalers, methods=["GET"], response_model=list[models.UpscalerItem]) + self.add_api_route("/sdapi/v1/latent-upscale-modes", self.get_latent_upscale_modes, methods=["GET"], response_model=list[models.LatentUpscalerModeItem]) + self.add_api_route("/sdapi/v1/sd-models", self.get_sd_models, methods=["GET"], response_model=list[models.SDModelItem]) + self.add_api_route("/sdapi/v1/sd-vae", self.get_sd_vaes, methods=["GET"], response_model=list[models.SDVaeItem]) + self.add_api_route("/sdapi/v1/hypernetworks", self.get_hypernetworks, methods=["GET"], response_model=list[models.HypernetworkItem]) + self.add_api_route("/sdapi/v1/face-restorers", self.get_face_restorers, methods=["GET"], response_model=list[models.FaceRestorerItem]) + self.add_api_route("/sdapi/v1/realesrgan-models", self.get_realesrgan_models, methods=["GET"], response_model=list[models.RealesrganItem]) + self.add_api_route("/sdapi/v1/prompt-styles", self.get_prompt_styles, methods=["GET"], response_model=list[models.PromptStyleItem]) self.add_api_route("/sdapi/v1/embeddings", self.get_embeddings, methods=["GET"], response_model=models.EmbeddingsResponse) self.add_api_route("/sdapi/v1/refresh-checkpoints", self.refresh_checkpoints, methods=["POST"]) self.add_api_route("/sdapi/v1/refresh-vae", self.refresh_vae, methods=["POST"]) @@ -242,8 +242,8 @@ def __init__(self, app: FastAPI, queue_lock: Lock): self.add_api_route("/sdapi/v1/unload-checkpoint", self.unloadapi, methods=["POST"]) self.add_api_route("/sdapi/v1/reload-checkpoint", self.reloadapi, methods=["POST"]) self.add_api_route("/sdapi/v1/scripts", self.get_scripts_list, methods=["GET"], response_model=models.ScriptsList) - self.add_api_route("/sdapi/v1/script-info", self.get_script_info, methods=["GET"], response_model=List[models.ScriptInfo]) - self.add_api_route("/sdapi/v1/extensions", self.get_extensions_list, methods=["GET"], response_model=List[models.ExtensionItem]) + self.add_api_route("/sdapi/v1/script-info", self.get_script_info, methods=["GET"], response_model=list[models.ScriptInfo]) + self.add_api_route("/sdapi/v1/extensions", self.get_extensions_list, methods=["GET"], response_model=list[models.ExtensionItem]) if shared.cmd_opts.api_server_stop: self.add_api_route("/sdapi/v1/server-kill", self.kill_webui, methods=["POST"]) @@ -563,7 +563,7 @@ def get_config(self): return options - def set_config(self, req: Dict[str, Any]): + def set_config(self, req: dict[str, Any]): checkpoint_name = req.get("sd_model_checkpoint", None) if checkpoint_name is not None and checkpoint_name not in checkpoint_aliases: raise RuntimeError(f"model {checkpoint_name!r} not found") diff --git a/modules/api/models.py b/modules/api/models.py index 94eca97dcb0..a0d80af8c0d 100644 --- a/modules/api/models.py +++ b/modules/api/models.py @@ -1,12 +1,10 @@ import inspect from pydantic import BaseModel, Field, create_model -from typing import Any, Optional -from typing_extensions import Literal +from typing import Any, Optional, Literal from inflection import underscore from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img from modules.shared import sd_upscalers, opts, parser -from typing import Dict, List API_NOT_ALLOWED = [ "self", @@ -130,12 +128,12 @@ def generate_model(self): ).generate_model() class TextToImageResponse(BaseModel): - images: List[str] = Field(default=None, title="Image", description="The generated image in base64 format.") + images: list[str] = Field(default=None, title="Image", description="The generated image in base64 format.") parameters: dict info: str class ImageToImageResponse(BaseModel): - images: List[str] = Field(default=None, title="Image", description="The generated image in base64 format.") + images: list[str] = Field(default=None, title="Image", description="The generated image in base64 format.") parameters: dict info: str @@ -168,10 +166,10 @@ class FileData(BaseModel): name: str = Field(title="File name") class ExtrasBatchImagesRequest(ExtrasBaseRequest): - imageList: List[FileData] = Field(title="Images", description="List of images to work on. Must be Base64 strings") + imageList: list[FileData] = Field(title="Images", description="List of images to work on. Must be Base64 strings") class ExtrasBatchImagesResponse(ExtraBaseResponse): - images: List[str] = Field(title="Images", description="The generated images in base64 format.") + images: list[str] = Field(title="Images", description="The generated images in base64 format.") class PNGInfoRequest(BaseModel): image: str = Field(title="Image", description="The base64 encoded PNG image") @@ -233,8 +231,8 @@ class PreprocessResponse(BaseModel): class SamplerItem(BaseModel): name: str = Field(title="Name") - aliases: List[str] = Field(title="Aliases") - options: Dict[str, str] = Field(title="Options") + aliases: list[str] = Field(title="Aliases") + options: dict[str, str] = Field(title="Options") class UpscalerItem(BaseModel): name: str = Field(title="Name") @@ -285,8 +283,8 @@ class EmbeddingItem(BaseModel): vectors: int = Field(title="Vectors", description="The number of vectors in the embedding") class EmbeddingsResponse(BaseModel): - loaded: Dict[str, EmbeddingItem] = Field(title="Loaded", description="Embeddings loaded for the current model") - skipped: Dict[str, EmbeddingItem] = Field(title="Skipped", description="Embeddings skipped for the current model (likely due to architecture incompatibility)") + loaded: dict[str, EmbeddingItem] = Field(title="Loaded", description="Embeddings loaded for the current model") + skipped: dict[str, EmbeddingItem] = Field(title="Skipped", description="Embeddings skipped for the current model (likely due to architecture incompatibility)") class MemoryResponse(BaseModel): ram: dict = Field(title="RAM", description="System memory stats") @@ -304,14 +302,14 @@ class ScriptArg(BaseModel): minimum: Optional[Any] = Field(default=None, title="Minimum", description="Minimum allowed value for the argumentin UI") maximum: Optional[Any] = Field(default=None, title="Minimum", description="Maximum allowed value for the argumentin UI") step: Optional[Any] = Field(default=None, title="Minimum", description="Step for changing value of the argumentin UI") - choices: Optional[List[str]] = Field(default=None, title="Choices", description="Possible values for the argument") + choices: Optional[list[str]] = Field(default=None, title="Choices", description="Possible values for the argument") class ScriptInfo(BaseModel): name: str = Field(default=None, title="Name", description="Script name") is_alwayson: bool = Field(default=None, title="IsAlwayson", description="Flag specifying whether this script is an alwayson script") is_img2img: bool = Field(default=None, title="IsImg2img", description="Flag specifying whether this script is an img2img script") - args: List[ScriptArg] = Field(title="Arguments", description="List of script's arguments") + args: list[ScriptArg] = Field(title="Arguments", description="List of script's arguments") class ExtensionItem(BaseModel): name: str = Field(title="Name", description="Extension name") diff --git a/modules/gitpython_hack.py b/modules/gitpython_hack.py index e537c1df93e..b55f0640e5e 100644 --- a/modules/gitpython_hack.py +++ b/modules/gitpython_hack.py @@ -23,7 +23,7 @@ def get_object_header(self, ref: str | bytes) -> tuple[str, str, int]: ) return self._parse_object_header(ret) - def stream_object_data(self, ref: str) -> tuple[str, str, int, "Git.CatFileContentStream"]: + def stream_object_data(self, ref: str) -> tuple[str, str, int, Git.CatFileContentStream]: # Not really streaming, per se; this buffers the entire object in memory. # Shouldn't be a problem for our use case, since we're only using this for # object headers (commit objects). diff --git a/modules/prompt_parser.py b/modules/prompt_parser.py index 334efeef317..ddf4d2dd4f3 100644 --- a/modules/prompt_parser.py +++ b/modules/prompt_parser.py @@ -2,7 +2,6 @@ import re from collections import namedtuple -from typing import List import lark # a prompt like this: "fantasy landscape with a [mountain:lake:0.25] and [an oak:a christmas tree:0.75][ in foreground::0.6][ in background:0.25] [shoddy:masterful:0.5]" @@ -240,14 +239,14 @@ def get_multicond_prompt_list(prompts: SdConditioning | list[str]): class ComposableScheduledPromptConditioning: def __init__(self, schedules, weight=1.0): - self.schedules: List[ScheduledPromptConditioning] = schedules + self.schedules: list[ScheduledPromptConditioning] = schedules self.weight: float = weight class MulticondLearnedConditioning: def __init__(self, shape, batch): self.shape: tuple = shape # the shape field is needed to send this object to DDIM/PLMS - self.batch: List[List[ComposableScheduledPromptConditioning]] = batch + self.batch: list[list[ComposableScheduledPromptConditioning]] = batch def get_multicond_learned_conditioning(model, prompts, steps, hires_steps=None, use_old_scheduling=False) -> MulticondLearnedConditioning: @@ -278,7 +277,7 @@ def shape(self): return self["crossattn"].shape -def reconstruct_cond_batch(c: List[List[ScheduledPromptConditioning]], current_step): +def reconstruct_cond_batch(c: list[list[ScheduledPromptConditioning]], current_step): param = c[0][0].cond is_dict = isinstance(param, dict) diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py index fab23551a68..9c2a6541bbc 100644 --- a/modules/script_callbacks.py +++ b/modules/script_callbacks.py @@ -1,7 +1,7 @@ import inspect import os from collections import namedtuple -from typing import Optional, Dict, Any +from typing import Optional, Any from fastapi import FastAPI from gradio import Blocks @@ -255,7 +255,7 @@ def image_grid_callback(params: ImageGridLoopParams): report_exception(c, 'image_grid') -def infotext_pasted_callback(infotext: str, params: Dict[str, Any]): +def infotext_pasted_callback(infotext: str, params: dict[str, Any]): for c in callback_map['callbacks_infotext_pasted']: try: c.callback(infotext, params) @@ -446,7 +446,7 @@ def on_infotext_pasted(callback): """register a function to be called before applying an infotext. The callback is called with two arguments: - infotext: str - raw infotext. - - result: Dict[str, any] - parsed infotext parameters. + - result: dict[str, any] - parsed infotext parameters. """ add_callback(callback_map['callbacks_infotext_pasted'], callback) diff --git a/modules/sub_quadratic_attention.py b/modules/sub_quadratic_attention.py index ae4ee4bbec0..4cb561ef207 100644 --- a/modules/sub_quadratic_attention.py +++ b/modules/sub_quadratic_attention.py @@ -15,7 +15,7 @@ from torch import Tensor from torch.utils.checkpoint import checkpoint import math -from typing import Optional, NamedTuple, List +from typing import Optional, NamedTuple def narrow_trunc( @@ -97,7 +97,7 @@ def chunk_scanner(chunk_idx: int) -> AttnChunk: ) return summarize_chunk(query, key_chunk, value_chunk) - chunks: List[AttnChunk] = [ + chunks: list[AttnChunk] = [ chunk_scanner(chunk) for chunk in torch.arange(0, k_tokens, kv_chunk_size) ] acc_chunk = AttnChunk(*map(torch.stack, zip(*chunks))) diff --git a/modules/ui.py b/modules/ui.py index f4028475661..9a569182364 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1338,7 +1338,6 @@ def versions_html(): def setup_ui_api(app): from pydantic import BaseModel, Field - from typing import List class QuicksettingsHint(BaseModel): name: str = Field(title="Name of the quicksettings field") @@ -1347,7 +1346,7 @@ class QuicksettingsHint(BaseModel): def quicksettings_hint(): return [QuicksettingsHint(name=k, label=v.label) for k, v in opts.data_labels.items()] - app.add_api_route("/internal/quicksettings-hint", quicksettings_hint, methods=["GET"], response_model=List[QuicksettingsHint]) + app.add_api_route("/internal/quicksettings-hint", quicksettings_hint, methods=["GET"], response_model=list[QuicksettingsHint]) app.add_api_route("/internal/ping", lambda: {}, methods=["GET"]) From 04b90328c0aa86670cfe5d31612d341e29b5a258 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Tue, 29 Aug 2023 15:38:05 +0300 Subject: [PATCH 0036/1174] revert SGM noise multiplier change for img2img because it breaks hires fix --- modules/sd_samplers_kdiffusion.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py index a8a2735f4e5..72c352a6e6e 100644 --- a/modules/sd_samplers_kdiffusion.py +++ b/modules/sd_samplers_kdiffusion.py @@ -144,13 +144,7 @@ def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, sigmas = self.get_sigmas(p, steps) sigma_sched = sigmas[steps - t_enc - 1:] - if opts.sgm_noise_multiplier: - p.extra_generation_params["SGM noise multiplier"] = True - noise_multiplier = torch.sqrt(1.0 + sigma_sched[0] ** 2.0) - else: - noise_multiplier = sigma_sched[0] - - xi = x + noise * noise_multiplier + xi = x + noise * sigma_sched[0] if opts.img2img_extra_noise > 0: p.extra_generation_params["Extra noise"] = opts.img2img_extra_noise From ba7d0d225a97a7d79a150d1f649011f54552b3bf Mon Sep 17 00:00:00 2001 From: ibrainventures Date: Tue, 29 Aug 2023 15:31:01 +0200 Subject: [PATCH 0037/1174] Update style.css --- style.css | 3 +++ 1 file changed, 3 insertions(+) diff --git a/style.css b/style.css index 56e2cb4c350..34869b8955b 100644 --- a/style.css +++ b/style.css @@ -621,6 +621,9 @@ table.popup-table .link{ .modalControls { display: flex; + position: absolute; + right: 0px; + left: 0px; gap: 1em; padding: 1em; background-color:rgba(0,0,0,0); From f564d8ed2c5c5644101c5670d2cec15b03ccb51b Mon Sep 17 00:00:00 2001 From: bluelovers Date: Tue, 29 Aug 2023 22:11:18 +0800 Subject: [PATCH 0038/1174] refactor: refactor function --- modules/ui_extra_networks_user_metadata.py | 21 ++++++--------------- 1 file changed, 6 insertions(+), 15 deletions(-) diff --git a/modules/ui_extra_networks_user_metadata.py b/modules/ui_extra_networks_user_metadata.py index 877d0285945..588f84c75e2 100644 --- a/modules/ui_extra_networks_user_metadata.py +++ b/modules/ui_extra_networks_user_metadata.py @@ -10,21 +10,12 @@ from modules.paths_internal import models_path -def windows_to_unix_style(path): - return Path(path).as_posix() - - -def exclude_root_path(root_path, path_to_exclude): - try: - relative_path = os.path.relpath(path_to_exclude, root_path) - # 如果路径已经在 root_path 之外,relpath 会返回绝对路径 - # 所以需要检查路径是否在 root_path 之内 - if not relative_path.startswith('..'): - return windows_to_unix_style(relative_path) - except ValueError: - pass - # 如果路径无法相对化,或者位于 root_path 之外,则返回原始路径 - return windows_to_unix_style(path_to_exclude) +def exclude_root_path(root_path, path): + path_object = Path(path) + if path_object.is_relative_to(root_path): + path_object = path_object.relative_to(root_path) + + return path_object.as_posix() class UserMetadataEditor: From cb2a4f24247c6159740813935a973a6fe1ccc30e Mon Sep 17 00:00:00 2001 From: bluelovers Date: Tue, 29 Aug 2023 22:47:10 +0800 Subject: [PATCH 0039/1174] chore: change extension time format --- modules/ui_extensions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/ui_extensions.py b/modules/ui_extensions.py index e01382676a2..83dcd303945 100644 --- a/modules/ui_extensions.py +++ b/modules/ui_extensions.py @@ -177,7 +177,7 @@ def extension_table(): {remote} {ext.branch} {version_link} - {time.asctime(time.gmtime(ext.commit_date))} + {time.strftime('%Y-%m-%d %H:%M:%S', time.gmtime(ext.commit_date))} {ext_status} """ From e3939f33394de31594f7c459a7bd352d206f7669 Mon Sep 17 00:00:00 2001 From: catboxanon <122327233+catboxanon@users.noreply.github.com> Date: Tue, 29 Aug 2023 12:19:10 -0400 Subject: [PATCH 0040/1174] Do not change quicksettings value when value returned is `None` --- modules/ui_settings.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/ui_settings.py b/modules/ui_settings.py index 6dde4b6aa04..8ff9c074718 100644 --- a/modules/ui_settings.py +++ b/modules/ui_settings.py @@ -87,7 +87,7 @@ def run_settings_single(self, value, key): if not opts.same_type(value, opts.data_labels[key].default): return gr.update(visible=True), opts.dumpjson() - if not opts.set(key, value): + if value is None or not opts.set(key, value): return gr.update(value=getattr(opts, key)), opts.dumpjson() opts.save(shared.config_filename) From 7e5fcdaf694343bd66d58af6644e47d5b8f8b879 Mon Sep 17 00:00:00 2001 From: dhwz Date: Tue, 29 Aug 2023 18:49:42 +0200 Subject: [PATCH 0041/1174] don't print empty lines --- modules/launch_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/launch_utils.py b/modules/launch_utils.py index 9aa0f071fd0..05488fe6382 100644 --- a/modules/launch_utils.py +++ b/modules/launch_utils.py @@ -228,7 +228,7 @@ def run_extension_installer(extension_dir): env = os.environ.copy() env['PYTHONPATH'] = f"{os.path.abspath('.')}{os.pathsep}{env.get('PYTHONPATH', '')}" - stdout = run(f'"{python}" "{path_installer}"', errdesc=f"Error running install.py for extension {extension_dir}", custom_env=env) + stdout = run(f'"{python}" "{path_installer}"', errdesc=f"Error running install.py for extension {extension_dir}", custom_env=env).strip() if stdout: print(stdout) except Exception as e: From 549b475be9b5cf1dca0a7bad2b6a6381e50e2b37 Mon Sep 17 00:00:00 2001 From: catboxanon <122327233+catboxanon@users.noreply.github.com> Date: Tue, 29 Aug 2023 14:22:04 -0400 Subject: [PATCH 0042/1174] Add noisy latent to ExtraNoiseParams for callback --- modules/script_callbacks.py | 7 +++++-- modules/sd_samplers_kdiffusion.py | 2 +- modules/sd_samplers_timesteps.py | 2 +- 3 files changed, 7 insertions(+), 4 deletions(-) diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py index fab23551a68..c99695eb3d9 100644 --- a/modules/script_callbacks.py +++ b/modules/script_callbacks.py @@ -29,12 +29,15 @@ def __init__(self, image, p, filename, pnginfo): class ExtraNoiseParams: - def __init__(self, noise, x): + def __init__(self, noise, x, xi): self.noise = noise """Random noise generated by the seed""" self.x = x - """Latent image representation of the image""" + """Latent representation of the image""" + + self.xi = xi + """Noisy latent representation of the image""" class CFGDenoiserParams: diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py index 72c352a6e6e..8a8c87e0d01 100644 --- a/modules/sd_samplers_kdiffusion.py +++ b/modules/sd_samplers_kdiffusion.py @@ -148,7 +148,7 @@ def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, if opts.img2img_extra_noise > 0: p.extra_generation_params["Extra noise"] = opts.img2img_extra_noise - extra_noise_params = ExtraNoiseParams(noise, x) + extra_noise_params = ExtraNoiseParams(noise, x, xi) extra_noise_callback(extra_noise_params) noise = extra_noise_params.noise xi += noise * opts.img2img_extra_noise diff --git a/modules/sd_samplers_timesteps.py b/modules/sd_samplers_timesteps.py index 7a6cbd46d9c..b17a8f93c2b 100644 --- a/modules/sd_samplers_timesteps.py +++ b/modules/sd_samplers_timesteps.py @@ -107,7 +107,7 @@ def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, if opts.img2img_extra_noise > 0: p.extra_generation_params["Extra noise"] = opts.img2img_extra_noise - extra_noise_params = ExtraNoiseParams(noise, x) + extra_noise_params = ExtraNoiseParams(noise, x, xi) extra_noise_callback(extra_noise_params) noise = extra_noise_params.noise xi += noise * opts.img2img_extra_noise * sqrt_alpha_cumprod From 9a4a1aac81df8c29d6dcce1abbf1b58fb7e4fc75 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Wed, 30 Aug 2023 08:05:18 +0300 Subject: [PATCH 0043/1174] get progressbar to display correctly in extensions tab --- javascript/extensions.js | 2 +- modules/ui_extensions.py | 8 ++++++-- style.css | 5 +++++ 3 files changed, 12 insertions(+), 3 deletions(-) diff --git a/javascript/extensions.js b/javascript/extensions.js index 1f7254c5dfe..312131b76eb 100644 --- a/javascript/extensions.js +++ b/javascript/extensions.js @@ -33,7 +33,7 @@ function extensions_check() { var id = randomId(); - requestProgress(id, gradioApp().getElementById('extensions_installed_top'), null, function() { + requestProgress(id, gradioApp().getElementById('extensions_installed_html'), null, function() { }); diff --git a/modules/ui_extensions.py b/modules/ui_extensions.py index e01382676a2..67a243c359b 100644 --- a/modules/ui_extensions.py +++ b/modules/ui_extensions.py @@ -557,8 +557,12 @@ def create_ui(): msg = '"--disable-extra-extensions" was used, remove it to load all extensions again' html = f'{msg}' - info = gr.HTML(html) - extensions_table = gr.HTML('Loading...') + with gr.Row(): + info = gr.HTML(html) + + with gr.Row(elem_classes="progress-container"): + extensions_table = gr.HTML('Loading...', elem_id="extensions_installed_html") + ui.load(fn=extension_table, inputs=[], outputs=[extensions_table]) apply.click( diff --git a/style.css b/style.css index 92d3030e5d2..fb4e2f1f02e 100644 --- a/style.css +++ b/style.css @@ -517,6 +517,11 @@ table.popup-table .link{ background: #b4c0cc; border-radius: 3px !important; top: -20px; + width: 100%; +} + +.progress-container{ + position: relative; } [id$=_results].mobile{ From edf3ad5aed9435e2ff3cc0f98895be6056f1f950 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Wed, 30 Aug 2023 08:22:06 +0300 Subject: [PATCH 0044/1174] go back to single path for filenames in extra networks metadata dialog --- modules/ui_extra_networks_user_metadata.py | 21 +++++++++------------ 1 file changed, 9 insertions(+), 12 deletions(-) diff --git a/modules/ui_extra_networks_user_metadata.py b/modules/ui_extra_networks_user_metadata.py index 588f84c75e2..ae972fbb2fa 100644 --- a/modules/ui_extra_networks_user_metadata.py +++ b/modules/ui_extra_networks_user_metadata.py @@ -2,22 +2,13 @@ import html import json import os.path -from pathlib import Path import gradio as gr -from modules import generation_parameters_copypaste, images, sysinfo, errors +from modules import generation_parameters_copypaste, images, sysinfo, errors, ui_extra_networks from modules.paths_internal import models_path -def exclude_root_path(root_path, path): - path_object = Path(path) - if path_object.is_relative_to(root_path): - path_object = path_object.relative_to(root_path) - - return path_object.as_posix() - - class UserMetadataEditor: def __init__(self, ui, tabname, page): @@ -99,6 +90,13 @@ def get_card_html(self, name): return preview + def relative_path(self, path): + for parent_path in self.page.allowed_directories_for_previews(): + if ui_extra_networks.path_is_parent(parent_path, path): + return os.path.relpath(path, parent_path) + + return os.path.basename(path) + def get_metadata_table(self, name): item = self.page.items.get(name, {}) try: @@ -107,8 +105,7 @@ def get_metadata_table(self, name): stats = os.stat(filename) params = [ - ('Filename: ', os.path.basename(filename)), - ('Path: ', exclude_root_path(models_path, filename)), + ('Filename: ', self.relative_path(filename)), ('File size: ', sysinfo.pretty_bytes(stats.st_size)), ('Hash: ', shorthash), ('Modified: ', datetime.datetime.fromtimestamp(stats.st_mtime).strftime('%Y-%m-%d %H:%M')), From f874b1bcad05d7ea4c3cc28df82904ac7c12e64f Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Wed, 30 Aug 2023 08:54:31 +0300 Subject: [PATCH 0045/1174] keep order in list of checkpoints when loading model that doesn't have a checksum --- modules/sd_models.py | 22 +++++++++++++++++++++- 1 file changed, 21 insertions(+), 1 deletion(-) diff --git a/modules/sd_models.py b/modules/sd_models.py index 547e93c442e..930d0bee5c8 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -27,6 +27,24 @@ checkpoints_loaded = collections.OrderedDict() +def replace_key(d, key, new_key, value): + keys = list(d.keys()) + + d[new_key] = value + + if key not in keys: + return d + + index = keys.index(key) + keys[index] = new_key + + new_d = {k: d[k] for k in keys} + + d.clear() + d.update(new_d) + return d + + class CheckpointInfo: def __init__(self, filename): self.filename = filename @@ -91,9 +109,11 @@ def calculate_shorthash(self): if self.shorthash not in self.ids: self.ids += [self.shorthash, self.sha256, f'{self.name} [{self.shorthash}]', f'{self.name_for_extra} [{self.shorthash}]'] - checkpoints_list.pop(self.title, None) + old_title = self.title self.title = f'{self.name} [{self.shorthash}]' self.short_title = f'{self.name_for_extra} [{self.shorthash}]' + + replace_key(checkpoints_list, old_title, self.title, self) self.register() return self.shorthash From 28b084ca25387340ba07a5ffed8403d8d289cb70 Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Wed, 30 Aug 2023 15:28:46 +0900 Subject: [PATCH 0046/1174] extension time format in system time zone --- modules/ui_extensions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/ui_extensions.py b/modules/ui_extensions.py index 83557d7a003..fa831f5703e 100644 --- a/modules/ui_extensions.py +++ b/modules/ui_extensions.py @@ -177,7 +177,7 @@ def extension_table(): {remote} {ext.branch} {version_link} - {time.strftime('%Y-%m-%d %H:%M:%S', time.gmtime(ext.commit_date))} + {datetime.fromtimestamp(ext.commit_date) if ext.commit_date else ""} {ext_status} """ From 67cd4ec0aabb69d1133dfb18543ab080be855323 Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Wed, 30 Aug 2023 15:37:13 +0900 Subject: [PATCH 0047/1174] lint --- modules/ui_extra_networks_user_metadata.py | 1 - 1 file changed, 1 deletion(-) diff --git a/modules/ui_extra_networks_user_metadata.py b/modules/ui_extra_networks_user_metadata.py index ae972fbb2fa..bfec140cc73 100644 --- a/modules/ui_extra_networks_user_metadata.py +++ b/modules/ui_extra_networks_user_metadata.py @@ -6,7 +6,6 @@ import gradio as gr from modules import generation_parameters_copypaste, images, sysinfo, errors, ui_extra_networks -from modules.paths_internal import models_path class UserMetadataEditor: From c985d23c52b541e5a5d0dcf2c3f3a0629cee23f9 Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Wed, 30 Aug 2023 16:18:31 +0900 Subject: [PATCH 0048/1174] extension update time, convert to system time zone --- modules/ui_extensions.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/modules/ui_extensions.py b/modules/ui_extensions.py index fa831f5703e..2e8c1d6d21d 100644 --- a/modules/ui_extensions.py +++ b/modules/ui_extensions.py @@ -2,7 +2,7 @@ import os import threading import time -from datetime import datetime +from datetime import datetime, timezone import git @@ -442,7 +442,7 @@ def search_extensions(filter_text, hide_tags, sort_column): def get_date(info: dict, key): try: - return datetime.strptime(info.get(key), "%Y-%m-%dT%H:%M:%SZ").strftime("%Y-%m-%d") + return datetime.strptime(info.get(key), "%Y-%m-%dT%H:%M:%SZ").replace(tzinfo=timezone.utc).astimezone().strftime("%Y-%m-%d") except (ValueError, TypeError): return '' From ae0b2cc1964486ba847290ad752d9a284b6d63ba Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Wed, 30 Aug 2023 18:22:50 +0300 Subject: [PATCH 0049/1174] add an option to choose how to combine hires fix and refiner --- modules/processing.py | 18 ++++++------------ modules/sd_samplers_common.py | 13 +++++++++++-- modules/shared_options.py | 1 + 3 files changed, 18 insertions(+), 14 deletions(-) diff --git a/modules/processing.py b/modules/processing.py index 0138e5ac9be..f696e9251ff 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -1148,18 +1148,12 @@ def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subs else: decoded_samples = None - current = shared.sd_model.sd_checkpoint_info - try: - if self.hr_checkpoint_info is not None: - self.sampler = None - sd_models.reload_model_weights(info=self.hr_checkpoint_info) - devices.torch_gc() - - return self.sample_hr_pass(samples, decoded_samples, seeds, subseeds, subseed_strength, prompts) - finally: - self.sampler = None - sd_models.reload_model_weights(info=current) - devices.torch_gc() + with sd_models.SkipWritingToConfig(): + sd_models.reload_model_weights(info=self.hr_checkpoint_info) + + devices.torch_gc() + + return self.sample_hr_pass(samples, decoded_samples, seeds, subseeds, subseed_strength, prompts) def sample_hr_pass(self, samples, decoded_samples, seeds, subseeds, subseed_strength, prompts): if shared.state.interrupted: diff --git a/modules/sd_samplers_common.py b/modules/sd_samplers_common.py index 60fa161cc7e..6c935a38f77 100644 --- a/modules/sd_samplers_common.py +++ b/modules/sd_samplers_common.py @@ -164,8 +164,17 @@ def apply_refiner(cfg_denoiser): if refiner_checkpoint_info is None or shared.sd_model.sd_checkpoint_info == refiner_checkpoint_info: return False - if getattr(cfg_denoiser.p, "enable_hr", False) and not cfg_denoiser.p.is_hr_pass: - return False + if getattr(cfg_denoiser.p, "enable_hr", False): + is_second_pass = cfg_denoiser.p.is_hr_pass + + if opts.hires_fix_refiner_pass == "first pass" and is_second_pass: + return False + + if opts.hires_fix_refiner_pass == "second pass" and not is_second_pass: + return False + + if opts.hires_fix_refiner_pass != "second pass": + cfg_denoiser.p.extra_generation_params['Hires refiner'] = opts.hires_fix_refiner_pass cfg_denoiser.p.extra_generation_params['Refiner'] = refiner_checkpoint_info.short_title cfg_denoiser.p.extra_generation_params['Refiner switch at'] = refiner_switch_at diff --git a/modules/shared_options.py b/modules/shared_options.py index 78652ea27b7..00b273faa54 100644 --- a/modules/shared_options.py +++ b/modules/shared_options.py @@ -146,6 +146,7 @@ "upcast_attn": OptionInfo(False, "Upcast cross attention layer to float32"), "randn_source": OptionInfo("GPU", "Random number generator source.", gr.Radio, {"choices": ["GPU", "CPU", "NV"]}, infotext="RNG").info("changes seeds drastically; use CPU to produce the same picture across different videocard vendors; use NV to produce same picture as on NVidia videocards"), "tiling": OptionInfo(False, "Tiling", infotext='Tiling').info("produce a tileable picture"), + "hires_fix_refiner_pass": OptionInfo("second pass", "Hires fix: which pass to enable refiner for", gr.Radio, {"choices": ["first pass", "second pass", "both passes"]}, infotext="Hires refiner"), })) options_templates.update(options_section(('sdxl', "Stable Diffusion XL"), { From 6adf2b71c2c89f84d4aee1e230276dcd1a3fab62 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Wed, 30 Aug 2023 19:08:04 +0300 Subject: [PATCH 0050/1174] fix inpainting models in txt2img creating black pictures --- modules/processing.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/modules/processing.py b/modules/processing.py index f696e9251ff..e08b6305f99 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -91,8 +91,8 @@ def create_binary_mask(image): def txt2img_image_conditioning(sd_model, x, width, height): if sd_model.model.conditioning_key in {'hybrid', 'concat'}: # Inpainting models - # The "masked-image" in this case will just be all zeros since the entire image is masked. - image_conditioning = torch.zeros(x.shape[0], 3, height, width, device=x.device) + # The "masked-image" in this case will just be all 0.5 since the entire image is masked. + image_conditioning = torch.ones(x.shape[0], 3, height, width, device=x.device) * 0.5 image_conditioning = images_tensor_to_samples(image_conditioning, approximation_indexes.get(opts.sd_vae_encode_method)) # Add the fake full 1s mask to the first dimension. From 541a3db05ba7241b466f7370533e2bef24dbe9de Mon Sep 17 00:00:00 2001 From: ljleb Date: Wed, 30 Aug 2023 21:38:21 -0400 Subject: [PATCH 0051/1174] fix generation params regex --- modules/generation_parameters_copypaste.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py index 2ca1605547c..d39f2ebac36 100644 --- a/modules/generation_parameters_copypaste.py +++ b/modules/generation_parameters_copypaste.py @@ -9,7 +9,7 @@ from modules import shared, ui_tempdir, script_callbacks, processing from PIL import Image -re_param_code = r'\s*([\w ]+):\s*("(?:\\"[^,]|\\"|\\|[^\"])+"|[^,]*)(?:,|$)' +re_param_code = r'\s*([\w ]+):\s*("(?:\\.|[^\\"])+"|[^,]*)(?:,|$)' re_param = re.compile(re_param_code) re_imagesize = re.compile(r"^(\d+)x(\d+)$") re_hypernet_hash = re.compile("\(([0-9a-f]+)\)$") From 41196ccbf7552274cf111de24a43ebfa836175a6 Mon Sep 17 00:00:00 2001 From: zixaphir Date: Wed, 30 Aug 2023 20:20:19 -0700 Subject: [PATCH 0052/1174] account for customizable extra network separators in remove code previous behavior only searched for leading spaces --- javascript/extraNetworks.js | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/javascript/extraNetworks.js b/javascript/extraNetworks.js index 493f31af28a..eb2b9ebd709 100644 --- a/javascript/extraNetworks.js +++ b/javascript/extraNetworks.js @@ -141,9 +141,12 @@ function setupExtraNetworks() { onUiLoaded(setupExtraNetworks); var re_extranet = /<([^:]+:[^:]+):[\d.]+>(.*)/; -var re_extranet_g = /\s+<([^:]+:[^:]+):[\d.]+>/g; +var re_extranet_str = '<([^:]+:[^:]+):[\\d.]+>'; function tryToRemoveExtraNetworkFromPrompt(textarea, text) { + function reEscape(s) { + return s.replace(/[.*+?^${}()|[\]\\]/g, "\\$&"); + } var m = text.match(re_extranet); var replaced = false; var newTextareaText; @@ -151,7 +154,9 @@ function tryToRemoveExtraNetworkFromPrompt(textarea, text) { var extraTextAfterNet = m[2]; var partToSearch = m[1]; var foundAtPosition = -1; - newTextareaText = textarea.value.replaceAll(re_extranet_g, function(found, net, pos) { + var escapedSeparator = reEscape(opts.extra_networks_add_text_separator); + var re = new RegExp(escapedSeparator + re_extranet_str, 'g'); + newTextareaText = textarea.value.replaceAll(re, function(found, net, pos) { m = found.match(re_extranet); if (m[1] == partToSearch) { replaced = true; From 76b1ad7daf35f8667e07ff9cff9ef42b828b1b83 Mon Sep 17 00:00:00 2001 From: missionfloyd Date: Wed, 30 Aug 2023 23:07:18 -0600 Subject: [PATCH 0053/1174] Use default dropdown padding on mobile --- style.css | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/style.css b/style.css index fb4e2f1f02e..58eb29c1598 100644 --- a/style.css +++ b/style.css @@ -83,8 +83,10 @@ div.compact{ white-space: nowrap; } -.gradio-dropdown ul.options li.item { - padding: 0.05em 0; +@media (pointer:fine) { + .gradio-dropdown ul.options li.item { + padding: 0.05em 0; + } } .gradio-dropdown ul.options li.item.selected { From 348c6022f330c6e64a6a0fb40fd2b3e65bf0ce6a Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Fri, 1 Sep 2023 00:55:56 +0900 Subject: [PATCH 0054/1174] Action to calculate all SD checkpoint hashes --- modules/ui_settings.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/modules/ui_settings.py b/modules/ui_settings.py index 8ff9c074718..c6fe3604a1a 100644 --- a/modules/ui_settings.py +++ b/modules/ui_settings.py @@ -5,6 +5,7 @@ from modules.shared import opts from modules.ui_components import FormRow from modules.ui_gradio_extensions import reload_javascript +from concurrent.futures import ThreadPoolExecutor, as_completed def get_value_for_setting(key): @@ -175,6 +176,9 @@ def create_ui(self, loadsave, dummy_component): with gr.Row(): unload_sd_model = gr.Button(value='Unload SD checkpoint to free VRAM', elem_id="sett_unload_sd_model") reload_sd_model = gr.Button(value='Reload the last SD checkpoint back into VRAM', elem_id="sett_reload_sd_model") + with gr.Row(): + calculate_all_checkpoint_hash = gr.Button(value='Calculate hash for all checkpoint', elem_id="calculate_all_checkpoint_hash") + calculate_all_checkpoint_hash_threads = gr.Number(value=1, label="Number of parallel calculations", elem_id="calculate_all_checkpoint_hash_threads", precision=0, minimum=1) with gr.TabItem("Licenses", id="licenses", elem_id="settings_tab_licenses"): gr.HTML(shared.html("licenses.html"), elem_id="licenses") @@ -241,6 +245,21 @@ def check_file(x): outputs=[sysinfo_check_output], ) + def calculate_all_checkpoint_hash_fn(max_thread): + checkpoints_list = sd_models.checkpoints_list.values() + with ThreadPoolExecutor(max_workers=max_thread) as executor: + futures = [executor.submit(checkpoint.calculate_shorthash) for checkpoint in checkpoints_list] + completed = 0 + for _ in as_completed(futures): + completed += 1 + print(f"{completed} / {len(checkpoints_list)} ") + print("Finish calculating hash for all checkpoints") + + calculate_all_checkpoint_hash.click( + fn=calculate_all_checkpoint_hash_fn, + inputs=[calculate_all_checkpoint_hash_threads], + ) + self.interface = settings_interface def add_quicksettings(self): From 5681bf801664aa09fa02ab8b4e73f780d9563440 Mon Sep 17 00:00:00 2001 From: catboxanon <122327233+catboxanon@users.noreply.github.com> Date: Thu, 31 Aug 2023 14:57:16 -0400 Subject: [PATCH 0055/1174] More accurate check for enabling cuDNN benchmark on 16XX cards --- modules/devices.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/modules/devices.py b/modules/devices.py index c01f06024b4..63c38eff1f1 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -60,7 +60,8 @@ def enable_tf32(): # enabling benchmark option seems to enable a range of cards to do fp16 when they otherwise can't # see https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/4407 - if any(torch.cuda.get_device_capability(devid) == (7, 5) for devid in range(0, torch.cuda.device_count())): + device_id = (int(shared.cmd_opts.device_id) if shared.cmd_opts.device_id.isdigit() else 0) or torch.cuda.current_device() + if torch.cuda.get_device_capability(device_id) == (7, 5) and torch.cuda.get_device_name(device_id).startswith("NVIDIA GeForce GTX 16"): torch.backends.cudnn.benchmark = True torch.backends.cuda.matmul.allow_tf32 = True From bd9b3d15e8f631f9475d14a5fd07560c177dc2f3 Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Fri, 1 Sep 2023 04:05:58 +0900 Subject: [PATCH 0056/1174] fix batch img2img output dir with script --- modules/img2img.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/modules/img2img.py b/modules/img2img.py index c81c7ab9e8c..0c6d1af572c 100644 --- a/modules/img2img.py +++ b/modules/img2img.py @@ -114,11 +114,14 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=Fal else: p.override_settings.pop("sd_model_checkpoint", None) + if output_dir: + p.outpath_samples = output_dir + p.override_settings['save_to_dirs'] = False + proc = modules.scripts.scripts_img2img.run(p, *args) + if proc is None: if output_dir: - p.outpath_samples = output_dir - p.override_settings['save_to_dirs'] = False if p.n_iter > 1 or p.batch_size > 1: p.override_settings['samples_filename_pattern'] = f'{image_path.stem}-[generation_number]' else: From 78c1a74660a2e25b3960beb42e3a6f8419c8b3c3 Mon Sep 17 00:00:00 2001 From: zixaphir Date: Thu, 31 Aug 2023 14:18:35 -0700 Subject: [PATCH 0057/1174] Account for edge case where user deleted leading separator. --- javascript/extraNetworks.js | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/javascript/extraNetworks.js b/javascript/extraNetworks.js index eb2b9ebd709..ca87beade13 100644 --- a/javascript/extraNetworks.js +++ b/javascript/extraNetworks.js @@ -154,7 +154,7 @@ function tryToRemoveExtraNetworkFromPrompt(textarea, text) { var extraTextAfterNet = m[2]; var partToSearch = m[1]; var foundAtPosition = -1; - var escapedSeparator = reEscape(opts.extra_networks_add_text_separator); + var escapedSeparator = `(?:${reEscape(opts.extra_networks_add_text_separator)})?`; var re = new RegExp(escapedSeparator + re_extranet_str, 'g'); newTextareaText = textarea.value.replaceAll(re, function(found, net, pos) { m = found.match(re_extranet); From 737a013377dd698e620f39e405594f7688656af0 Mon Sep 17 00:00:00 2001 From: Beinsezii Date: Thu, 31 Aug 2023 15:03:08 -0700 Subject: [PATCH 0058/1174] WEBUI.SH Navi 3 torch 2.1.0 rc instead of nightly With the release candidates being out for both torch and vision, webui should default to these over nightly for a more stable experience. Stable release isn't excpected until October 4th: https://dev-discuss.pytorch.org/c/release-announcements/27 --- webui.sh | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/webui.sh b/webui.sh index 3d0f87eed74..29a6d31185a 100755 --- a/webui.sh +++ b/webui.sh @@ -141,9 +141,8 @@ case "$gpu_info" in *"Navi 2"*) export HSA_OVERRIDE_GFX_VERSION=10.3.0 ;; *"Navi 3"*) [[ -z "${TORCH_COMMAND}" ]] && \ - export TORCH_COMMAND="pip install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/rocm5.6" - # Navi 3 needs at least 5.5 which is only on the nightly chain, previous versions are no longer online (torch==2.1.0.dev-20230614+rocm5.5 torchvision==0.16.0.dev-20230614+rocm5.5 torchaudio==2.1.0.dev-20230614+rocm5.5) - # so switch to nightly rocm5.6 without explicit versions this time + export TORCH_COMMAND="pip install torch torchvision --index-url https://download.pytorch.org/whl/test/rocm5.6" + # Navi 3 needs at least 5.5 which is only on the torch 2.1.0 release candidates right now ;; *"Renoir"*) export HSA_OVERRIDE_GFX_VERSION=9.0.0 printf "\n%s\n" "${delimiter}" From 317d00b2a6f81eb58e33487abf05a8b84ef01dd0 Mon Sep 17 00:00:00 2001 From: AnyISalIn Date: Fri, 1 Sep 2023 21:45:11 +0800 Subject: [PATCH 0059/1174] fix: update shared.opts.data when add_option Signed-off-by: AnyISalIn --- modules/options.py | 1 + 1 file changed, 1 insertion(+) diff --git a/modules/options.py b/modules/options.py index 758b1ce5f24..e75916d2be2 100644 --- a/modules/options.py +++ b/modules/options.py @@ -210,6 +210,7 @@ def dumpjson(self): def add_option(self, key, info): self.data_labels[key] = info + self.data[key] = info.default def reorder(self): """reorder settings so that all items related to section always go together""" From bf0b08321688f65905168b6444d6d13b1a1d9d91 Mon Sep 17 00:00:00 2001 From: missionfloyd Date: Fri, 1 Sep 2023 16:14:33 -0600 Subject: [PATCH 0060/1174] Add button to copy prompt to style editor --- modules/ui_prompt_styles.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/modules/ui_prompt_styles.py b/modules/ui_prompt_styles.py index 85eb3a6417e..46a26573b40 100644 --- a/modules/ui_prompt_styles.py +++ b/modules/ui_prompt_styles.py @@ -4,6 +4,7 @@ styles_edit_symbol = '\U0001f58c\uFE0F' # 🖌️ styles_materialize_symbol = '\U0001f4cb' # 📋 +styles_copy_symbol = '\U0001f4dd' # 📝 def select_style(name): @@ -62,6 +63,7 @@ def __init__(self, tabname, main_ui_prompt, main_ui_negative_prompt): self.selection = gr.Dropdown(label="Styles", elem_id=f"{tabname}_styles_edit_select", choices=list(shared.prompt_styles.styles), value=[], allow_custom_value=True, info="Styles allow you to add custom text to prompt. Use the {prompt} token in style text, and it will be replaced with user's prompt when applying style. Otherwise, style's text will be added to the end of the prompt.") ui_common.create_refresh_button([self.dropdown, self.selection], shared.prompt_styles.reload, lambda: {"choices": list(shared.prompt_styles.styles)}, f"refresh_{tabname}_styles") self.materialize = ui_components.ToolButton(value=styles_materialize_symbol, elem_id=f"{tabname}_style_apply", tooltip="Apply all selected styles from the style selction dropdown in main UI to the prompt.") + self.copy = ui_components.ToolButton(value=styles_copy_symbol, elem_id=f"{tabname}_style_copy", tooltip="Copy main UI prompt to style.") with gr.Row(): self.prompt = gr.Textbox(label="Prompt", show_label=True, elem_id=f"{tabname}_edit_style_prompt", lines=3) @@ -102,6 +104,13 @@ def __init__(self, tabname, main_ui_prompt, main_ui_negative_prompt): outputs=[main_ui_prompt, main_ui_negative_prompt, self.dropdown], show_progress=False, ).then(fn=None, _js="function(){update_"+tabname+"_tokens(); closePopup();}", show_progress=False) + + self.copy.click( + fn=lambda p, n: (p, n), + inputs=[main_ui_prompt, main_ui_negative_prompt], + outputs=[self.prompt, self.neg_prompt], + show_progress=False, + ) ui_common.setup_dialog(button_show=edit_button, dialog=styles_dialog, button_close=self.close) From d7e3ea68b3604aaec6607aad3272e999657e6331 Mon Sep 17 00:00:00 2001 From: missionfloyd Date: Fri, 1 Sep 2023 16:24:35 -0600 Subject: [PATCH 0061/1174] Remove whitespace --- modules/ui_prompt_styles.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/ui_prompt_styles.py b/modules/ui_prompt_styles.py index 46a26573b40..64d379ef65e 100644 --- a/modules/ui_prompt_styles.py +++ b/modules/ui_prompt_styles.py @@ -104,7 +104,7 @@ def __init__(self, tabname, main_ui_prompt, main_ui_negative_prompt): outputs=[main_ui_prompt, main_ui_negative_prompt, self.dropdown], show_progress=False, ).then(fn=None, _js="function(){update_"+tabname+"_tokens(); closePopup();}", show_progress=False) - + self.copy.click( fn=lambda p, n: (p, n), inputs=[main_ui_prompt, main_ui_negative_prompt], From 3e67017dfb767f18f599f13e62fff9355ea14160 Mon Sep 17 00:00:00 2001 From: missionfloyd Date: Fri, 1 Sep 2023 17:01:08 -0600 Subject: [PATCH 0062/1174] Restore missing tooltips --- modules/processing_scripts/seed.py | 4 ++-- modules/ui.py | 12 ++++++------ modules/ui_extra_networks.py | 2 +- scripts/postprocessing_upscale.py | 2 +- 4 files changed, 10 insertions(+), 10 deletions(-) diff --git a/modules/processing_scripts/seed.py b/modules/processing_scripts/seed.py index 6b6ff987d2d..dc9c2da5000 100644 --- a/modules/processing_scripts/seed.py +++ b/modules/processing_scripts/seed.py @@ -29,8 +29,8 @@ def ui(self, is_img2img): else: self.seed = gr.Number(label='Seed', value=-1, elem_id=self.elem_id("seed"), min_width=100, precision=0) - random_seed = ToolButton(ui.random_symbol, elem_id=self.elem_id("random_seed"), label='Random seed') - reuse_seed = ToolButton(ui.reuse_symbol, elem_id=self.elem_id("reuse_seed"), label='Reuse seed') + random_seed = ToolButton(ui.random_symbol, elem_id=self.elem_id("random_seed"), tooltip="Set seed to -1, which will cause a new random number to be used every time") + reuse_seed = ToolButton(ui.reuse_symbol, elem_id=self.elem_id("reuse_seed"), tooltip="Reuse seed from last generation, mostly useful if it was randomized") seed_checkbox = gr.Checkbox(label='Extra', elem_id=self.elem_id("subseed_show"), value=False) diff --git a/modules/ui.py b/modules/ui.py index 579bab9800c..89173053866 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -215,9 +215,9 @@ def __init__(self, is_img2img): ) with gr.Row(elem_id=f"{id_part}_tools"): - self.paste = ToolButton(value=paste_symbol, elem_id="paste") - self.clear_prompt_button = ToolButton(value=clear_prompt_symbol, elem_id=f"{id_part}_clear_prompt") - self.restore_progress_button = ToolButton(value=restore_progress_symbol, elem_id=f"{id_part}_restore_progress", visible=False) + self.paste = ToolButton(value=paste_symbol, elem_id="paste", tooltip="Read generation parameters from prompt or last generation if prompt is empty into user interface.") + self.clear_prompt_button = ToolButton(value=clear_prompt_symbol, elem_id=f"{id_part}_clear_prompt", tooltip="Clear prompt") + self.restore_progress_button = ToolButton(value=restore_progress_symbol, elem_id=f"{id_part}_restore_progress", visible=False, tooltip="Restore progress") self.token_counter = gr.HTML(value="0/75", elem_id=f"{id_part}_token_counter", elem_classes=["token-counter"]) self.token_button = gr.Button(visible=False, elem_id=f"{id_part}_token_button") @@ -348,7 +348,7 @@ def create_ui(): height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="txt2img_height") with gr.Column(elem_id="txt2img_dimensions_row", scale=1, elem_classes="dimensions-tools"): - res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="txt2img_res_switch_btn", label="Switch dims") + res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="txt2img_res_switch_btn", tooltip="Switch width/height") if opts.dimensions_and_batch_together: with gr.Column(elem_id="txt2img_column_batch"): @@ -661,8 +661,8 @@ def copy_image(img): width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="img2img_width") height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="img2img_height") with gr.Column(elem_id="img2img_dimensions_row", scale=1, elem_classes="dimensions-tools"): - res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="img2img_res_switch_btn") - detect_image_size_btn = ToolButton(value=detect_image_size_symbol, elem_id="img2img_detect_image_size_btn") + res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="img2img_res_switch_btn", tooltip="Switch width/height") + detect_image_size_btn = ToolButton(value=detect_image_size_symbol, elem_id="img2img_detect_image_size_btn", tooltip="Auto detect size from img2img") with gr.Tab(label="Resize by", elem_id="img2img_tab_resize_by") as tab_scale_by: scale_by = gr.Slider(minimum=0.05, maximum=4.0, step=0.05, label="Scale", value=1.0, elem_id="img2img_scale") diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index 063bd7b80e6..21eed6a1d6b 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -374,7 +374,7 @@ def create_ui(interface: gr.Blocks, unrelated_tabs, tabname): edit_search = gr.Textbox('', show_label=False, elem_id=tabname+"_extra_search", elem_classes="search", placeholder="Search...", visible=False, interactive=True) dropdown_sort = gr.Dropdown(choices=['Default Sort', 'Date Created', 'Date Modified', 'Name'], value='Default Sort', elem_id=tabname+"_extra_sort", elem_classes="sort", multiselect=False, visible=False, show_label=False, interactive=True, label=tabname+"_extra_sort_order") - button_sortorder = ToolButton(switch_values_symbol, elem_id=tabname+"_extra_sortorder", elem_classes="sortorder", visible=False) + button_sortorder = ToolButton(switch_values_symbol, elem_id=tabname+"_extra_sortorder", elem_classes="sortorder", visible=False, tooltip="Invert sort order") button_refresh = gr.Button('Refresh', elem_id=tabname+"_extra_refresh", visible=False) checkbox_show_dirs = gr.Checkbox(True, label='Show dirs', elem_id=tabname+"_extra_show_dirs", elem_classes="show-dirs", visible=False) diff --git a/scripts/postprocessing_upscale.py b/scripts/postprocessing_upscale.py index edb70ac01ca..eb42a29e521 100644 --- a/scripts/postprocessing_upscale.py +++ b/scripts/postprocessing_upscale.py @@ -29,7 +29,7 @@ def ui(self): upscaling_resize_w = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="extras_upscaling_resize_w") upscaling_resize_h = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="extras_upscaling_resize_h") with gr.Column(elem_id="upscaling_dimensions_row", scale=1, elem_classes="dimensions-tools"): - upscaling_res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="upscaling_res_switch_btn") + upscaling_res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="upscaling_res_switch_btn", tooltip="Switch width/height") upscaling_crop = gr.Checkbox(label='Crop to fit', value=True, elem_id="extras_upscaling_crop") with FormRow(): From ba05e327896898eb73caec3ed710fe45d1e38732 Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Sat, 2 Sep 2023 14:12:59 +0900 Subject: [PATCH 0063/1174] update cmd arg description --- modules/cmd_args.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/modules/cmd_args.py b/modules/cmd_args.py index aab62286e24..a77c7e77dc3 100644 --- a/modules/cmd_args.py +++ b/modules/cmd_args.py @@ -112,8 +112,8 @@ parser.add_argument("--no-hashing", action='store_true', help="disable sha256 hashing of checkpoints to help loading performance", default=False) parser.add_argument("--no-download-sd-model", action='store_true', help="don't download SD1.5 model even if no model is found in --ckpt-dir", default=False) parser.add_argument('--subpath', type=str, help='customize the subpath for gradio, use with reverse proxy') -parser.add_argument('--add-stop-route', action='store_true', help='add /_stop route to stop server') +parser.add_argument('--add-stop-route', action='store_true', help='does not do anything') parser.add_argument('--api-server-stop', action='store_true', help='enable server stop/restart/kill via api') parser.add_argument('--timeout-keep-alive', type=int, default=30, help='set timeout_keep_alive for uvicorn') parser.add_argument("--disable-all-extensions", action='store_true', help="prevent all extensions from running regardless of any other settings", default=False) -parser.add_argument("--disable-extra-extensions", action='store_true', help=" prevent all extensions except built-in from running regardless of any other settings", default=False) +parser.add_argument("--disable-extra-extensions", action='store_true', help="prevent all extensions except built-in from running regardless of any other settings", default=False) From 061a4a295dd53f089ac460bc2c1585491ef26f24 Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Sat, 2 Sep 2023 18:11:08 +0900 Subject: [PATCH 0064/1174] Update bug_report.yml --- .github/ISSUE_TEMPLATE/bug_report.yml | 90 ++++++++++++++++++++++----- 1 file changed, 74 insertions(+), 16 deletions(-) diff --git a/.github/ISSUE_TEMPLATE/bug_report.yml b/.github/ISSUE_TEMPLATE/bug_report.yml index cf6a2be86fa..70c2e16058b 100644 --- a/.github/ISSUE_TEMPLATE/bug_report.yml +++ b/.github/ISSUE_TEMPLATE/bug_report.yml @@ -4,22 +4,45 @@ title: "[Bug]: " labels: ["bug-report"] body: + - type: markdown + attributes: + value: | + > The title the bug report should be short and descriptive + > Use relevant keywords for searchability + > Don't leave it blank but also don't put the entire error log in it - type: checkboxes attributes: - label: Is there an existing issue for this? - description: Please search to see if an issue already exists for the bug you encountered, and that it hasn't been fixed in a recent build/commit. + label: Checklist + description: | + Please perform basic debugging to see if extensions or configuration is the cause of the issue. + Basic debug procedure +  1. Disable all third-party extensions - check if extension is the cause +  2. Update extensions and webui - sometimes thing just need to be updated +  3. Backup and remove your config.json and ui-config.json - check if the issue is caused bed configuration +  4. delete venv with third-party extensions disable - sometimes extensions might cause wrong libraries to be installed +  5. try a fresh installation webui in a different directory - see if a clean installation solves the issue + Before making a issue report please check that the issue hasn't been reported recently options: - - label: I have searched the existing issues and checked the recent builds/commits - required: true + - label: The issue exist after disabling all extensions + - label: The issue exist on a clean installation of webui + - label: The issue is caused by an extension but it is caused by a bug in webui + - label: The issue exist in current version of webui + - label: The issue haven't been reported before recently + - label: The issue has been reported before but hasn't been fixed yet - type: markdown attributes: value: | - *Please fill this form with as much information as possible, don't forget to fill "What OS..." and "What browsers" and *provide screenshots if possible** + > Please fill this form with as much information as possible, don't forget to "Upload Sysinfo" and "What browsers" and provide screenshots if possible - type: textarea id: what-did attributes: label: What happened? description: Tell us what happened in a very clear and simple way + placeholder: | + I tried to use txt2img with XYZ grid with Sampler DPM++ SDE,DPM++ 2M SDE + it should generate a grid of 2 images but I only got 1 + + add screenshot or screen recording if necessary validations: required: true - type: textarea @@ -27,10 +50,10 @@ body: attributes: label: Steps to reproduce the problem description: Please provide us with precise step by step instructions on how to reproduce the bug - value: | - 1. Go to .... - 2. Press .... - 3. ... + placeholder: | + 1. Go to txt2img tab Select XYZ grid + 2. Set axis type Sampler and select DPM++ 2M SDE, DPM++ 3M SDE + 3. Set Sampling steps to 1 and click Generate button validations: required: true - type: textarea @@ -38,13 +61,9 @@ body: attributes: label: What should have happened? description: Tell us what you think the normal behavior should be - validations: - required: true - - type: textarea - id: sysinfo - attributes: - label: Sysinfo - description: System info file, generated by WebUI. You can generate it in settings, on the Sysinfo page. Drag the file into the field to upload it. If you submit your report without including the sysinfo file, the report will be closed. If needed, review the report to make sure it includes no personal information you don't want to share. If you can't start WebUI, you can use --dump-sysinfo commandline argument to generate the file. + placeholder: | + It should generate a grid of 2 images + this was working in webui version 1.x.x validations: required: true - type: dropdown @@ -58,13 +77,47 @@ body: - Brave - Apple Safari - Microsoft Edge + - Android + - iOS - Other + - type: textarea + id: sysinfo + attributes: + label: Sysinfo + description: System info file, generated by WebUI. You can generate it in settings, on the Sysinfo page. Drag the file into the field to upload it. If you submit your report without including the sysinfo file, the report will be closed. If needed, review the report to make sure it includes no personal information you don't want to share. If you can't start WebUI, you can use --dump-sysinfo commandline argument to generate the file. + placeholder: | + Upload the Sysinfo as a attached file + Don't paste it in as text + validations: + required: true - type: textarea id: logs attributes: label: Console logs description: Please provide **full** cmd/terminal logs from the moment you started UI to the end of it, after your bug happened. If it's very long, provide a link to pastebin or similar service. render: Shell + placeholder: | + generating image for xyz plot: UnboundLocalError + Traceback (most recent call last): + File "B:\GitHub\stable-diffusion-webui\scripts\xyz_grid.py", line 698, in cell + res = process_images(pc) + File "B:\GitHub\stable-diffusion-webui\modules\processing.py", line 732, in process_images + res = process_images_inner(p) + File "B:\GitHub\stable-diffusion-webui\modules\processing.py", line 867, in process_images_inner + samples_ddim = p.sample(conditioning=p.c, unconditional_conditioning=p.uc, seeds=p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, prompts=p.prompts) + File "B:\GitHub\stable-diffusion-webui\modules\processing.py", line 1140, in sample + samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x)) + File "B:\GitHub\stable-diffusion-webui\modules\sd_samplers_kdiffusion.py", line 235, in sample + samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, x, extra_args=self.sampler_extra_args, disable=False, callback=self.callback_state, **extra_params_kwargs)) + File "B:\GitHub\stable-diffusion-webui\modules\sd_samplers_common.py", line 261, in launch_sampling + return func() + File "B:\GitHub\stable-diffusion-webui\modules\sd_samplers_kdiffusion.py", line 235, in + samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, x, extra_args=self.sampler_extra_args, disable=False, callback=self.callback_state, **extra_params_kwargs)) + File "B:\GitHub\stable-diffusion-webui\venv\lib\site-packages\torch\utils\_contextlib.py", line 115, in decorate_context + return func(*args, **kwargs) + File "B:\GitHub\stable-diffusion-webui\repositories\k-diffusion\k_diffusion\sampling.py", line 651, in sample_dpmpp_2m_sde + h_last = h + UnboundLocalError: local variable 'h' referenced before assignment validations: required: true - type: textarea @@ -72,3 +125,8 @@ body: attributes: label: Additional information description: Please provide us with any relevant additional info or context. + placeholder: | + Examples + I have updated the GPU driver recently + I suspect the issue is caused by XXXXX + I am using a VPN From a51721cb09aa8dc68beb08cf8f0a2602b41d052c Mon Sep 17 00:00:00 2001 From: uservar <63248296+uservar@users.noreply.github.com> Date: Sat, 2 Sep 2023 11:35:30 +0000 Subject: [PATCH 0065/1174] Fix bug with sigma min/max overrides. --- modules/shared_options.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/modules/shared_options.py b/modules/shared_options.py index 00b273faa54..73588a22188 100644 --- a/modules/shared_options.py +++ b/modules/shared_options.py @@ -305,8 +305,8 @@ 's_tmax': OptionInfo(0.0, "sigma tmax", gr.Slider, {"minimum": 0.0, "maximum": 999.0, "step": 0.01}, infotext='Sigma tmax').info("0 = inf; end value of the sigma range; only applies to Euler, Heun, and DPM2"), 's_noise': OptionInfo(1.0, "sigma noise", gr.Slider, {"minimum": 0.0, "maximum": 1.1, "step": 0.001}, infotext='Sigma noise').info('amount of additional noise to counteract loss of detail during sampling'), 'k_sched_type': OptionInfo("Automatic", "Scheduler type", gr.Dropdown, {"choices": ["Automatic", "karras", "exponential", "polyexponential"]}, infotext='Schedule type').info("lets you override the noise schedule for k-diffusion samplers; choosing Automatic disables the three parameters below"), - 'sigma_min': OptionInfo(0.0, "sigma min", gr.Number, infotext='Schedule max sigma').info("0 = default (~0.03); minimum noise strength for k-diffusion noise scheduler"), - 'sigma_max': OptionInfo(0.0, "sigma max", gr.Number, infotext='Schedule min sigma').info("0 = default (~14.6); maximum noise strength for k-diffusion noise scheduler"), + 'sigma_min': OptionInfo(0.0, "sigma min", gr.Number, infotext='Schedule min sigma').info("0 = default (~0.03); minimum noise strength for k-diffusion noise scheduler"), + 'sigma_max': OptionInfo(0.0, "sigma max", gr.Number, infotext='Schedule max sigma').info("0 = default (~14.6); maximum noise strength for k-diffusion noise scheduler"), 'rho': OptionInfo(0.0, "rho", gr.Number, infotext='Schedule rho').info("0 = default (7 for karras, 1 for polyexponential); higher values result in a steeper noise schedule (decreases faster)"), 'eta_noise_seed_delta': OptionInfo(0, "Eta noise seed delta", gr.Number, {"precision": 0}, infotext='ENSD').info("ENSD; does not improve anything, just produces different results for ancestral samplers - only useful for reproducing images"), 'always_discard_next_to_last_sigma': OptionInfo(False, "Always discard next-to-last sigma", infotext='Discard penultimate sigma').link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/6044"), From aab385d01b4311726127397552d791f4d71b7147 Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Sun, 3 Sep 2023 11:56:02 +0900 Subject: [PATCH 0066/1174] thread safe extra network list_items --- extensions-builtin/Lora/ui_extra_networks_lora.py | 10 +++++----- modules/ui_extra_networks.py | 2 ++ modules/ui_extra_networks_checkpoints.py | 6 +++--- modules/ui_extra_networks_hypernets.py | 5 +++-- modules/ui_extra_networks_textual_inversion.py | 5 +++-- 5 files changed, 16 insertions(+), 12 deletions(-) diff --git a/extensions-builtin/Lora/ui_extra_networks_lora.py b/extensions-builtin/Lora/ui_extra_networks_lora.py index 55409a7829d..e9f300621ac 100644 --- a/extensions-builtin/Lora/ui_extra_networks_lora.py +++ b/extensions-builtin/Lora/ui_extra_networks_lora.py @@ -66,11 +66,11 @@ def create_item(self, name, index=None, enable_filter=True): return item def list_items(self): - for index, name in enumerate(networks.available_networks): - item = self.create_item(name, index) - - if item is not None: - yield item + with self.thread_lock: + for index, name in enumerate(networks.available_networks): + item = self.create_item(name, index) + if item is not None: + yield item def allowed_directories_for_previews(self): return [shared.cmd_opts.lora_dir, shared.cmd_opts.lyco_dir_backcompat] diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index 063bd7b80e6..564bab7fe02 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -1,6 +1,7 @@ import os.path import urllib.parse from pathlib import Path +from threading import Lock from modules import shared, ui_extra_networks_user_metadata, errors, extra_networks from modules.images import read_info_from_image, save_image_with_geninfo @@ -94,6 +95,7 @@ def __init__(self, title): self.allow_negative_prompt = False self.metadata = {} self.items = {} + self.thread_lock = Lock() def refresh(self): pass diff --git a/modules/ui_extra_networks_checkpoints.py b/modules/ui_extra_networks_checkpoints.py index ca6c26076f9..2753214fa10 100644 --- a/modules/ui_extra_networks_checkpoints.py +++ b/modules/ui_extra_networks_checkpoints.py @@ -30,9 +30,9 @@ def create_item(self, name, index=None, enable_filter=True): } def list_items(self): - names = list(sd_models.checkpoints_list) - for index, name in enumerate(names): - yield self.create_item(name, index) + with self.thread_lock: + for index, name in enumerate(sd_models.checkpoints_list): + yield self.create_item(name, index) def allowed_directories_for_previews(self): return [v for v in [shared.cmd_opts.ckpt_dir, sd_models.model_path] if v is not None] diff --git a/modules/ui_extra_networks_hypernets.py b/modules/ui_extra_networks_hypernets.py index 4cedf085196..411b4f11145 100644 --- a/modules/ui_extra_networks_hypernets.py +++ b/modules/ui_extra_networks_hypernets.py @@ -31,8 +31,9 @@ def create_item(self, name, index=None, enable_filter=True): } def list_items(self): - for index, name in enumerate(shared.hypernetworks): - yield self.create_item(name, index) + with self.thread_lock: + for index, name in enumerate(shared.hypernetworks): + yield self.create_item(name, index) def allowed_directories_for_previews(self): return [shared.cmd_opts.hypernetwork_dir] diff --git a/modules/ui_extra_networks_textual_inversion.py b/modules/ui_extra_networks_textual_inversion.py index 55ef0ea7b54..d25b45d6133 100644 --- a/modules/ui_extra_networks_textual_inversion.py +++ b/modules/ui_extra_networks_textual_inversion.py @@ -29,8 +29,9 @@ def create_item(self, name, index=None, enable_filter=True): } def list_items(self): - for index, name in enumerate(sd_hijack.model_hijack.embedding_db.word_embeddings): - yield self.create_item(name, index) + with self.thread_lock: + for index, name in enumerate(sd_hijack.model_hijack.embedding_db.word_embeddings): + yield self.create_item(name, index) def allowed_directories_for_previews(self): return list(sd_hijack.model_hijack.embedding_db.embedding_dirs) From f593cbfec417a3ea40589fe64e7f8806c9a81e5a Mon Sep 17 00:00:00 2001 From: AngelBottomless Date: Sun, 3 Sep 2023 21:07:36 +0900 Subject: [PATCH 0067/1174] fallback if exif data was invalid --- modules/images.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/modules/images.py b/modules/images.py index eb644733898..8512a46ec76 100644 --- a/modules/images.py +++ b/modules/images.py @@ -718,7 +718,12 @@ def read_info_from_image(image: Image.Image) -> tuple[str | None, dict]: geninfo = items.pop('parameters', None) if "exif" in items: - exif = piexif.load(items["exif"]) + exif_data = items["exif"] + try: + exif = piexif.load(exif_data) + except OSError: + # memory / exif was not valid so piexif tried to read from a file + exif = None exif_comment = (exif or {}).get("Exif", {}).get(piexif.ExifIFD.UserComment, b'') try: exif_comment = piexif.helper.UserComment.load(exif_comment) From 8f3b02f09535f55d3673aa9ea589396b8614f799 Mon Sep 17 00:00:00 2001 From: JaredTherriault Date: Sun, 3 Sep 2023 13:31:42 -0700 Subject: [PATCH 0068/1174] Revert "Offloading custom work" This reverts commit f3d1631aab82d559294126a9230c979ef4c4e1d6. This work has been offloaded now into an extension called Prompt Control. --- modules/custom_statics.py | 29 ---------------------- modules/generation_parameters_copypaste.py | 8 ------ webui-user.bat | 4 +-- 3 files changed, 2 insertions(+), 39 deletions(-) delete mode 100644 modules/custom_statics.py diff --git a/modules/custom_statics.py b/modules/custom_statics.py deleted file mode 100644 index 207bd5fbb90..00000000000 --- a/modules/custom_statics.py +++ /dev/null @@ -1,29 +0,0 @@ -import os -import gc -import re - -import modules.paths as paths - -class CustomStatics: - - @staticmethod - # loads a file with strings structured as below, on each line with a : between the search and replace strings, into a list - # search0:replace0 - # search string:replace string - # - # Then replaces all occurrences of the list's search strings with the list's replace strings in one go - def mass_replace_strings(input_string): - with open(os.path.join(paths.data_path, "custom_statics/Replacements.txt"), "r", encoding="utf8") as file: - replacements = file.readlines() - - replacement_dict = {} - for line in replacements: - search, replace = line.strip().split(":") - replacement_dict[search] = replace - - def replace(match_text): - return replacement_dict[match_text.group(0)] - - return re.sub('|'.join(r'\b%s\b' % re.escape(s) for s in replacement_dict.keys()), replace, str(input_string)) - - return str(geninfo) \ No newline at end of file diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py index bd7b0018350..a3448be9db8 100644 --- a/modules/generation_parameters_copypaste.py +++ b/modules/generation_parameters_copypaste.py @@ -370,14 +370,6 @@ def paste_func(prompt): prompt = file.read() params = parse_generation_parameters(prompt) - - # This sanitizes unsavory prompt words when copying from another image - # for my own sanity. This is not intended to be contributed to the main repo, - # it's just so I don't have to see anything I'm not interested in when batch - # reproducing images from civit.ai or elsewhere when working on loras - # todo: make this work with the callback instead of forcing it here, this can be an extension when I feel like putting it together :D - from modules import custom_statics - params = custom_statics.CustomStatics.mass_replace_strings(params) script_callbacks.infotext_pasted_callback(prompt, params) res = [] diff --git a/webui-user.bat b/webui-user.bat index 1ba2116d01c..e5a257bef06 100644 --- a/webui-user.bat +++ b/webui-user.bat @@ -1,8 +1,8 @@ @echo off -set TEMP=G:\SD-temp + set PYTHON= set GIT= set VENV_DIR= -set COMMANDLINE_ARGS= --xformers +set COMMANDLINE_ARGS= call webui.bat From 022639a145751d61db1c144e5e657aa6481e2bc0 Mon Sep 17 00:00:00 2001 From: JaredTherriault Date: Mon, 4 Sep 2023 17:37:48 -0700 Subject: [PATCH 0069/1174] Load comments from gif images to gather geninfo from gif outputs --- modules/images.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/modules/images.py b/modules/images.py index eb644733898..8c6e862feaf 100644 --- a/modules/images.py +++ b/modules/images.py @@ -728,6 +728,8 @@ def read_info_from_image(image: Image.Image) -> tuple[str | None, dict]: if exif_comment: items['exif comment'] = exif_comment geninfo = exif_comment + elif "comment" in items: # for gif + geninfo = items["comment"].decode('utf8', errors="ignore") for field in IGNORED_INFO_KEYS: items.pop(field, None) From 0c1c9e74cda6637ab1305b4c294b7719eb141927 Mon Sep 17 00:00:00 2001 From: liubo0902 <38622806+liubo0902@users.noreply.github.com> Date: Tue, 5 Sep 2023 15:06:47 +0800 Subject: [PATCH 0070/1174] Update localization.py --- modules/localization.py | 26 ++++++++++++++++---------- 1 file changed, 16 insertions(+), 10 deletions(-) diff --git a/modules/localization.py b/modules/localization.py index c132028856f..3392b0557af 100644 --- a/modules/localization.py +++ b/modules/localization.py @@ -1,7 +1,7 @@ import json import os -from modules import errors, scripts +from modules import errors localizations = {} @@ -14,21 +14,27 @@ def list_localizations(dirname): if ext.lower() != ".json": continue - localizations[fn] = os.path.join(dirname, file) + fn = fn.replace(" ", "").replace("(", "_").replace(")","") + localizations[fn] = [os.path.join(dirname, file)] + from modules import scripts for file in scripts.list_scripts("localizations", ".json"): fn, ext = os.path.splitext(file.filename) - localizations[fn] = file.path + fn = fn.replace(" ", "").replace("(", "_").replace(")","") + if fn not in localizations: + localizations[fn] = [] + localizations[fn].append(file.path) def localization_js(current_localization_name: str) -> str: - fn = localizations.get(current_localization_name, None) + fns = localizations.get(current_localization_name, None) data = {} - if fn is not None: - try: - with open(fn, "r", encoding="utf8") as file: - data = json.load(file) - except Exception: - errors.report(f"Error loading localization from {fn}", exc_info=True) + if fns is not None: + for fn in fns: + try: + with open(fn, "r", encoding="utf8") as file: + data.update(json.load(file)) + except Exception: + errors.report(f"Error loading localization from {fn}", exc_info=True) return f"window.localization = {json.dumps(data)}" From ff7027ffc075ae44ddaa56014f900d392cf53ca8 Mon Sep 17 00:00:00 2001 From: liubo0902 <38622806+liubo0902@users.noreply.github.com> Date: Tue, 5 Sep 2023 15:08:59 +0800 Subject: [PATCH 0071/1174] Update localization.py --- modules/localization.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/modules/localization.py b/modules/localization.py index 3392b0557af..262d49ee1c8 100644 --- a/modules/localization.py +++ b/modules/localization.py @@ -1,7 +1,7 @@ import json import os -from modules import errors +from modules import errors, scripts localizations = {} @@ -17,7 +17,6 @@ def list_localizations(dirname): fn = fn.replace(" ", "").replace("(", "_").replace(")","") localizations[fn] = [os.path.join(dirname, file)] - from modules import scripts for file in scripts.list_scripts("localizations", ".json"): fn, ext = os.path.splitext(file.filename) fn = fn.replace(" ", "").replace("(", "_").replace(")","") From de5bb4ca88df44362c9263de7334b30156540e21 Mon Sep 17 00:00:00 2001 From: AngelBottomless Date: Tue, 5 Sep 2023 22:35:17 +0900 Subject: [PATCH 0072/1174] Fix #13080 - Hypernetwork/TI preview generation Fixes sampler name reference Same patch will be done for TI. --- modules/hypernetworks/hypernetwork.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 70f1cbd26b6..65b63f2f81f 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -468,7 +468,7 @@ def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, shared.reload_hypernetworks() -def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, varsize, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, use_weight, create_image_every, save_hypernetwork_every, template_filename, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): +def train_hypernetwork(id_task, hypernetwork_name:str, learn_rate:float, batch_size:int, gradient_step:int, data_root:str, log_directory:str, training_width:int, training_height:int, varsize:bool, steps:int, clip_grad_mode:str, clip_grad_value:float, shuffle_tags:bool, tag_drop_out:bool, latent_sampling_method:str, use_weight:bool, create_image_every:int, save_hypernetwork_every:int, template_filename:str, preview_from_txt2img:bool, preview_prompt:str, preview_negative_prompt:str, preview_steps:int, preview_sampler_name:str, preview_cfg_scale:float, preview_seed:int, preview_width:int, preview_height:int): from modules import images, processing save_hypernetwork_every = save_hypernetwork_every or 0 @@ -698,7 +698,7 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi p.prompt = preview_prompt p.negative_prompt = preview_negative_prompt p.steps = preview_steps - p.sampler_name = sd_samplers.samplers[preview_sampler_index].name + p.sampler_name = sd_samplers.samplers_map[preview_sampler_name.lower()] p.cfg_scale = preview_cfg_scale p.seed = preview_seed p.width = preview_width From 47033afa5c08e72b622348b0bcfd71fd1a66e2cb Mon Sep 17 00:00:00 2001 From: AngelBottomless Date: Tue, 5 Sep 2023 22:38:02 +0900 Subject: [PATCH 0073/1174] Fix preview for textual inversion training --- modules/textual_inversion/textual_inversion.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index aa79dc09843..401a0a2ab02 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -386,7 +386,7 @@ def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, dat assert log_directory, "Log directory is empty" -def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, varsize, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, use_weight, create_image_every, save_embedding_every, template_filename, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): +def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, varsize, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, use_weight, create_image_every, save_embedding_every, template_filename, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_name, preview_cfg_scale, preview_seed, preview_width, preview_height): from modules import processing save_embedding_every = save_embedding_every or 0 @@ -590,7 +590,7 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st p.prompt = preview_prompt p.negative_prompt = preview_negative_prompt p.steps = preview_steps - p.sampler_name = sd_samplers.samplers[preview_sampler_index].name + p.sampler_name = sd_samplers.samplers_map[preview_sampler_name.lower()] p.cfg_scale = preview_cfg_scale p.seed = preview_seed p.width = preview_width From 25189b29afb04ff0c203e7f666c8acaead09dcde Mon Sep 17 00:00:00 2001 From: catboxanon <122327233+catboxanon@users.noreply.github.com> Date: Tue, 5 Sep 2023 22:13:36 -0400 Subject: [PATCH 0074/1174] Grammar fixes --- .github/ISSUE_TEMPLATE/bug_report.yml | 62 +++++++++++++-------------- 1 file changed, 31 insertions(+), 31 deletions(-) diff --git a/.github/ISSUE_TEMPLATE/bug_report.yml b/.github/ISSUE_TEMPLATE/bug_report.yml index 70c2e16058b..a53db9f0ef7 100644 --- a/.github/ISSUE_TEMPLATE/bug_report.yml +++ b/.github/ISSUE_TEMPLATE/bug_report.yml @@ -1,5 +1,5 @@ name: Bug Report -description: You think somethings is broken in the UI +description: You think something is broken in the UI title: "[Bug]: " labels: ["bug-report"] @@ -7,9 +7,9 @@ body: - type: markdown attributes: value: | - > The title the bug report should be short and descriptive - > Use relevant keywords for searchability - > Don't leave it blank but also don't put the entire error log in it + > The title of the bug report should be short and descriptive. + > Use relevant keywords for searchability. + > Do not leave it blank, but also do not put an entire error log in it. - type: checkboxes attributes: label: Checklist @@ -17,32 +17,32 @@ body: Please perform basic debugging to see if extensions or configuration is the cause of the issue. Basic debug procedure  1. Disable all third-party extensions - check if extension is the cause -  2. Update extensions and webui - sometimes thing just need to be updated -  3. Backup and remove your config.json and ui-config.json - check if the issue is caused bed configuration -  4. delete venv with third-party extensions disable - sometimes extensions might cause wrong libraries to be installed -  5. try a fresh installation webui in a different directory - see if a clean installation solves the issue - Before making a issue report please check that the issue hasn't been reported recently +  2. Update extensions and webui - sometimes things just need to be updated +  3. Backup and remove your config.json and ui-config.json - check if the issue is caused by bad configuration +  4. Delete venv with third-party extensions disabled - sometimes extensions might cause wrong libraries to be installed +  5. Try a fresh installation webui in a different directory - see if a clean installation solves the issue + Before making a issue report please, check that the issue hasn't been reported recently. options: - - label: The issue exist after disabling all extensions - - label: The issue exist on a clean installation of webui - - label: The issue is caused by an extension but it is caused by a bug in webui - - label: The issue exist in current version of webui - - label: The issue haven't been reported before recently - - label: The issue has been reported before but hasn't been fixed yet + - label: The issue exists after disabling all extensions + - label: The issue exists on a clean installation of webui + - label: The issue is caused by an extension, but I believe it is caused by a bug in the webui + - label: The issue exists in the current version of the webui + - label: The issue has not been reported before recently + - label: The issue has been reported before but has not been fixed yet - type: markdown attributes: value: | - > Please fill this form with as much information as possible, don't forget to "Upload Sysinfo" and "What browsers" and provide screenshots if possible + > Please fill this form with as much information as possible. Don't forget to "Upload Sysinfo" and "What browsers" and provide screenshots if possible - type: textarea id: what-did attributes: label: What happened? description: Tell us what happened in a very clear and simple way placeholder: | - I tried to use txt2img with XYZ grid with Sampler DPM++ SDE,DPM++ 2M SDE - it should generate a grid of 2 images but I only got 1 + I tried to use txt2img with the XYZ grid script, with DPM++ SDE, DPM++ 2M SDE samplers. + It should generate a grid of 2 images but I only got 1. - add screenshot or screen recording if necessary + (add screenshot or screen recording if necessary) validations: required: true - type: textarea @@ -51,9 +51,9 @@ body: label: Steps to reproduce the problem description: Please provide us with precise step by step instructions on how to reproduce the bug placeholder: | - 1. Go to txt2img tab Select XYZ grid - 2. Set axis type Sampler and select DPM++ 2M SDE, DPM++ 3M SDE - 3. Set Sampling steps to 1 and click Generate button + 1. Go to txt2img tab, select XYZ grid script + 2. Set axis type to `Sampler`, and select DPM++ 2M SDE, DPM++ 3M SDE + 3. Set `Sampling steps` to 1, click generate button validations: required: true - type: textarea @@ -62,8 +62,8 @@ body: label: What should have happened? description: Tell us what you think the normal behavior should be placeholder: | - It should generate a grid of 2 images - this was working in webui version 1.x.x + It should generate a grid of 2 images. + This was working in webui version 1.x.x validations: required: true - type: dropdown @@ -86,15 +86,15 @@ body: label: Sysinfo description: System info file, generated by WebUI. You can generate it in settings, on the Sysinfo page. Drag the file into the field to upload it. If you submit your report without including the sysinfo file, the report will be closed. If needed, review the report to make sure it includes no personal information you don't want to share. If you can't start WebUI, you can use --dump-sysinfo commandline argument to generate the file. placeholder: | - Upload the Sysinfo as a attached file - Don't paste it in as text + Upload the Sysinfo as a attached file. + Do not paste it in as text. validations: required: true - type: textarea id: logs attributes: label: Console logs - description: Please provide **full** cmd/terminal logs from the moment you started UI to the end of it, after your bug happened. If it's very long, provide a link to pastebin or similar service. + description: Please provide **full** cmd/terminal logs from the moment you started UI to the end of it, after the bug occured. If it's very long, provide a link to pastebin or similar service. render: Shell placeholder: | generating image for xyz plot: UnboundLocalError @@ -126,7 +126,7 @@ body: label: Additional information description: Please provide us with any relevant additional info or context. placeholder: | - Examples - I have updated the GPU driver recently - I suspect the issue is caused by XXXXX - I am using a VPN + Examples: + I have updated the GPU driver recently. + I suspect the issue is caused by XXXXX. + I am using a VPN. From 35d1c94549cf75e7e312372d90fee0acc2806426 Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Wed, 6 Sep 2023 20:24:26 +0900 Subject: [PATCH 0075/1174] save_images_add_number_suffix --- modules/images.py | 10 +++++++++- modules/shared_options.py | 2 +- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/modules/images.py b/modules/images.py index eb644733898..10dcd9ab854 100644 --- a/modules/images.py +++ b/modules/images.py @@ -661,7 +661,15 @@ def _atomically_save_image(image_to_save, filename_without_extension, extension) save_image_with_geninfo(image_to_save, info, temp_file_path, extension, existing_pnginfo=params.pnginfo, pnginfo_section_name=pnginfo_section_name) - os.replace(temp_file_path, filename_without_extension + extension) + full_file_name = filename_without_extension + extension + if shared.opts.save_images_add_number_suffix and os.path.exists(full_file_name): + count = 1 + while True: + full_file_name = f"{filename_without_extension}_{count}{extension}" + if not os.path.exists(full_file_name): + break + count += 1 + os.replace(temp_file_path, full_file_name) fullfn_without_extension, extension = os.path.splitext(params.filename) if hasattr(os, 'statvfs'): diff --git a/modules/shared_options.py b/modules/shared_options.py index 00b273faa54..2f4caa9dd53 100644 --- a/modules/shared_options.py +++ b/modules/shared_options.py @@ -26,7 +26,7 @@ "samples_format": OptionInfo('png', 'File format for images'), "samples_filename_pattern": OptionInfo("", "Images filename pattern", component_args=hide_dirs).link("wiki", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Custom-Images-Filename-Name-and-Subdirectory"), "save_images_add_number": OptionInfo(True, "Add number to filename when saving", component_args=hide_dirs), - + "save_images_add_number_suffix": OptionInfo(True, "Add number suffix when necessary", component_args=hide_dirs).info("prevent existing image from being override"), "grid_save": OptionInfo(True, "Always save all generated image grids"), "grid_format": OptionInfo('png', 'File format for grids'), "grid_extended_filename": OptionInfo(False, "Add extended info (seed, prompt) to filename when saving grid"), From 657404b75b2f214a97281afbec1adcb4313d24eb Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Wed, 6 Sep 2023 20:33:43 +0900 Subject: [PATCH 0076/1174] use original filename batch img2img with scripts --- modules/img2img.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/modules/img2img.py b/modules/img2img.py index 0c6d1af572c..c1cae22f812 100644 --- a/modules/img2img.py +++ b/modules/img2img.py @@ -117,15 +117,14 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=Fal if output_dir: p.outpath_samples = output_dir p.override_settings['save_to_dirs'] = False + if p.n_iter > 1 or p.batch_size > 1: + p.override_settings['samples_filename_pattern'] = f'{image_path.stem}-[generation_number]' + else: + p.override_settings['samples_filename_pattern'] = f'{image_path.stem}' proc = modules.scripts.scripts_img2img.run(p, *args) if proc is None: - if output_dir: - if p.n_iter > 1 or p.batch_size > 1: - p.override_settings['samples_filename_pattern'] = f'{image_path.stem}-[generation_number]' - else: - p.override_settings['samples_filename_pattern'] = f'{image_path.stem}' process_images(p) From 340fce2113b6d68f06f5bb8c897be998f03b4c8c Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Thu, 7 Sep 2023 10:01:16 +0900 Subject: [PATCH 0077/1174] enable console prompts in settings --- modules/img2img.py | 2 +- modules/shared_options.py | 1 + modules/txt2img.py | 2 +- 3 files changed, 3 insertions(+), 2 deletions(-) diff --git a/modules/img2img.py b/modules/img2img.py index c81c7ab9e8c..cbd80baccf7 100644 --- a/modules/img2img.py +++ b/modules/img2img.py @@ -199,7 +199,7 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s p.user = request.username - if shared.cmd_opts.enable_console_prompts: + if shared.opts.enable_console_prompts or shared.cmd_opts.enable_console_prompts: print(f"\nimg2img: {prompt}", file=shared.progress_print_out) if mask: diff --git a/modules/shared_options.py b/modules/shared_options.py index 00b273faa54..44fb1670544 100644 --- a/modules/shared_options.py +++ b/modules/shared_options.py @@ -100,6 +100,7 @@ options_templates.update(options_section(('system', "System"), { "auto_launch_browser": OptionInfo("Local", "Automatically open webui in browser on startup", gr.Radio, lambda: {"choices": ["Disable", "Local", "Remote"]}), + "enable_console_prompts": OptionInfo(False, "Print prompts to console when generating with txt2img and img2img."), "show_warnings": OptionInfo(False, "Show warnings in console.").needs_reload_ui(), "show_gradio_deprecation_warnings": OptionInfo(True, "Show gradio deprecation warnings in console.").needs_reload_ui(), "memmon_poll_rate": OptionInfo(8, "VRAM usage polls per second during generation.", gr.Slider, {"minimum": 0, "maximum": 40, "step": 1}).info("0 = disable"), diff --git a/modules/txt2img.py b/modules/txt2img.py index 1ee592ad944..379ef85967b 100644 --- a/modules/txt2img.py +++ b/modules/txt2img.py @@ -45,7 +45,7 @@ def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, step p.user = request.username - if cmd_opts.enable_console_prompts: + if shared.opts.enable_console_prompts or cmd_opts.enable_console_prompts: print(f"\ntxt2img: {prompt}", file=shared.progress_print_out) with closing(p): From 45881703c5b1c0499406a76fa49ec7bd408a4898 Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Thu, 7 Sep 2023 12:11:36 +0900 Subject: [PATCH 0078/1174] consolidated allowed preview formats --- modules/ui_extra_networks.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index 063bd7b80e6..2e816254d0b 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -15,6 +15,11 @@ extra_pages = [] allowed_dirs = set() +allowed_preview_extensions = ["png", "jpg", "jpeg", "webp", "gif"] +if shared.opts.samples_format not in allowed_preview_extensions: + allowed_preview_extensions.append(shared.opts.samples_format) +allowed_preview_extensions_dot = ['.' + extension for extension in allowed_preview_extensions] + def register_page(page): """registers extra networks page for the UI; recommend doing it in on_before_ui() callback for extensions""" @@ -34,7 +39,7 @@ def fetch_file(filename: str = ""): raise ValueError(f"File cannot be fetched: {filename}. Must be in one of directories registered by extra pages.") ext = os.path.splitext(filename)[1].lower() - if ext not in (".png", ".jpg", ".jpeg", ".webp", ".gif"): + if ext not in allowed_preview_extensions_dot: raise ValueError(f"File cannot be fetched: {filename}. Only png, jpg, webp, and gif.") # would profit from returning 304 @@ -273,11 +278,7 @@ def find_preview(self, path): Find a preview PNG for a given path (without extension) and call link_preview on it. """ - preview_extensions = ["png", "jpg", "jpeg", "webp"] - if shared.opts.samples_format not in preview_extensions: - preview_extensions.append(shared.opts.samples_format) - - potential_files = sum([[path + "." + ext, path + ".preview." + ext] for ext in preview_extensions], []) + potential_files = sum([[path + "." + ext, path + ".preview." + ext] for ext in allowed_preview_extensions], []) for file in potential_files: if os.path.isfile(file): From c3d51fc696bbe5f9ea2de63933234c90c55afbbd Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Thu, 7 Sep 2023 19:35:55 +0900 Subject: [PATCH 0079/1174] Update bug_report.yml --- .github/ISSUE_TEMPLATE/bug_report.yml | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/.github/ISSUE_TEMPLATE/bug_report.yml b/.github/ISSUE_TEMPLATE/bug_report.yml index a53db9f0ef7..a423f052d16 100644 --- a/.github/ISSUE_TEMPLATE/bug_report.yml +++ b/.github/ISSUE_TEMPLATE/bug_report.yml @@ -86,8 +86,9 @@ body: label: Sysinfo description: System info file, generated by WebUI. You can generate it in settings, on the Sysinfo page. Drag the file into the field to upload it. If you submit your report without including the sysinfo file, the report will be closed. If needed, review the report to make sure it includes no personal information you don't want to share. If you can't start WebUI, you can use --dump-sysinfo commandline argument to generate the file. placeholder: | - Upload the Sysinfo as a attached file. - Do not paste it in as text. + 1. Go to WebUI Settings -> Sysinfo -> Download system info. + If WebUI fails to launch, use --dump-sysinfo commandline argument to generate the file + 2. Upload the Sysinfo as a attached file, Do NOT paste it in as plain text. validations: required: true - type: textarea From f11eec81e31bfc9195bbacda13b2a3ce7b98fd92 Mon Sep 17 00:00:00 2001 From: ibrainventures Date: Thu, 7 Sep 2023 23:19:52 +0200 Subject: [PATCH 0080/1174] (feat) Include Program Version in info response. Update processing.py This would help to organize / memorize the program version for the creation process. (as it is also unformated included inside the infotext). --- modules/processing.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/modules/processing.py b/modules/processing.py index e124e7f0dd2..0c19142868e 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -533,6 +533,7 @@ def __init__(self, p: StableDiffusionProcessing, images_list, seed=-1, info="", self.all_seeds = all_seeds or p.all_seeds or [self.seed] self.all_subseeds = all_subseeds or p.all_subseeds or [self.subseed] self.infotexts = infotexts or [info] + self.version = program_version() def js(self): obj = { @@ -567,6 +568,7 @@ def js(self): "job_timestamp": self.job_timestamp, "clip_skip": self.clip_skip, "is_using_inpainting_conditioning": self.is_using_inpainting_conditioning, + "version": self.version, } return json.dumps(obj) From e4726cccf960257e1b456db84a59f28cea019c8f Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Fri, 8 Sep 2023 09:46:34 +0900 Subject: [PATCH 0081/1174] parsing string to path --- modules/sd_models.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/modules/sd_models.py b/modules/sd_models.py index 930d0bee5c8..9b0923de1f7 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -49,11 +49,12 @@ class CheckpointInfo: def __init__(self, filename): self.filename = filename abspath = os.path.abspath(filename) + abs_ckpt_dir = os.path.abspath(shared.cmd_opts.ckpt_dir) if shared.cmd_opts.ckpt_dir is not None else None self.is_safetensors = os.path.splitext(filename)[1].lower() == ".safetensors" - if shared.cmd_opts.ckpt_dir is not None and abspath.startswith(shared.cmd_opts.ckpt_dir): - name = abspath.replace(shared.cmd_opts.ckpt_dir, '') + if abs_ckpt_dir and abspath.startswith(abs_ckpt_dir): + name = abspath.replace(abs_ckpt_dir, '') elif abspath.startswith(model_path): name = abspath.replace(model_path, '') else: From 63485b2c55d2e5d1d5fc64d3964120a7305a9aee Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Fri, 8 Sep 2023 10:00:27 +0900 Subject: [PATCH 0082/1174] option use short name for checkpoint dropdown --- modules/shared_items.py | 4 ++-- modules/shared_options.py | 3 ++- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/modules/shared_items.py b/modules/shared_items.py index 84d69c8df43..b1459f8c4e4 100644 --- a/modules/shared_items.py +++ b/modules/shared_items.py @@ -44,9 +44,9 @@ def refresh_unet_list(): modules.sd_unet.list_unets() -def list_checkpoint_tiles(): +def list_checkpoint_tiles(use_short=False): import modules.sd_models - return modules.sd_models.checkpoint_tiles() + return modules.sd_models.checkpoint_tiles(use_short) def refresh_checkpoints(): diff --git a/modules/shared_options.py b/modules/shared_options.py index 00b273faa54..7f71c51745a 100644 --- a/modules/shared_options.py +++ b/modules/shared_options.py @@ -133,7 +133,7 @@ })) options_templates.update(options_section(('sd', "Stable Diffusion"), { - "sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": shared_items.list_checkpoint_tiles()}, refresh=shared_items.refresh_checkpoints, infotext='Model hash'), + "sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": shared_items.list_checkpoint_tiles(shared.opts.sd_checkpoint_dropdown_use_short)}, refresh=shared_items.refresh_checkpoints, infotext='Model hash'), "sd_checkpoints_limit": OptionInfo(1, "Maximum number of checkpoints loaded at the same time", gr.Slider, {"minimum": 1, "maximum": 10, "step": 1}), "sd_checkpoints_keep_in_cpu": OptionInfo(True, "Only keep one model on device").info("will keep models other than the currently used one in RAM rather than VRAM"), "sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}).info("obsolete; set to 0 and use the two settings above instead"), @@ -261,6 +261,7 @@ "ui_tab_order": OptionInfo([], "UI tab order", ui_components.DropdownMulti, lambda: {"choices": list(shared.tab_names)}).needs_reload_ui(), "hidden_tabs": OptionInfo([], "Hidden UI tabs", ui_components.DropdownMulti, lambda: {"choices": list(shared.tab_names)}).needs_reload_ui(), "ui_reorder_list": OptionInfo([], "txt2img/img2img UI item order", ui_components.DropdownMulti, lambda: {"choices": list(shared_items.ui_reorder_categories())}).info("selected items appear first").needs_reload_ui(), + "sd_checkpoint_dropdown_use_short": OptionInfo(False, "Use short name for Stable Diffusion checkpoint dropdown"), "hires_fix_show_sampler": OptionInfo(False, "Hires fix: show hires checkpoint and sampler selection").needs_reload_ui(), "hires_fix_show_prompts": OptionInfo(False, "Hires fix: show hires prompt and negative prompt").needs_reload_ui(), "disable_token_counters": OptionInfo(False, "Disable prompt token counters").needs_reload_ui(), From 259768f27fc4da61000610bc81a16f0152b36550 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sat, 9 Sep 2023 08:38:42 +0300 Subject: [PATCH 0083/1174] fix the bug in script-info API --- modules/scripts.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/modules/scripts.py b/modules/scripts.py index e8518ad0fba..f1f17a5f739 100644 --- a/modules/scripts.py +++ b/modules/scripts.py @@ -491,11 +491,15 @@ def create_script_ui(self, script): arg_info = api_models.ScriptArg(label=control.label or "") - for field in ("value", "minimum", "maximum", "step", "choices"): + for field in ("value", "minimum", "maximum", "step"): v = getattr(control, field, None) if v is not None: setattr(arg_info, field, v) + choices = getattr(control, 'choices', None) # as of gradio 3.41, some items in choices are strings, and some are tuples where the first elem is the string + if choices is not None: + setattr(arg_info, 'choices', [x[0] if isinstance(x, tuple) else x for x in choices]) + api_args.append(arg_info) script.api_info = api_models.ScriptInfo( From 7b44b85730d392733a285fe7e5c9e077f7bbccd3 Mon Sep 17 00:00:00 2001 From: ljleb Date: Sat, 9 Sep 2023 02:01:12 -0400 Subject: [PATCH 0084/1174] refact --- modules/ui.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/modules/ui.py b/modules/ui.py index f4028475661..0e78b6e1add 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1,4 +1,5 @@ import datetime +import functools import mimetypes import os import sys @@ -151,11 +152,14 @@ def connect_clear_prompt(button): ) -def update_token_counter(text, steps): +def update_token_counter(text, steps, is_positive=True): try: text, _ = extra_networks.parse_prompt(text) - _, prompt_flat_list, _ = prompt_parser.get_multicond_prompt_list([text]) + if is_positive: + _, prompt_flat_list, _ = prompt_parser.get_multicond_prompt_list([text]) + else: + prompt_flat_list = [text] prompt_schedules = prompt_parser.get_learned_conditioning_prompt_schedules(prompt_flat_list, steps) except Exception: @@ -533,7 +537,7 @@ def create_ui(): ] toprow.token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[toprow.prompt, steps], outputs=[toprow.token_counter]) - toprow.negative_token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[toprow.negative_prompt, steps], outputs=[toprow.negative_token_counter]) + toprow.negative_token_button.click(fn=wrap_queued_call(functools.partial(update_token_counter, is_positive=False)), inputs=[toprow.negative_prompt, steps], outputs=[toprow.negative_token_counter]) extra_networks_ui = ui_extra_networks.create_ui(txt2img_interface, [txt2img_generation_tab], 'txt2img') ui_extra_networks.setup_ui(extra_networks_ui, txt2img_gallery) From 3ca4655a18eb80cca5f806412f2cb2d56cc536e5 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sat, 9 Sep 2023 09:08:31 +0300 Subject: [PATCH 0085/1174] update for #12926 --- modules/images.py | 16 +++++++--------- modules/shared_options.py | 2 +- 2 files changed, 8 insertions(+), 10 deletions(-) diff --git a/modules/images.py b/modules/images.py index 10dcd9ab854..5cf3c825d7e 100644 --- a/modules/images.py +++ b/modules/images.py @@ -661,15 +661,13 @@ def _atomically_save_image(image_to_save, filename_without_extension, extension) save_image_with_geninfo(image_to_save, info, temp_file_path, extension, existing_pnginfo=params.pnginfo, pnginfo_section_name=pnginfo_section_name) - full_file_name = filename_without_extension + extension - if shared.opts.save_images_add_number_suffix and os.path.exists(full_file_name): - count = 1 - while True: - full_file_name = f"{filename_without_extension}_{count}{extension}" - if not os.path.exists(full_file_name): - break - count += 1 - os.replace(temp_file_path, full_file_name) + filename = filename_without_extension + extension + if shared.opts.save_images_replace_action != "Replace": + n = 0 + while os.path.exists(filename): + n += 1 + filename = f"{filename_without_extension}-{n}{extension}" + os.replace(temp_file_path, filename) fullfn_without_extension, extension = os.path.splitext(params.filename) if hasattr(os, 'statvfs'): diff --git a/modules/shared_options.py b/modules/shared_options.py index 2f4caa9dd53..1befb6eaeed 100644 --- a/modules/shared_options.py +++ b/modules/shared_options.py @@ -26,7 +26,7 @@ "samples_format": OptionInfo('png', 'File format for images'), "samples_filename_pattern": OptionInfo("", "Images filename pattern", component_args=hide_dirs).link("wiki", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Custom-Images-Filename-Name-and-Subdirectory"), "save_images_add_number": OptionInfo(True, "Add number to filename when saving", component_args=hide_dirs), - "save_images_add_number_suffix": OptionInfo(True, "Add number suffix when necessary", component_args=hide_dirs).info("prevent existing image from being override"), + "save_images_replace_action": OptionInfo("Replace", "Saving the image to an existing file", gr.Radio, {"choices": ["Replace", "Add number suffix"], **hide_dirs}), "grid_save": OptionInfo(True, "Always save all generated image grids"), "grid_format": OptionInfo('png', 'File format for grids'), "grid_extended_filename": OptionInfo(False, "Add extended info (seed, prompt) to filename when saving grid"), From 4c4d7dd01f77f021381a09cb18b4ca8a8b7734b1 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sat, 9 Sep 2023 09:15:09 +0300 Subject: [PATCH 0086/1174] fix whitespace for #13084 --- modules/hypernetworks/hypernetwork.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 65b63f2f81f..be3e4648486 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -468,7 +468,7 @@ def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, shared.reload_hypernetworks() -def train_hypernetwork(id_task, hypernetwork_name:str, learn_rate:float, batch_size:int, gradient_step:int, data_root:str, log_directory:str, training_width:int, training_height:int, varsize:bool, steps:int, clip_grad_mode:str, clip_grad_value:float, shuffle_tags:bool, tag_drop_out:bool, latent_sampling_method:str, use_weight:bool, create_image_every:int, save_hypernetwork_every:int, template_filename:str, preview_from_txt2img:bool, preview_prompt:str, preview_negative_prompt:str, preview_steps:int, preview_sampler_name:str, preview_cfg_scale:float, preview_seed:int, preview_width:int, preview_height:int): +def train_hypernetwork(id_task, hypernetwork_name: str, learn_rate: float, batch_size: int, gradient_step: int, data_root: str, log_directory: str, training_width: int, training_height: int, varsize: bool, steps: int, clip_grad_mode: str, clip_grad_value: float, shuffle_tags: bool, tag_drop_out: bool, latent_sampling_method: str, use_weight: bool, create_image_every: int, save_hypernetwork_every: int, template_filename: str, preview_from_txt2img: bool, preview_prompt: str, preview_negative_prompt: str, preview_steps: int, preview_sampler_name: str, preview_cfg_scale: float, preview_seed: int, preview_width: int, preview_height: int): from modules import images, processing save_hypernetwork_every = save_hypernetwork_every or 0 From 46375f059276cb2d4d1e47bf65f984c6466dc2a0 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sat, 9 Sep 2023 09:39:37 +0300 Subject: [PATCH 0087/1174] fix for crash when running #12924 without --device-id --- modules/devices.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/devices.py b/modules/devices.py index 63c38eff1f1..1d4eb563517 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -60,7 +60,7 @@ def enable_tf32(): # enabling benchmark option seems to enable a range of cards to do fp16 when they otherwise can't # see https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/4407 - device_id = (int(shared.cmd_opts.device_id) if shared.cmd_opts.device_id.isdigit() else 0) or torch.cuda.current_device() + device_id = (int(shared.cmd_opts.device_id) if shared.cmd_opts.device_id is not None and shared.cmd_opts.device_id.isdigit() else 0) or torch.cuda.current_device() if torch.cuda.get_device_capability(device_id) == (7, 5) and torch.cuda.get_device_name(device_id).startswith("NVIDIA GeForce GTX 16"): torch.backends.cudnn.benchmark = True From 46ef1857098df7610c36e73903731e486feca927 Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Sat, 9 Sep 2023 15:53:10 +0900 Subject: [PATCH 0088/1174] deprecate --enable-console-prompts use --enable-console-prompts as the default value for shared.opts.enable_console_prompts --- modules/cmd_args.py | 2 +- modules/img2img.py | 2 +- modules/shared_options.py | 2 +- modules/txt2img.py | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/modules/cmd_args.py b/modules/cmd_args.py index aab62286e24..fe4d4eccc5a 100644 --- a/modules/cmd_args.py +++ b/modules/cmd_args.py @@ -90,7 +90,7 @@ parser.add_argument("--theme", type=str, help="launches the UI with light or dark theme", default=None) parser.add_argument("--use-textbox-seed", action='store_true', help="use textbox for seeds in UI (no up/down, but possible to input long seeds)", default=False) parser.add_argument("--disable-console-progressbars", action='store_true', help="do not output progressbars to console", default=False) -parser.add_argument("--enable-console-prompts", action='store_true', help="print prompts to console when generating with txt2img and img2img", default=False) +parser.add_argument("--enable-console-prompts", action='store_true', help="does not do anything", default=False) # Legacy compatibility, use as default value shared.opts.enable_console_prompts parser.add_argument('--vae-path', type=str, help='Checkpoint to use as VAE; setting this argument disables all settings related to VAE', default=None) parser.add_argument("--disable-safe-unpickle", action='store_true', help="disable checking pytorch models for malicious code", default=False) parser.add_argument("--api", action='store_true', help="use api=True to launch the API together with the webui (use --nowebui instead for only the API)") diff --git a/modules/img2img.py b/modules/img2img.py index cbd80baccf7..72ee7bc22da 100644 --- a/modules/img2img.py +++ b/modules/img2img.py @@ -199,7 +199,7 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s p.user = request.username - if shared.opts.enable_console_prompts or shared.cmd_opts.enable_console_prompts: + if shared.opts.enable_console_prompts: print(f"\nimg2img: {prompt}", file=shared.progress_print_out) if mask: diff --git a/modules/shared_options.py b/modules/shared_options.py index 44fb1670544..1ea5c8f8e7a 100644 --- a/modules/shared_options.py +++ b/modules/shared_options.py @@ -100,7 +100,7 @@ options_templates.update(options_section(('system', "System"), { "auto_launch_browser": OptionInfo("Local", "Automatically open webui in browser on startup", gr.Radio, lambda: {"choices": ["Disable", "Local", "Remote"]}), - "enable_console_prompts": OptionInfo(False, "Print prompts to console when generating with txt2img and img2img."), + "enable_console_prompts": OptionInfo(shared.cmd_opts.enable_console_prompts, "Print prompts to console when generating with txt2img and img2img."), "show_warnings": OptionInfo(False, "Show warnings in console.").needs_reload_ui(), "show_gradio_deprecation_warnings": OptionInfo(True, "Show gradio deprecation warnings in console.").needs_reload_ui(), "memmon_poll_rate": OptionInfo(8, "VRAM usage polls per second during generation.", gr.Slider, {"minimum": 0, "maximum": 40, "step": 1}).info("0 = disable"), diff --git a/modules/txt2img.py b/modules/txt2img.py index 379ef85967b..721206dd9e1 100644 --- a/modules/txt2img.py +++ b/modules/txt2img.py @@ -45,7 +45,7 @@ def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, step p.user = request.username - if shared.opts.enable_console_prompts or cmd_opts.enable_console_prompts: + if shared.opts.enable_console_prompts: print(f"\ntxt2img: {prompt}", file=shared.progress_print_out) with closing(p): From c68aabc852151633016d3d5c84b433041f09d96e Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Sat, 9 Sep 2023 15:59:22 +0900 Subject: [PATCH 0089/1174] lint --- modules/txt2img.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/txt2img.py b/modules/txt2img.py index 721206dd9e1..e4e18ceb6dd 100644 --- a/modules/txt2img.py +++ b/modules/txt2img.py @@ -3,7 +3,7 @@ import modules.scripts from modules import processing from modules.generation_parameters_copypaste import create_override_settings_dict -from modules.shared import opts, cmd_opts +from modules.shared import opts import modules.shared as shared from modules.ui import plaintext_to_html import gradio as gr From 9cebe308e9f30f2a9555cf9fc43bf20652c4a619 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sat, 9 Sep 2023 10:20:06 +0300 Subject: [PATCH 0090/1174] return apply styles to main UI --- modules/ui.py | 2 ++ modules/ui_prompt_styles.py | 21 +++++++++++---------- 2 files changed, 13 insertions(+), 10 deletions(-) diff --git a/modules/ui.py b/modules/ui.py index b2aed7db3b2..06aa509b150 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -217,6 +217,7 @@ def __init__(self, is_img2img): with gr.Row(elem_id=f"{id_part}_tools"): self.paste = ToolButton(value=paste_symbol, elem_id="paste", tooltip="Read generation parameters from prompt or last generation if prompt is empty into user interface.") self.clear_prompt_button = ToolButton(value=clear_prompt_symbol, elem_id=f"{id_part}_clear_prompt", tooltip="Clear prompt") + self.apply_styles = ToolButton(value=ui_prompt_styles.styles_materialize_symbol, elem_id=f"{id_part}_style_apply", tooltip="Apply all selected styles to prompts.") self.restore_progress_button = ToolButton(value=restore_progress_symbol, elem_id=f"{id_part}_restore_progress", visible=False, tooltip="Restore progress") self.token_counter = gr.HTML(value="0/75", elem_id=f"{id_part}_token_counter", elem_classes=["token-counter"]) @@ -232,6 +233,7 @@ def __init__(self, is_img2img): ) self.ui_styles = ui_prompt_styles.UiPromptStyles(id_part, self.prompt, self.negative_prompt) + self.ui_styles.setup_apply_button(self.apply_styles) self.prompt_img.change( fn=modules.images.image_data, diff --git a/modules/ui_prompt_styles.py b/modules/ui_prompt_styles.py index 64d379ef65e..3bcf092fddc 100644 --- a/modules/ui_prompt_styles.py +++ b/modules/ui_prompt_styles.py @@ -53,6 +53,8 @@ def refresh_styles(): class UiPromptStyles: def __init__(self, tabname, main_ui_prompt, main_ui_negative_prompt): self.tabname = tabname + self.main_ui_prompt = main_ui_prompt + self.main_ui_negative_prompt = main_ui_negative_prompt with gr.Row(elem_id=f"{tabname}_styles_row"): self.dropdown = gr.Dropdown(label="Styles", show_label=False, elem_id=f"{tabname}_styles", choices=list(shared.prompt_styles.styles), value=[], multiselect=True, tooltip="Styles") @@ -62,7 +64,7 @@ def __init__(self, tabname, main_ui_prompt, main_ui_negative_prompt): with gr.Row(): self.selection = gr.Dropdown(label="Styles", elem_id=f"{tabname}_styles_edit_select", choices=list(shared.prompt_styles.styles), value=[], allow_custom_value=True, info="Styles allow you to add custom text to prompt. Use the {prompt} token in style text, and it will be replaced with user's prompt when applying style. Otherwise, style's text will be added to the end of the prompt.") ui_common.create_refresh_button([self.dropdown, self.selection], shared.prompt_styles.reload, lambda: {"choices": list(shared.prompt_styles.styles)}, f"refresh_{tabname}_styles") - self.materialize = ui_components.ToolButton(value=styles_materialize_symbol, elem_id=f"{tabname}_style_apply", tooltip="Apply all selected styles from the style selction dropdown in main UI to the prompt.") + self.materialize = ui_components.ToolButton(value=styles_materialize_symbol, elem_id=f"{tabname}_style_apply_dialog", tooltip="Apply all selected styles from the style selction dropdown in main UI to the prompt.") self.copy = ui_components.ToolButton(value=styles_copy_symbol, elem_id=f"{tabname}_style_copy", tooltip="Copy main UI prompt to style.") with gr.Row(): @@ -98,12 +100,7 @@ def __init__(self, tabname, main_ui_prompt, main_ui_negative_prompt): show_progress=False, ).then(refresh_styles, outputs=[self.dropdown, self.selection], show_progress=False) - self.materialize.click( - fn=materialize_styles, - inputs=[main_ui_prompt, main_ui_negative_prompt, self.dropdown], - outputs=[main_ui_prompt, main_ui_negative_prompt, self.dropdown], - show_progress=False, - ).then(fn=None, _js="function(){update_"+tabname+"_tokens(); closePopup();}", show_progress=False) + self.setup_apply_button(self.materialize) self.copy.click( fn=lambda p, n: (p, n), @@ -114,6 +111,10 @@ def __init__(self, tabname, main_ui_prompt, main_ui_negative_prompt): ui_common.setup_dialog(button_show=edit_button, dialog=styles_dialog, button_close=self.close) - - - + def setup_apply_button(self, button): + button.click( + fn=materialize_styles, + inputs=[self.main_ui_prompt, self.main_ui_negative_prompt, self.dropdown], + outputs=[self.main_ui_prompt, self.main_ui_negative_prompt, self.dropdown], + show_progress=False, + ).then(fn=None, _js="function(){update_"+self.tabname+"_tokens(); closePopup();}", show_progress=False) From 06af73bd1d17b417f234746c9a2f8e8e92cf6149 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sat, 9 Sep 2023 10:23:53 +0300 Subject: [PATCH 0091/1174] linter --- modules/scripts.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/scripts.py b/modules/scripts.py index f1f17a5f739..5c6e0226e2e 100644 --- a/modules/scripts.py +++ b/modules/scripts.py @@ -498,7 +498,7 @@ def create_script_ui(self, script): choices = getattr(control, 'choices', None) # as of gradio 3.41, some items in choices are strings, and some are tuples where the first elem is the string if choices is not None: - setattr(arg_info, 'choices', [x[0] if isinstance(x, tuple) else x for x in choices]) + arg_info.choices = [x[0] if isinstance(x, tuple) else x for x in choices] api_args.append(arg_info) From c9c457eda8dec414dc38e874691d1e2736d6dcbb Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sat, 9 Sep 2023 10:27:16 +0300 Subject: [PATCH 0092/1174] stylistic changes for #13118 --- modules/ui.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/modules/ui.py b/modules/ui.py index 5f0f1cd16c3..569dc807ced 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1,5 +1,4 @@ import datetime -import functools import mimetypes import os import sys @@ -152,7 +151,7 @@ def connect_clear_prompt(button): ) -def update_token_counter(text, steps, is_positive=True): +def update_token_counter(text, steps, *, is_positive=True): try: text, _ = extra_networks.parse_prompt(text) @@ -160,6 +159,7 @@ def update_token_counter(text, steps, is_positive=True): _, prompt_flat_list, _ = prompt_parser.get_multicond_prompt_list([text]) else: prompt_flat_list = [text] + prompt_schedules = prompt_parser.get_learned_conditioning_prompt_schedules(prompt_flat_list, steps) except Exception: @@ -173,6 +173,10 @@ def update_token_counter(text, steps, is_positive=True): return f"{token_count}/{max_length}" +def update_negative_prompt_token_counter(text, steps): + return update_token_counter(text, steps, is_positive=False) + + class Toprow: """Creates a top row UI with prompts, generate button, styles, extra little buttons for things, and enables some functionality related to their operation""" @@ -539,7 +543,7 @@ def create_ui(): ] toprow.token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[toprow.prompt, steps], outputs=[toprow.token_counter]) - toprow.negative_token_button.click(fn=wrap_queued_call(functools.partial(update_token_counter, is_positive=False)), inputs=[toprow.negative_prompt, steps], outputs=[toprow.negative_token_counter]) + toprow.negative_token_button.click(fn=wrap_queued_call(update_negative_prompt_token_counter), inputs=[toprow.negative_prompt, steps], outputs=[toprow.negative_token_counter]) extra_networks_ui = ui_extra_networks.create_ui(txt2img_interface, [txt2img_generation_tab], 'txt2img') ui_extra_networks.setup_ui(extra_networks_ui, txt2img_gallery) From 25de9a785cc9e93c16626db6ab5b16824443de53 Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Sat, 9 Sep 2023 16:56:19 +0900 Subject: [PATCH 0093/1174] Revert "thread safe extra network list_items" This reverts commit aab385d01b4311726127397552d791f4d71b7147. --- extensions-builtin/Lora/ui_extra_networks_lora.py | 10 +++++----- modules/ui_extra_networks.py | 2 -- modules/ui_extra_networks_checkpoints.py | 6 +++--- modules/ui_extra_networks_hypernets.py | 5 ++--- modules/ui_extra_networks_textual_inversion.py | 5 ++--- 5 files changed, 12 insertions(+), 16 deletions(-) diff --git a/extensions-builtin/Lora/ui_extra_networks_lora.py b/extensions-builtin/Lora/ui_extra_networks_lora.py index e9f300621ac..55409a7829d 100644 --- a/extensions-builtin/Lora/ui_extra_networks_lora.py +++ b/extensions-builtin/Lora/ui_extra_networks_lora.py @@ -66,11 +66,11 @@ def create_item(self, name, index=None, enable_filter=True): return item def list_items(self): - with self.thread_lock: - for index, name in enumerate(networks.available_networks): - item = self.create_item(name, index) - if item is not None: - yield item + for index, name in enumerate(networks.available_networks): + item = self.create_item(name, index) + + if item is not None: + yield item def allowed_directories_for_previews(self): return [shared.cmd_opts.lora_dir, shared.cmd_opts.lyco_dir_backcompat] diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index 564bab7fe02..063bd7b80e6 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -1,7 +1,6 @@ import os.path import urllib.parse from pathlib import Path -from threading import Lock from modules import shared, ui_extra_networks_user_metadata, errors, extra_networks from modules.images import read_info_from_image, save_image_with_geninfo @@ -95,7 +94,6 @@ def __init__(self, title): self.allow_negative_prompt = False self.metadata = {} self.items = {} - self.thread_lock = Lock() def refresh(self): pass diff --git a/modules/ui_extra_networks_checkpoints.py b/modules/ui_extra_networks_checkpoints.py index 2753214fa10..ca6c26076f9 100644 --- a/modules/ui_extra_networks_checkpoints.py +++ b/modules/ui_extra_networks_checkpoints.py @@ -30,9 +30,9 @@ def create_item(self, name, index=None, enable_filter=True): } def list_items(self): - with self.thread_lock: - for index, name in enumerate(sd_models.checkpoints_list): - yield self.create_item(name, index) + names = list(sd_models.checkpoints_list) + for index, name in enumerate(names): + yield self.create_item(name, index) def allowed_directories_for_previews(self): return [v for v in [shared.cmd_opts.ckpt_dir, sd_models.model_path] if v is not None] diff --git a/modules/ui_extra_networks_hypernets.py b/modules/ui_extra_networks_hypernets.py index 411b4f11145..4cedf085196 100644 --- a/modules/ui_extra_networks_hypernets.py +++ b/modules/ui_extra_networks_hypernets.py @@ -31,9 +31,8 @@ def create_item(self, name, index=None, enable_filter=True): } def list_items(self): - with self.thread_lock: - for index, name in enumerate(shared.hypernetworks): - yield self.create_item(name, index) + for index, name in enumerate(shared.hypernetworks): + yield self.create_item(name, index) def allowed_directories_for_previews(self): return [shared.cmd_opts.hypernetwork_dir] diff --git a/modules/ui_extra_networks_textual_inversion.py b/modules/ui_extra_networks_textual_inversion.py index d25b45d6133..55ef0ea7b54 100644 --- a/modules/ui_extra_networks_textual_inversion.py +++ b/modules/ui_extra_networks_textual_inversion.py @@ -29,9 +29,8 @@ def create_item(self, name, index=None, enable_filter=True): } def list_items(self): - with self.thread_lock: - for index, name in enumerate(sd_hijack.model_hijack.embedding_db.word_embeddings): - yield self.create_item(name, index) + for index, name in enumerate(sd_hijack.model_hijack.embedding_db.word_embeddings): + yield self.create_item(name, index) def allowed_directories_for_previews(self): return list(sd_hijack.model_hijack.embedding_db.embedding_dirs) From f5959c1c3022c454de22fab749d0f06ab3219868 Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Sat, 9 Sep 2023 17:05:50 +0900 Subject: [PATCH 0094/1174] thread safe extra network using list --- extensions-builtin/Lora/ui_extra_networks_lora.py | 3 ++- modules/ui_extra_networks_hypernets.py | 3 ++- modules/ui_extra_networks_textual_inversion.py | 3 ++- 3 files changed, 6 insertions(+), 3 deletions(-) diff --git a/extensions-builtin/Lora/ui_extra_networks_lora.py b/extensions-builtin/Lora/ui_extra_networks_lora.py index 55409a7829d..e74daa7705e 100644 --- a/extensions-builtin/Lora/ui_extra_networks_lora.py +++ b/extensions-builtin/Lora/ui_extra_networks_lora.py @@ -66,7 +66,8 @@ def create_item(self, name, index=None, enable_filter=True): return item def list_items(self): - for index, name in enumerate(networks.available_networks): + names = list(networks.available_networks) + for index, name in enumerate(names): item = self.create_item(name, index) if item is not None: diff --git a/modules/ui_extra_networks_hypernets.py b/modules/ui_extra_networks_hypernets.py index 4cedf085196..5f590491594 100644 --- a/modules/ui_extra_networks_hypernets.py +++ b/modules/ui_extra_networks_hypernets.py @@ -31,7 +31,8 @@ def create_item(self, name, index=None, enable_filter=True): } def list_items(self): - for index, name in enumerate(shared.hypernetworks): + names = list(shared.hypernetworks) + for index, name in enumerate(names): yield self.create_item(name, index) def allowed_directories_for_previews(self): diff --git a/modules/ui_extra_networks_textual_inversion.py b/modules/ui_extra_networks_textual_inversion.py index 55ef0ea7b54..40ab0aca3cf 100644 --- a/modules/ui_extra_networks_textual_inversion.py +++ b/modules/ui_extra_networks_textual_inversion.py @@ -29,7 +29,8 @@ def create_item(self, name, index=None, enable_filter=True): } def list_items(self): - for index, name in enumerate(sd_hijack.model_hijack.embedding_db.word_embeddings): + names = list(sd_hijack.model_hijack.embedding_db.word_embeddings) + for index, name in enumerate(names): yield self.create_item(name, index) def allowed_directories_for_previews(self): From f8042cb323f0b581e3a49880dd0023e39d7dcc2c Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Sat, 9 Sep 2023 22:35:07 +0900 Subject: [PATCH 0095/1174] Ensure not override images with script enabled --- modules/img2img.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/modules/img2img.py b/modules/img2img.py index 7ca10cf0795..52cb577a646 100644 --- a/modules/img2img.py +++ b/modules/img2img.py @@ -117,6 +117,7 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=Fal if output_dir: p.outpath_samples = output_dir p.override_settings['save_to_dirs'] = False + p.override_settings['save_images_replace_action'] = "Add number suffix" if p.n_iter > 1 or p.batch_size > 1: p.override_settings['samples_filename_pattern'] = f'{image_path.stem}-[generation_number]' else: @@ -125,6 +126,7 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=Fal proc = modules.scripts.scripts_img2img.run(p, *args) if proc is None: + p.override_settings.pop('save_images_replace_action', None) process_images(p) From ab5741717546758c57cf6c2a040645ec2b44690a Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Sat, 9 Sep 2023 22:35:50 +0900 Subject: [PATCH 0096/1174] prevent accessing non-existing keys --- modules/processing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/processing.py b/modules/processing.py index 0c19142868e..618f8abee04 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -711,7 +711,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed: if p.scripts is not None: p.scripts.before_process(p) - stored_opts = {k: opts.data[k] for k in p.override_settings.keys()} + stored_opts = {k: opts.data[k] for k in p.override_settings.keys() if k in opts.data} try: # if no checkpoint override or the override checkpoint can't be found, remove override entry and load opts checkpoint From d6478a60aa7c6f96a959ca6e3b9e8d51ad22d895 Mon Sep 17 00:00:00 2001 From: zixaphir Date: Sat, 9 Sep 2023 17:22:10 -0700 Subject: [PATCH 0097/1174] Remove extra network separator without regex --- javascript/extraNetworks.js | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/javascript/extraNetworks.js b/javascript/extraNetworks.js index ca87beade13..ff58d3dc8b1 100644 --- a/javascript/extraNetworks.js +++ b/javascript/extraNetworks.js @@ -140,23 +140,19 @@ function setupExtraNetworks() { onUiLoaded(setupExtraNetworks); -var re_extranet = /<([^:]+:[^:]+):[\d.]+>(.*)/; -var re_extranet_str = '<([^:]+:[^:]+):[\\d.]+>'; +var re_extranet = /<([^:^>]+:[^:]+):[\d.]+>(.*)/; +var re_extranet_g = /<([^:^>]+:[^:]+):[\d.]+>/g; function tryToRemoveExtraNetworkFromPrompt(textarea, text) { - function reEscape(s) { - return s.replace(/[.*+?^${}()|[\]\\]/g, "\\$&"); - } var m = text.match(re_extranet); var replaced = false; var newTextareaText; if (m) { + var extraTextBeforeNet = opts.extra_networks_add_text_separator; var extraTextAfterNet = m[2]; var partToSearch = m[1]; var foundAtPosition = -1; - var escapedSeparator = `(?:${reEscape(opts.extra_networks_add_text_separator)})?`; - var re = new RegExp(escapedSeparator + re_extranet_str, 'g'); - newTextareaText = textarea.value.replaceAll(re, function(found, net, pos) { + newTextareaText = textarea.value.replaceAll(re_extranet_g, function(found, net, pos) { m = found.match(re_extranet); if (m[1] == partToSearch) { replaced = true; @@ -166,8 +162,13 @@ function tryToRemoveExtraNetworkFromPrompt(textarea, text) { return found; }); - if (foundAtPosition >= 0 && newTextareaText.substr(foundAtPosition, extraTextAfterNet.length) == extraTextAfterNet) { - newTextareaText = newTextareaText.substr(0, foundAtPosition) + newTextareaText.substr(foundAtPosition + extraTextAfterNet.length); + if (foundAtPosition >= 0) { + if (newTextareaText.substr(foundAtPosition, extraTextAfterNet.length) == extraTextAfterNet) { + newTextareaText = newTextareaText.substr(0, foundAtPosition) + newTextareaText.substr(foundAtPosition + extraTextAfterNet.length); + } + if (newTextareaText.substr(foundAtPosition - extraTextBeforeNet.length, extraTextBeforeNet.length) == extraTextBeforeNet) { + newTextareaText = newTextareaText.substr(0, foundAtPosition - extraTextBeforeNet.length) + newTextareaText.substr(foundAtPosition); + } } } else { newTextareaText = textarea.value.replaceAll(new RegExp(text, "g"), function(found) { From 26d0d87f5b05c345abfec8eb0f8bd703cacd9619 Mon Sep 17 00:00:00 2001 From: zixaphir Date: Sat, 9 Sep 2023 17:26:46 -0700 Subject: [PATCH 0098/1174] Remove extra spaces --- javascript/extraNetworks.js | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/javascript/extraNetworks.js b/javascript/extraNetworks.js index ff58d3dc8b1..158b5b6478e 100644 --- a/javascript/extraNetworks.js +++ b/javascript/extraNetworks.js @@ -140,7 +140,7 @@ function setupExtraNetworks() { onUiLoaded(setupExtraNetworks); -var re_extranet = /<([^:^>]+:[^:]+):[\d.]+>(.*)/; +var re_extranet = /<([^:^>]+:[^:]+):[\d.]+>(.*)/; var re_extranet_g = /<([^:^>]+:[^:]+):[\d.]+>/g; function tryToRemoveExtraNetworkFromPrompt(textarea, text) { From 7d4d871d4679b5b78ff67b501da5367413542984 Mon Sep 17 00:00:00 2001 From: dongwenpu Date: Sun, 10 Sep 2023 17:53:42 +0800 Subject: [PATCH 0099/1174] fix: lora-bias-backup don't reset cache --- extensions-builtin/Lora/networks.py | 1 + 1 file changed, 1 insertion(+) diff --git a/extensions-builtin/Lora/networks.py b/extensions-builtin/Lora/networks.py index 96f935b236f..315682b31b0 100644 --- a/extensions-builtin/Lora/networks.py +++ b/extensions-builtin/Lora/networks.py @@ -418,6 +418,7 @@ def network_forward(module, input, original_forward): def network_reset_cached_weight(self: Union[torch.nn.Conv2d, torch.nn.Linear]): self.network_current_names = () self.network_weights_backup = None + self.network_bias_backup = None def network_Linear_forward(self, input): From 413123f08a745e9417fd384d2c1bee1e0e5e5730 Mon Sep 17 00:00:00 2001 From: liubo0902 <38622806+liubo0902@users.noreply.github.com> Date: Mon, 11 Sep 2023 09:22:27 +0800 Subject: [PATCH 0100/1174] Update localization.py --- modules/localization.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/modules/localization.py b/modules/localization.py index 262d49ee1c8..108f792e96f 100644 --- a/modules/localization.py +++ b/modules/localization.py @@ -14,12 +14,10 @@ def list_localizations(dirname): if ext.lower() != ".json": continue - fn = fn.replace(" ", "").replace("(", "_").replace(")","") localizations[fn] = [os.path.join(dirname, file)] for file in scripts.list_scripts("localizations", ".json"): fn, ext = os.path.splitext(file.filename) - fn = fn.replace(" ", "").replace("(", "_").replace(")","") if fn not in localizations: localizations[fn] = [] localizations[fn].append(file.path) From c485a7d12e26df1497309023d09cb3c106d7ae2b Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Mon, 11 Sep 2023 13:47:44 +0900 Subject: [PATCH 0101/1174] make InputAccordion work with ui-config --- modules/ui_loadsave.py | 22 +++++++++++++++------- 1 file changed, 15 insertions(+), 7 deletions(-) diff --git a/modules/ui_loadsave.py b/modules/ui_loadsave.py index ec8fa8e89e3..eb20ff258a4 100644 --- a/modules/ui_loadsave.py +++ b/modules/ui_loadsave.py @@ -4,7 +4,7 @@ import gradio as gr from modules import errors -from modules.ui_components import ToolButton +from modules.ui_components import ToolButton, InputAccordion def radio_choices(comp): # gradio 3.41 changes choices from list of values to list of pairs @@ -32,8 +32,6 @@ def __init__(self, filename): self.error_loading = True errors.display(e, "loading settings") - - def add_component(self, path, x): """adds component to the registry of tracked components""" @@ -43,20 +41,24 @@ def apply_field(obj, field, condition=None, init_field=None): key = f"{path}/{field}" if getattr(obj, 'custom_script_source', None) is not None: - key = f"customscript/{obj.custom_script_source}/{key}" + key = f"customscript/{obj.custom_script_source}/{key}" if getattr(obj, 'do_not_save_to_config', False): return saved_value = self.ui_settings.get(key, None) + + if isinstance(obj, gr.Accordion) and isinstance(x, InputAccordion) and field == 'value': + field = 'open' + if saved_value is None: self.ui_settings[key] = getattr(obj, field) elif condition and not condition(saved_value): pass else: - if isinstance(x, gr.Textbox) and field == 'value': # due to an undesirable behavior of gr.Textbox, if you give it an int value instead of str, everything dies + if isinstance(obj, gr.Textbox) and field == 'value': # due to an undesirable behavior of gr.Textbox, if you give it an int value instead of str, everything dies saved_value = str(saved_value) - elif isinstance(x, gr.Number) and field == 'value': + elif isinstance(obj, gr.Number) and field == 'value': try: saved_value = float(saved_value) except ValueError: @@ -67,7 +69,7 @@ def apply_field(obj, field, condition=None, init_field=None): init_field(saved_value) if field == 'value' and key not in self.component_mapping: - self.component_mapping[key] = x + self.component_mapping[key] = obj if type(x) in [gr.Slider, gr.Radio, gr.Checkbox, gr.Textbox, gr.Number, gr.Dropdown, ToolButton, gr.Button] and x.visible: apply_field(x, 'visible') @@ -100,6 +102,12 @@ def check_dropdown(val): apply_field(x, 'value', check_dropdown, getattr(x, 'init_field', None)) + if type(x) == InputAccordion: + if x.accordion.visible: + apply_field(x.accordion, 'visible') + apply_field(x, 'value') + apply_field(x.accordion, 'value') + def check_tab_id(tab_id): tab_items = list(filter(lambda e: isinstance(e, gr.TabItem), x.children)) if type(tab_id) == str: From e785402b6acca12108e15224ff80d58817ab3c27 Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Sat, 9 Sep 2023 17:28:06 +0900 Subject: [PATCH 0102/1174] return nothing if not found --- extensions-builtin/Lora/ui_extra_networks_lora.py | 3 ++- modules/ui_extra_networks_checkpoints.py | 7 ++++++- modules/ui_extra_networks_hypernets.py | 9 +++++++-- modules/ui_extra_networks_textual_inversion.py | 6 +++++- 4 files changed, 20 insertions(+), 5 deletions(-) diff --git a/extensions-builtin/Lora/ui_extra_networks_lora.py b/extensions-builtin/Lora/ui_extra_networks_lora.py index e74daa7705e..dac90a86c02 100644 --- a/extensions-builtin/Lora/ui_extra_networks_lora.py +++ b/extensions-builtin/Lora/ui_extra_networks_lora.py @@ -17,6 +17,8 @@ def refresh(self): def create_item(self, name, index=None, enable_filter=True): lora_on_disk = networks.available_networks.get(name) + if lora_on_disk is None: + return path, ext = os.path.splitext(lora_on_disk.filename) @@ -69,7 +71,6 @@ def list_items(self): names = list(networks.available_networks) for index, name in enumerate(names): item = self.create_item(name, index) - if item is not None: yield item diff --git a/modules/ui_extra_networks_checkpoints.py b/modules/ui_extra_networks_checkpoints.py index ca6c26076f9..35e958a0090 100644 --- a/modules/ui_extra_networks_checkpoints.py +++ b/modules/ui_extra_networks_checkpoints.py @@ -15,6 +15,9 @@ def refresh(self): def create_item(self, name, index=None, enable_filter=True): checkpoint: sd_models.CheckpointInfo = sd_models.checkpoint_aliases.get(name) + if checkpoint is None: + return + path, ext = os.path.splitext(checkpoint.filename) return { "name": checkpoint.name_for_extra, @@ -32,7 +35,9 @@ def create_item(self, name, index=None, enable_filter=True): def list_items(self): names = list(sd_models.checkpoints_list) for index, name in enumerate(names): - yield self.create_item(name, index) + item = self.create_item(name, index) + if item is not None: + yield item def allowed_directories_for_previews(self): return [v for v in [shared.cmd_opts.ckpt_dir, sd_models.model_path] if v is not None] diff --git a/modules/ui_extra_networks_hypernets.py b/modules/ui_extra_networks_hypernets.py index 5f590491594..74f7d8472fc 100644 --- a/modules/ui_extra_networks_hypernets.py +++ b/modules/ui_extra_networks_hypernets.py @@ -13,7 +13,10 @@ def refresh(self): shared.reload_hypernetworks() def create_item(self, name, index=None, enable_filter=True): - full_path = shared.hypernetworks[name] + full_path = shared.hypernetworks.get(name) + if full_path is None: + return + path, ext = os.path.splitext(full_path) sha256 = sha256_from_cache(full_path, f'hypernet/{name}') shorthash = sha256[0:10] if sha256 else None @@ -33,7 +36,9 @@ def create_item(self, name, index=None, enable_filter=True): def list_items(self): names = list(shared.hypernetworks) for index, name in enumerate(names): - yield self.create_item(name, index) + item = self.create_item(name, index) + if item is not None: + yield item def allowed_directories_for_previews(self): return [shared.cmd_opts.hypernetwork_dir] diff --git a/modules/ui_extra_networks_textual_inversion.py b/modules/ui_extra_networks_textual_inversion.py index 40ab0aca3cf..71c38fabc1a 100644 --- a/modules/ui_extra_networks_textual_inversion.py +++ b/modules/ui_extra_networks_textual_inversion.py @@ -14,6 +14,8 @@ def refresh(self): def create_item(self, name, index=None, enable_filter=True): embedding = sd_hijack.model_hijack.embedding_db.word_embeddings.get(name) + if embedding is None: + return path, ext = os.path.splitext(embedding.filename) return { @@ -31,7 +33,9 @@ def create_item(self, name, index=None, enable_filter=True): def list_items(self): names = list(sd_hijack.model_hijack.embedding_db.word_embeddings) for index, name in enumerate(names): - yield self.create_item(name, index) + item = self.create_item(name, index) + if item is not None: + yield item def allowed_directories_for_previews(self): return list(sd_hijack.model_hijack.embedding_db.embedding_dirs) From 59544321aa019d71d220b1da1eec703aa44fa8eb Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Mon, 11 Sep 2023 21:17:28 +0300 Subject: [PATCH 0103/1174] initial work on sd_unet for SDXL --- modules/sd_hijack.py | 17 ++++++++++++----- modules/sd_unet.py | 4 ++-- 2 files changed, 14 insertions(+), 7 deletions(-) diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 592f00551f1..22a1eb5ca7d 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -2,7 +2,7 @@ from torch.nn.functional import silu from types import MethodType -from modules import devices, sd_hijack_optimizations, shared, script_callbacks, errors, sd_unet +from modules import devices, sd_hijack_optimizations, shared, script_callbacks, errors, sd_unet, patches from modules.hypernetworks import hypernetwork from modules.shared import cmd_opts from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr @@ -10,6 +10,7 @@ import ldm.modules.attention import ldm.modules.diffusionmodules.model import ldm.modules.diffusionmodules.openaimodel +import ldm.models.diffusion.ddpm import ldm.models.diffusion.ddim import ldm.models.diffusion.plms import ldm.modules.encoders.modules @@ -37,6 +38,8 @@ optimizers = [] current_optimizer: sd_hijack_optimizations.SdOptimization = None +ldm_original_forward = patches.patch(__file__, ldm.modules.diffusionmodules.openaimodel.UNetModel, "forward", sd_unet.UNetModel_forward) +sgm_original_forward = patches.patch(__file__, sgm.modules.diffusionmodules.openaimodel.UNetModel, "forward", sd_unet.UNetModel_forward) def list_optimizers(): new_optimizers = script_callbacks.list_optimizers_callback() @@ -239,10 +242,13 @@ def flatten(el): self.layers = flatten(m) - if not hasattr(ldm.modules.diffusionmodules.openaimodel, 'copy_of_UNetModel_forward_for_webui'): - ldm.modules.diffusionmodules.openaimodel.copy_of_UNetModel_forward_for_webui = ldm.modules.diffusionmodules.openaimodel.UNetModel.forward + if isinstance(m, ldm.models.diffusion.ddpm.LatentDiffusion): + sd_unet.original_forward = ldm_original_forward + elif isinstance(m, sgm.models.diffusion.DiffusionEngine): + sd_unet.original_forward = sgm_original_forward + else: + sd_unet.original_forward = None - ldm.modules.diffusionmodules.openaimodel.UNetModel.forward = sd_unet.UNetModel_forward def undo_hijack(self, m): conditioner = getattr(m, 'conditioner', None) @@ -279,7 +285,8 @@ def undo_hijack(self, m): self.layers = None self.clip = None - ldm.modules.diffusionmodules.openaimodel.UNetModel.forward = ldm.modules.diffusionmodules.openaimodel.copy_of_UNetModel_forward_for_webui + sd_unet.original_forward = None + def apply_circular(self, enable): if self.circular_enabled == enable: diff --git a/modules/sd_unet.py b/modules/sd_unet.py index 5525cfbc3a0..6a7bc9e2646 100644 --- a/modules/sd_unet.py +++ b/modules/sd_unet.py @@ -1,11 +1,11 @@ import torch.nn -import ldm.modules.diffusionmodules.openaimodel from modules import script_callbacks, shared, devices unet_options = [] current_unet_option = None current_unet = None +original_forward = None def list_unets(): @@ -88,5 +88,5 @@ def UNetModel_forward(self, x, timesteps=None, context=None, *args, **kwargs): if current_unet is not None: return current_unet.forward(x, timesteps, context, *args, **kwargs) - return ldm.modules.diffusionmodules.openaimodel.copy_of_UNetModel_forward_for_webui(self, x, timesteps, context, *args, **kwargs) + return original_forward(self, x, timesteps, context, *args, **kwargs) From 74b80e72115af46bf1c04167a30f9ec5025cb464 Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Tue, 12 Sep 2023 09:29:07 +0900 Subject: [PATCH 0104/1174] add comment --- extensions-builtin/Lora/ui_extra_networks_lora.py | 1 + modules/ui_extra_networks_checkpoints.py | 1 + modules/ui_extra_networks_hypernets.py | 1 + modules/ui_extra_networks_textual_inversion.py | 1 + 4 files changed, 4 insertions(+) diff --git a/extensions-builtin/Lora/ui_extra_networks_lora.py b/extensions-builtin/Lora/ui_extra_networks_lora.py index dac90a86c02..df02c663b12 100644 --- a/extensions-builtin/Lora/ui_extra_networks_lora.py +++ b/extensions-builtin/Lora/ui_extra_networks_lora.py @@ -68,6 +68,7 @@ def create_item(self, name, index=None, enable_filter=True): return item def list_items(self): + # instantiate a list to protect against concurrent modification names = list(networks.available_networks) for index, name in enumerate(names): item = self.create_item(name, index) diff --git a/modules/ui_extra_networks_checkpoints.py b/modules/ui_extra_networks_checkpoints.py index 35e958a0090..df7efb2e1fb 100644 --- a/modules/ui_extra_networks_checkpoints.py +++ b/modules/ui_extra_networks_checkpoints.py @@ -33,6 +33,7 @@ def create_item(self, name, index=None, enable_filter=True): } def list_items(self): + # instantiate a list to protect against concurrent modification names = list(sd_models.checkpoints_list) for index, name in enumerate(names): item = self.create_item(name, index) diff --git a/modules/ui_extra_networks_hypernets.py b/modules/ui_extra_networks_hypernets.py index 74f7d8472fc..c96c4fa3b12 100644 --- a/modules/ui_extra_networks_hypernets.py +++ b/modules/ui_extra_networks_hypernets.py @@ -34,6 +34,7 @@ def create_item(self, name, index=None, enable_filter=True): } def list_items(self): + # instantiate a list to protect against concurrent modification names = list(shared.hypernetworks) for index, name in enumerate(names): item = self.create_item(name, index) diff --git a/modules/ui_extra_networks_textual_inversion.py b/modules/ui_extra_networks_textual_inversion.py index 71c38fabc1a..1b334fda174 100644 --- a/modules/ui_extra_networks_textual_inversion.py +++ b/modules/ui_extra_networks_textual_inversion.py @@ -31,6 +31,7 @@ def create_item(self, name, index=None, enable_filter=True): } def list_items(self): + # instantiate a list to protect against concurrent modification names = list(sd_hijack.model_hijack.embedding_db.word_embeddings) for index, name in enumerate(names): item = self.create_item(name, index) From 6fb2194d9cc2c9b52bc2006117d592283e00b7d6 Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Tue, 12 Sep 2023 16:50:56 +0900 Subject: [PATCH 0105/1174] fetch version info when webui_dir is not work_dir --- modules/launch_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/modules/launch_utils.py b/modules/launch_utils.py index 6e54d06367c..8cdbafa50c8 100644 --- a/modules/launch_utils.py +++ b/modules/launch_utils.py @@ -64,7 +64,7 @@ def check_python_version(): @lru_cache() def commit_hash(): try: - return subprocess.check_output([git, "rev-parse", "HEAD"], shell=False, encoding='utf8').strip() + return subprocess.check_output([git, "-C", script_path, "rev-parse", "HEAD"], shell=False, encoding='utf8').strip() except Exception: return "" @@ -72,7 +72,7 @@ def commit_hash(): @lru_cache() def git_tag(): try: - return subprocess.check_output([git, "describe", "--tags"], shell=False, encoding='utf8').strip() + return subprocess.check_output([git, "-C", script_path, "describe", "--tags"], shell=False, encoding='utf8').strip() except Exception: try: From 93015964c7c920fbb834bf99977ab8e16296efac Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Tue, 12 Sep 2023 22:43:35 +0900 Subject: [PATCH 0106/1174] fix add_option overriding config with default --- modules/options.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/modules/options.py b/modules/options.py index e75916d2be2..ab40aff73f0 100644 --- a/modules/options.py +++ b/modules/options.py @@ -210,7 +210,8 @@ def dumpjson(self): def add_option(self, key, info): self.data_labels[key] = info - self.data[key] = info.default + if key not in self.data: + self.data[key] = info.default def reorder(self): """reorder settings so that all items related to section always go together""" From 5b761b49ade392eb8ff4c54ab1841b63475b5dd0 Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Wed, 13 Sep 2023 16:05:55 +0900 Subject: [PATCH 0107/1174] correct webpath when webui_dir is not work_dir --- modules/paths.py | 2 +- modules/paths_internal.py | 1 + modules/ui_gradio_extensions.py | 6 +++--- 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/modules/paths.py b/modules/paths.py index 2505233999b..187b949612a 100644 --- a/modules/paths.py +++ b/modules/paths.py @@ -1,6 +1,6 @@ import os import sys -from modules.paths_internal import models_path, script_path, data_path, extensions_dir, extensions_builtin_dir # noqa: F401 +from modules.paths_internal import models_path, script_path, data_path, extensions_dir, extensions_builtin_dir, cwd # noqa: F401 import modules.safe # noqa: F401 diff --git a/modules/paths_internal.py b/modules/paths_internal.py index 005a9b0aa75..89131a54fa1 100644 --- a/modules/paths_internal.py +++ b/modules/paths_internal.py @@ -8,6 +8,7 @@ commandline_args = os.environ.get('COMMANDLINE_ARGS', "") sys.argv += shlex.split(commandline_args) +cwd = os.getcwd() modules_path = os.path.dirname(os.path.realpath(__file__)) script_path = os.path.dirname(modules_path) diff --git a/modules/ui_gradio_extensions.py b/modules/ui_gradio_extensions.py index b824b113732..0d368f8b2c4 100644 --- a/modules/ui_gradio_extensions.py +++ b/modules/ui_gradio_extensions.py @@ -2,12 +2,12 @@ import gradio as gr from modules import localization, shared, scripts -from modules.paths import script_path, data_path +from modules.paths import script_path, data_path, cwd def webpath(fn): - if fn.startswith(script_path): - web_path = os.path.relpath(fn, script_path).replace('\\', '/') + if fn.startswith(cwd): + web_path = os.path.relpath(fn, cwd) else: web_path = os.path.abspath(fn) From cf1edc2b54d7ae8acec45ddba098957a6caa7867 Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Wed, 13 Sep 2023 16:27:02 +0900 Subject: [PATCH 0108/1174] initialize state.time_start befroe state.job_count --- modules/shared_state.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/shared_state.py b/modules/shared_state.py index d272ee5bc2c..a68789cc815 100644 --- a/modules/shared_state.py +++ b/modules/shared_state.py @@ -103,6 +103,7 @@ def dict(self): def begin(self, job: str = "(unknown)"): self.sampling_step = 0 + self.time_start = time.time() self.job_count = -1 self.processing_has_refined_job_count = False self.job_no = 0 @@ -114,7 +115,6 @@ def begin(self, job: str = "(unknown)"): self.skipped = False self.interrupted = False self.textinfo = None - self.time_start = time.time() self.job = job devices.torch_gc() log.info("Starting job %s", job) From 0ad38a9b87b7781315ea6324a8aa6c924d1275de Mon Sep 17 00:00:00 2001 From: Der Chien Date: Wed, 13 Sep 2023 20:20:01 +0800 Subject: [PATCH 0109/1174] 20230913 setup GIT_PYTHON_GIT_EXECUTABLE for GitPython --- webui.bat | 1 + webui.sh | 2 ++ 2 files changed, 3 insertions(+) diff --git a/webui.bat b/webui.bat index 42e7d517d18..a630ea4d9ff 100644 --- a/webui.bat +++ b/webui.bat @@ -1,6 +1,7 @@ @echo off if not defined PYTHON (set PYTHON=python) +if defined GIT (set "GIT_PYTHON_GIT_EXECUTABLE=%GIT%") if not defined VENV_DIR (set "VENV_DIR=%~dp0%venv") set SD_WEBUI_RESTART=tmp/restart diff --git a/webui.sh b/webui.sh index 3d0f87eed74..bdab3f0538d 100755 --- a/webui.sh +++ b/webui.sh @@ -51,6 +51,8 @@ fi if [[ -z "${GIT}" ]] then export GIT="git" +else + export GIT_PYTHON_GIT_EXECUTABLE="${GIT}" fi # python3 venv without trailing slash (defaults to ${install_dir}/${clone_dir}/venv) From ab3d3528a18ea1a81f1af22ea71bfc0d8c710dde Mon Sep 17 00:00:00 2001 From: Leon Date: Thu, 14 Sep 2023 18:42:56 +0800 Subject: [PATCH 0110/1174] add --skip-load-model-at-start --- modules/cmd_args.py | 1 + modules/initialize.py | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/modules/cmd_args.py b/modules/cmd_args.py index 5be879dd6f3..4e602a84270 100644 --- a/modules/cmd_args.py +++ b/modules/cmd_args.py @@ -117,3 +117,4 @@ parser.add_argument('--timeout-keep-alive', type=int, default=30, help='set timeout_keep_alive for uvicorn') parser.add_argument("--disable-all-extensions", action='store_true', help="prevent all extensions from running regardless of any other settings", default=False) parser.add_argument("--disable-extra-extensions", action='store_true', help="prevent all extensions except built-in from running regardless of any other settings", default=False) +parser.add_argument("--skip-load-model-at-start", action='store_true', help="if load a model at web start, only take effect when --nowebui", ) diff --git a/modules/initialize.py b/modules/initialize.py index f24f76375db..ac95fc6f00c 100644 --- a/modules/initialize.py +++ b/modules/initialize.py @@ -151,8 +151,8 @@ def load_model(): from modules import devices devices.first_time_calculation() - - Thread(target=load_model).start() + if not shared.cmd_opts.skip_load_model_at_start: + Thread(target=load_model).start() from modules import shared_items shared_items.reload_hypernetworks() From afd06245876004710007fa1abd0a1b4b2564c181 Mon Sep 17 00:00:00 2001 From: Won-Kyu Park Date: Fri, 15 Sep 2023 17:10:01 +0900 Subject: [PATCH 0111/1174] xyz_grid: add prepare option to AxisOption --- scripts/xyz_grid.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/scripts/xyz_grid.py b/scripts/xyz_grid.py index 939d86053bd..ce5a1a19d07 100644 --- a/scripts/xyz_grid.py +++ b/scripts/xyz_grid.py @@ -205,13 +205,14 @@ def csv_string_to_list_strip(data_str): class AxisOption: - def __init__(self, label, type, apply, format_value=format_value_add_label, confirm=None, cost=0.0, choices=None): + def __init__(self, label, type, apply, format_value=format_value_add_label, confirm=None, cost=0.0, choices=None, prepare=None): self.label = label self.type = type self.apply = apply self.format_value = format_value self.confirm = confirm self.cost = cost + self.prepare = prepare self.choices = choices @@ -536,6 +537,8 @@ def process_axis(opt, vals, vals_dropdown): if opt.choices is not None and not csv_mode: valslist = vals_dropdown + elif opt.prepare is not None: + valslist = opt.prepare(vals) else: valslist = csv_string_to_list_strip(vals) From 813535d38bbcdd8ccc51d0618a7d9fd353677bb9 Mon Sep 17 00:00:00 2001 From: "qiuwen.wang" Date: Fri, 15 Sep 2023 18:23:23 +0800 Subject: [PATCH 0112/1174] use dict[key]=model; did not update orderdict order, should use move to end --- modules/sd_models.py | 1 + 1 file changed, 1 insertion(+) diff --git a/modules/sd_models.py b/modules/sd_models.py index 930d0bee5c8..6d17dd3c7b6 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -309,6 +309,7 @@ def get_checkpoint_state_dict(checkpoint_info: CheckpointInfo, timer): if checkpoint_info in checkpoints_loaded: # use checkpoint cache print(f"Loading weights [{sd_model_hash}] from cache") + checkpoints_loaded.move_to_end(checkpoint_info) return checkpoints_loaded[checkpoint_info] print(f"Loading weights [{sd_model_hash}] from {checkpoint_info.filename}") From d9d94141dcfc1a84e98370bc137ffd888509b65e Mon Sep 17 00:00:00 2001 From: woweenie <145132974+woweenie@users.noreply.github.com> Date: Fri, 15 Sep 2023 18:59:44 +0200 Subject: [PATCH 0113/1174] patch DDPM.register_betas so that users can put given_betas in model yaml --- modules/sd_models.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/modules/sd_models.py b/modules/sd_models.py index 930d0bee5c8..8e4983a4629 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -7,7 +7,7 @@ import torch import re import safetensors.torch -from omegaconf import OmegaConf +from omegaconf import OmegaConf, ListConfig from os import mkdir from urllib import request import ldm.modules.midas as midas @@ -17,6 +17,7 @@ from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet, sd_models_xl, cache, extra_networks, processing, lowvram, sd_hijack from modules.timer import Timer import tomesd +import numpy as np model_dir = "Stable-diffusion" model_path = os.path.abspath(os.path.join(paths.models_path, model_dir)) @@ -132,6 +133,7 @@ def setup_model(): os.makedirs(model_path, exist_ok=True) enable_midas_autodownload() + patch_given_betas() def checkpoint_tiles(use_short=False): @@ -453,6 +455,17 @@ def load_model_wrapper(model_type): midas.api.load_model = load_model_wrapper +def patch_given_betas(): + original_register_schedule = ldm.models.diffusion.ddpm.DDPM.register_schedule + def patched_register_schedule(*args, **kwargs): + if args[1] is not None and isinstance(args[1], ListConfig): + modified_args = list(args) # Convert args tuple to a list + modified_args[1] = np.array(args[1]) # Modify the desired element + args = tuple(modified_args) # Convert the list back to a tuple + original_register_schedule(*args, **kwargs) + ldm.models.diffusion.ddpm.DDPM.register_schedule = patched_register_schedule + + def repair_config(sd_config): if not hasattr(sd_config.model.params, "use_ema"): From 663fb8797612b863b0f0b94496039ec2ac18701c Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Sat, 16 Sep 2023 09:05:42 +0900 Subject: [PATCH 0114/1174] Config states time ISO in system time zone --- modules/config_states.py | 3 +-- modules/ui_extensions.py | 2 +- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/modules/config_states.py b/modules/config_states.py index b766aef11d8..651793c7f6f 100644 --- a/modules/config_states.py +++ b/modules/config_states.py @@ -4,7 +4,6 @@ import os import json -import time import tqdm from datetime import datetime @@ -38,7 +37,7 @@ def list_config_states(): config_states = sorted(config_states, key=lambda cs: cs["created_at"], reverse=True) for cs in config_states: - timestamp = time.asctime(time.gmtime(cs["created_at"])) + timestamp = datetime.fromtimestamp(cs["created_at"]).strftime('%Y-%m-%d %H:%M:%S') name = cs.get("name", "Config") full_name = f"{name}: {timestamp}" all_config_states[full_name] = cs diff --git a/modules/ui_extensions.py b/modules/ui_extensions.py index 2e8c1d6d21d..c0a73b57380 100644 --- a/modules/ui_extensions.py +++ b/modules/ui_extensions.py @@ -197,7 +197,7 @@ def update_config_states_table(state_name): config_state = config_states.all_config_states[state_name] config_name = config_state.get("name", "Config") - created_date = time.asctime(time.gmtime(config_state["created_at"])) + created_date = datetime.fromtimestamp(config_state["created_at"]).strftime('%Y-%m-%d %H:%M:%S') filepath = config_state.get("filepath", "") try: From d2878a8b0b952c08d832e0308e575e352a1bc3f1 Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Sat, 16 Sep 2023 09:49:53 +0900 Subject: [PATCH 0115/1174] XYZ if not Include Sub Grids do not save Sub Grid --- scripts/xyz_grid.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/scripts/xyz_grid.py b/scripts/xyz_grid.py index 939d86053bd..99ad96be264 100644 --- a/scripts/xyz_grid.py +++ b/scripts/xyz_grid.py @@ -773,6 +773,8 @@ def cell(x, y, z, ix, iy, iz): # TODO: See previous comment about intentional data misalignment. adj_g = g-1 if g > 0 else g images.save_image(processed.images[g], p.outpath_grids, "xyz_grid", info=processed.infotexts[g], extension=opts.grid_format, prompt=processed.all_prompts[adj_g], seed=processed.all_seeds[adj_g], grid=True, p=processed) + if not include_sub_grids: # if not include_sub_grids then skip saving after the first grid + break if not include_sub_grids: # Done with sub-grids, drop all related information: From 701feabf496b7ce0327ccdb1ef1dc942deab25ea Mon Sep 17 00:00:00 2001 From: Zolxys Date: Sun, 17 Sep 2023 11:37:15 -0500 Subject: [PATCH 0116/1174] Fix: --sd_model in "Promts from file or textbox" script is not working Fix for bug report #8079 --- scripts/prompts_from_file.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/scripts/prompts_from_file.py b/scripts/prompts_from_file.py index 50320d553bd..ca73b2a56bf 100644 --- a/scripts/prompts_from_file.py +++ b/scripts/prompts_from_file.py @@ -5,11 +5,17 @@ import modules.scripts as scripts import gradio as gr -from modules import sd_samplers, errors +from modules import sd_samplers, errors, sd_models from modules.processing import Processed, process_images from modules.shared import state +def process_model_tag(tag): + info = sd_models.get_closet_checkpoint_match(tag) + assert info is not None, f'Unknown checkpoint: {tag}' + return info.name + + def process_string_tag(tag): return tag @@ -27,7 +33,7 @@ def process_boolean_tag(tag): prompt_tags = { - "sd_model": None, + "sd_model": process_model_tag, "outpath_samples": process_string_tag, "outpath_grids": process_string_tag, "prompt_for_display": process_string_tag, @@ -156,7 +162,10 @@ def run(self, p, checkbox_iterate, checkbox_iterate_batch, prompt_txt: str): copy_p = copy.copy(p) for k, v in args.items(): - setattr(copy_p, k, v) + if k == "sd_model": + copy_p.override_settings['sd_model_checkpoint'] = v + else: + setattr(copy_p, k, v) proc = process_images(copy_p) images += proc.images From 8e355fbd7552f1a7f5124c4685d6fa36f3d0ede1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8E=8B=E7=A7=8B=E6=96=87/qwwang?= Date: Mon, 18 Sep 2023 16:45:42 +0800 Subject: [PATCH 0117/1174] fix --- modules/sd_models.py | 1 + 1 file changed, 1 insertion(+) diff --git a/modules/sd_models.py b/modules/sd_models.py index 6d17dd3c7b6..eedb38c65ad 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -309,6 +309,7 @@ def get_checkpoint_state_dict(checkpoint_info: CheckpointInfo, timer): if checkpoint_info in checkpoints_loaded: # use checkpoint cache print(f"Loading weights [{sd_model_hash}] from cache") + # move to end as latest checkpoints_loaded.move_to_end(checkpoint_info) return checkpoints_loaded[checkpoint_info] From 702a1e1cc70240f2adbcfb707a644a5a98b5443c Mon Sep 17 00:00:00 2001 From: superhero-7 <537093830@qq.com> Date: Sat, 23 Sep 2023 17:51:41 +0800 Subject: [PATCH 0118/1174] support m18 --- configs/alt-diffusion-m18-inference.yaml | 73 ++++++++++ modules/sd_hijack.py | 6 +- modules/sd_models_config.py | 6 +- modules/xlmr_m18.py | 164 +++++++++++++++++++++++ 4 files changed, 244 insertions(+), 5 deletions(-) create mode 100644 configs/alt-diffusion-m18-inference.yaml create mode 100644 modules/xlmr_m18.py diff --git a/configs/alt-diffusion-m18-inference.yaml b/configs/alt-diffusion-m18-inference.yaml new file mode 100644 index 00000000000..41a031d55f0 --- /dev/null +++ b/configs/alt-diffusion-m18-inference.yaml @@ -0,0 +1,73 @@ +model: + base_learning_rate: 1.0e-04 + target: ldm.models.diffusion.ddpm.LatentDiffusion + params: + linear_start: 0.00085 + linear_end: 0.0120 + num_timesteps_cond: 1 + log_every_t: 200 + timesteps: 1000 + first_stage_key: "jpg" + cond_stage_key: "txt" + image_size: 64 + channels: 4 + cond_stage_trainable: false # Note: different from the one we trained before + conditioning_key: crossattn + monitor: val/loss_simple_ema + scale_factor: 0.18215 + use_ema: False + + scheduler_config: # 10000 warmup steps + target: ldm.lr_scheduler.LambdaLinearScheduler + params: + warm_up_steps: [ 10000 ] + cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases + f_start: [ 1.e-6 ] + f_max: [ 1. ] + f_min: [ 1. ] + + unet_config: + target: ldm.modules.diffusionmodules.openaimodel.UNetModel + params: + image_size: 32 # unused + in_channels: 4 + out_channels: 4 + model_channels: 320 + attention_resolutions: [ 4, 2, 1 ] + num_res_blocks: 2 + channel_mult: [ 1, 2, 4, 4 ] + num_head_channels: 64 + use_spatial_transformer: True + use_linear_in_transformer: True + transformer_depth: 1 + context_dim: 1024 + use_checkpoint: True + legacy: False + + first_stage_config: + target: ldm.models.autoencoder.AutoencoderKL + params: + embed_dim: 4 + monitor: val/rec_loss + ddconfig: + double_z: true + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + - 4 + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + + cond_stage_config: + target: modules.xlmr_m18.BertSeriesModelWithTransformation + params: + name: "XLMR-Large" diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 592f00551f1..ae9b2a656ac 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -5,7 +5,7 @@ from modules import devices, sd_hijack_optimizations, shared, script_callbacks, errors, sd_unet from modules.hypernetworks import hypernetwork from modules.shared import cmd_opts -from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr +from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr, xlmr_m18 import ldm.modules.attention import ldm.modules.diffusionmodules.model @@ -208,11 +208,10 @@ def hijack(self, m): else: m.cond_stage_model = conditioner - if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation: + if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation or type(m.cond_stage_model) == xlmr_m18.BertSeriesModelWithTransformation: model_embeddings = m.cond_stage_model.roberta.embeddings model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.word_embeddings, self) m.cond_stage_model = sd_hijack_xlmr.FrozenXLMREmbedderWithCustomWords(m.cond_stage_model, self) - elif type(m.cond_stage_model) == ldm.modules.encoders.modules.FrozenCLIPEmbedder: model_embeddings = m.cond_stage_model.transformer.text_model.embeddings model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self) @@ -258,7 +257,6 @@ def undo_hijack(self, m): if hasattr(m, 'cond_stage_model'): delattr(m, 'cond_stage_model') - elif type(m.cond_stage_model) == sd_hijack_xlmr.FrozenXLMREmbedderWithCustomWords: m.cond_stage_model = m.cond_stage_model.wrapped diff --git a/modules/sd_models_config.py b/modules/sd_models_config.py index 08dd03f19c7..9ba89dfc0eb 100644 --- a/modules/sd_models_config.py +++ b/modules/sd_models_config.py @@ -21,7 +21,7 @@ config_inpainting = os.path.join(sd_configs_path, "v1-inpainting-inference.yaml") config_instruct_pix2pix = os.path.join(sd_configs_path, "instruct-pix2pix.yaml") config_alt_diffusion = os.path.join(sd_configs_path, "alt-diffusion-inference.yaml") - +config_alt_diffusion_m18 = os.path.join(sd_configs_path, "alt-diffusion-m18-inference.yaml") def is_using_v_parameterization_for_sd2(state_dict): """ @@ -95,7 +95,11 @@ def guess_model_config_from_state_dict(sd, filename): if diffusion_model_input.shape[1] == 8: return config_instruct_pix2pix + + # import pdb; pdb.set_trace() if sd.get('cond_stage_model.roberta.embeddings.word_embeddings.weight', None) is not None: + if sd.get('cond_stage_model.transformation.weight').size()[0] == 1024: + return config_alt_diffusion_m18 return config_alt_diffusion return config_default diff --git a/modules/xlmr_m18.py b/modules/xlmr_m18.py new file mode 100644 index 00000000000..18785692a65 --- /dev/null +++ b/modules/xlmr_m18.py @@ -0,0 +1,164 @@ +from transformers import BertPreTrainedModel,BertModel,BertConfig +import torch.nn as nn +import torch +from transformers.models.xlm_roberta.configuration_xlm_roberta import XLMRobertaConfig +from transformers import XLMRobertaModel,XLMRobertaTokenizer +from typing import Optional + +class BertSeriesConfig(BertConfig): + def __init__(self, vocab_size=30522, hidden_size=768, num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072, hidden_act="gelu", hidden_dropout_prob=0.1, attention_probs_dropout_prob=0.1, max_position_embeddings=512, type_vocab_size=2, initializer_range=0.02, layer_norm_eps=1e-12, pad_token_id=0, position_embedding_type="absolute", use_cache=True, classifier_dropout=None,project_dim=512, pooler_fn="average",learn_encoder=False,model_type='bert',**kwargs): + + super().__init__(vocab_size, hidden_size, num_hidden_layers, num_attention_heads, intermediate_size, hidden_act, hidden_dropout_prob, attention_probs_dropout_prob, max_position_embeddings, type_vocab_size, initializer_range, layer_norm_eps, pad_token_id, position_embedding_type, use_cache, classifier_dropout, **kwargs) + self.project_dim = project_dim + self.pooler_fn = pooler_fn + self.learn_encoder = learn_encoder + +class RobertaSeriesConfig(XLMRobertaConfig): + def __init__(self, pad_token_id=1, bos_token_id=0, eos_token_id=2,project_dim=512,pooler_fn='cls',learn_encoder=False, **kwargs): + super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) + self.project_dim = project_dim + self.pooler_fn = pooler_fn + self.learn_encoder = learn_encoder + + +class BertSeriesModelWithTransformation(BertPreTrainedModel): + + _keys_to_ignore_on_load_unexpected = [r"pooler"] + _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"] + config_class = BertSeriesConfig + + def __init__(self, config=None, **kargs): + # modify initialization for autoloading + if config is None: + config = XLMRobertaConfig() + config.attention_probs_dropout_prob= 0.1 + config.bos_token_id=0 + config.eos_token_id=2 + config.hidden_act='gelu' + config.hidden_dropout_prob=0.1 + config.hidden_size=1024 + config.initializer_range=0.02 + config.intermediate_size=4096 + config.layer_norm_eps=1e-05 + config.max_position_embeddings=514 + + config.num_attention_heads=16 + config.num_hidden_layers=24 + config.output_past=True + config.pad_token_id=1 + config.position_embedding_type= "absolute" + + config.type_vocab_size= 1 + config.use_cache=True + config.vocab_size= 250002 + config.project_dim = 1024 + config.learn_encoder = False + super().__init__(config) + self.roberta = XLMRobertaModel(config) + self.transformation = nn.Linear(config.hidden_size,config.project_dim) + # self.pre_LN=nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.tokenizer = XLMRobertaTokenizer.from_pretrained('xlm-roberta-large') + # self.pooler = lambda x: x[:,0] + # self.post_init() + + self.has_pre_transformation = True + if self.has_pre_transformation: + self.transformation_pre = nn.Linear(config.hidden_size, config.project_dim) + self.pre_LN = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.post_init() + + def encode(self,c): + device = next(self.parameters()).device + text = self.tokenizer(c, + truncation=True, + max_length=77, + return_length=False, + return_overflowing_tokens=False, + padding="max_length", + return_tensors="pt") + text["input_ids"] = torch.tensor(text["input_ids"]).to(device) + text["attention_mask"] = torch.tensor( + text['attention_mask']).to(device) + features = self(**text) + return features['projection_state'] + + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + return_dict: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + ) : + r""" + """ + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + + outputs = self.roberta( + input_ids=input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + output_attentions=output_attentions, + output_hidden_states=True, + return_dict=return_dict, + ) + + # # last module outputs + # sequence_output = outputs[0] + + + # # project every module + # sequence_output_ln = self.pre_LN(sequence_output) + + # # pooler + # pooler_output = self.pooler(sequence_output_ln) + # pooler_output = self.transformation(pooler_output) + # projection_state = self.transformation(outputs.last_hidden_state) + + if self.has_pre_transformation: + sequence_output2 = outputs["hidden_states"][-2] + sequence_output2 = self.pre_LN(sequence_output2) + projection_state2 = self.transformation_pre(sequence_output2) + + return { + "projection_state": projection_state2, + "last_hidden_state": outputs.last_hidden_state, + "hidden_states": outputs.hidden_states, + "attentions": outputs.attentions, + } + else: + projection_state = self.transformation(outputs.last_hidden_state) + return { + "projection_state": projection_state, + "last_hidden_state": outputs.last_hidden_state, + "hidden_states": outputs.hidden_states, + "attentions": outputs.attentions, + } + + + # return { + # 'pooler_output':pooler_output, + # 'last_hidden_state':outputs.last_hidden_state, + # 'hidden_states':outputs.hidden_states, + # 'attentions':outputs.attentions, + # 'projection_state':projection_state, + # 'sequence_out': sequence_output + # } + + +class RobertaSeriesModelWithTransformation(BertSeriesModelWithTransformation): + base_model_prefix = 'roberta' + config_class= RobertaSeriesConfig \ No newline at end of file From f8f4ff2bb8f56877dede466934dd8ddf25c21063 Mon Sep 17 00:00:00 2001 From: superhero-7 <537093830@qq.com> Date: Sat, 23 Sep 2023 17:55:19 +0800 Subject: [PATCH 0119/1174] support altdiffusion-m18 --- modules/sd_hijack.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index ae9b2a656ac..4b36c0e9c13 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -212,6 +212,7 @@ def hijack(self, m): model_embeddings = m.cond_stage_model.roberta.embeddings model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.word_embeddings, self) m.cond_stage_model = sd_hijack_xlmr.FrozenXLMREmbedderWithCustomWords(m.cond_stage_model, self) + elif type(m.cond_stage_model) == ldm.modules.encoders.modules.FrozenCLIPEmbedder: model_embeddings = m.cond_stage_model.transformer.text_model.embeddings model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self) @@ -257,6 +258,7 @@ def undo_hijack(self, m): if hasattr(m, 'cond_stage_model'): delattr(m, 'cond_stage_model') + elif type(m.cond_stage_model) == sd_hijack_xlmr.FrozenXLMREmbedderWithCustomWords: m.cond_stage_model = m.cond_stage_model.wrapped From fdecf813b63db4a4e49b92ebfdf705a0d4047287 Mon Sep 17 00:00:00 2001 From: ezt19 <136929737+ezt19@users.noreply.github.com> Date: Sat, 23 Sep 2023 20:41:28 +0000 Subject: [PATCH 0120/1174] Update dragdrop.js Fixing a problem when u cannot put two images and they are going into two different places for images. --- javascript/dragdrop.js | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/javascript/dragdrop.js b/javascript/dragdrop.js index 5803daea5ef..d680daf52f2 100644 --- a/javascript/dragdrop.js +++ b/javascript/dragdrop.js @@ -119,7 +119,7 @@ window.addEventListener('paste', e => { } const firstFreeImageField = visibleImageFields - .filter(el => el.querySelector('input[type=file]'))?.[0]; + .filter(el => !el.querySelector('img'))?.[0]; dropReplaceImage( firstFreeImageField ? From d00f6dca2825c2a73bbc1a5c707be276d62acc6b Mon Sep 17 00:00:00 2001 From: missionfloyd Date: Mon, 25 Sep 2023 22:08:24 -0600 Subject: [PATCH 0121/1174] Escape item names --- modules/ui_extra_networks.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index 063bd7b80e6..60b95f21855 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -213,9 +213,9 @@ def create_html_for_item(self, item, tabname): metadata_button = "" metadata = item.get("metadata") if metadata: - metadata_button = f"" + metadata_button = f"" - edit_button = f"
" + edit_button = f"
" local_path = "" filename = item.get("filename", "") From 99aa702015b2a7c6707081cc975cfd12c40d55c4 Mon Sep 17 00:00:00 2001 From: missionfloyd Date: Tue, 26 Sep 2023 21:08:55 -0600 Subject: [PATCH 0122/1174] Update card on correct tab --- javascript/extraNetworks.js | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/javascript/extraNetworks.js b/javascript/extraNetworks.js index 493f31af28a..e927346c332 100644 --- a/javascript/extraNetworks.js +++ b/javascript/extraNetworks.js @@ -335,7 +335,7 @@ function extraNetworksEditUserMetadata(event, tabname, extraPage, cardName) { function extraNetworksRefreshSingleCard(page, tabname, name) { requestGet("./sd_extra_networks/get-single-card", {page: page, tabname: tabname, name: name}, function(data) { if (data && data.html) { - var card = gradioApp().querySelector('.card[data-name=' + JSON.stringify(name) + ']'); // likely using the wrong stringify function + var card = gradioApp().querySelector(`#${tabname}_${page.replace(" ", "_")}_cards > .card[data-name="${name}"]`); var newDiv = document.createElement('DIV'); newDiv.innerHTML = data.html; From a69daae012458bbd3d2cc472dc757fd78090ae05 Mon Sep 17 00:00:00 2001 From: missionfloyd Date: Tue, 26 Sep 2023 22:02:52 -0600 Subject: [PATCH 0123/1174] Fix data-sort-name containing spaces --- modules/ui_extra_networks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index 063bd7b80e6..799bd17402c 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -235,7 +235,7 @@ def create_html_for_item(self, item, tabname): if search_only and shared.opts.extra_networks_hidden_models == "Never": return "" - sort_keys = " ".join([html.escape(f'data-sort-{k}={v}') for k, v in item.get("sort_keys", {}).items()]).strip() + sort_keys = " ".join([f'data-sort-{k}="{html.escape(str(v))}"' for k, v in item.get("sort_keys", {}).items()]).strip() args = { "background_image": background_image, From 30f4f25b2ed77464104f31c35b40f5994ffce67b Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Wed, 27 Sep 2023 10:21:14 +0300 Subject: [PATCH 0124/1174] Bump to torchsde==0.2.6 --- requirements_versions.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements_versions.txt b/requirements_versions.txt index f8ae1f385ae..7d27f2be3bb 100644 --- a/requirements_versions.txt +++ b/requirements_versions.txt @@ -27,5 +27,5 @@ timm==0.9.2 tomesd==0.1.3 torch torchdiffeq==0.2.3 -torchsde==0.2.5 +torchsde==0.2.6 transformers==4.30.2 From ad3b8a1c41b8ab98b8b98c6364d13cb8b6d5fa88 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sat, 30 Sep 2023 08:23:12 +0300 Subject: [PATCH 0125/1174] alternative solution to #13434 --- modules/restart.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/modules/restart.py b/modules/restart.py index 18eacaf377e..2dd6493b450 100644 --- a/modules/restart.py +++ b/modules/restart.py @@ -14,7 +14,9 @@ def is_restartable() -> bool: def restart_program() -> None: """creates file tmp/restart and immediately stops the process, which webui.bat/webui.sh interpret as a command to start webui again""" - (Path(script_path) / "tmp" / "restart").touch() + tmpdir = Path(script_path) / "tmp" + tmpdir.mkdir(parents=True, exist_ok=True) + (tmpdir / "restart").touch() stop_program() From 87b50397a6da273fe0160016a209e4eb0cbf4a89 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sat, 30 Sep 2023 09:11:31 +0300 Subject: [PATCH 0126/1174] add missing import, simplify code, use patches module for #13276 --- modules/sd_models.py | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/modules/sd_models.py b/modules/sd_models.py index e3755253df8..5ef7aa138ed 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -14,7 +14,7 @@ from ldm.util import instantiate_from_config -from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet, sd_models_xl, cache, extra_networks, processing, lowvram, sd_hijack +from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet, sd_models_xl, cache, extra_networks, processing, lowvram, sd_hijack, patches from modules.timer import Timer import tomesd import numpy as np @@ -130,6 +130,8 @@ def calculate_shorthash(self): def setup_model(): + """called once at startup to do various one-time tasks related to SD models""" + os.makedirs(model_path, exist_ok=True) enable_midas_autodownload() @@ -458,14 +460,17 @@ def load_model_wrapper(model_type): def patch_given_betas(): - original_register_schedule = ldm.models.diffusion.ddpm.DDPM.register_schedule + import ldm.models.diffusion.ddpm + def patched_register_schedule(*args, **kwargs): - if args[1] is not None and isinstance(args[1], ListConfig): - modified_args = list(args) # Convert args tuple to a list - modified_args[1] = np.array(args[1]) # Modify the desired element - args = tuple(modified_args) # Convert the list back to a tuple + """a modified version of register_schedule function that converts plain list from Omegaconf into numpy""" + + if isinstance(args[1], ListConfig): + args = (args[0], np.array(args[1]), *args[2:]) + original_register_schedule(*args, **kwargs) - ldm.models.diffusion.ddpm.DDPM.register_schedule = patched_register_schedule + + original_register_schedule = patches.patch(__name__, ldm.models.diffusion.ddpm.DDPM, 'register_schedule', patched_register_schedule) def repair_config(sd_config): From ab63054f95a7b7dfb2f0ffa88e805c8d2f950605 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sat, 30 Sep 2023 09:34:50 +0300 Subject: [PATCH 0127/1174] write infotext to gif image as comment --- modules/images.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/modules/images.py b/modules/images.py index 1e63118e311..daf4eebe4b1 100644 --- a/modules/images.py +++ b/modules/images.py @@ -561,6 +561,8 @@ def save_image_with_geninfo(image, geninfo, filename, extension=None, existing_p }) piexif.insert(exif_bytes, filename) + elif extension.lower() == ".gif": + image.save(filename, format=image_format, comment=geninfo) else: image.save(filename, format=image_format, quality=opts.jpeg_quality) From 1cc7c4bfb31b80b6667154145f1455541951db18 Mon Sep 17 00:00:00 2001 From: missionfloyd Date: Sat, 30 Sep 2023 01:09:09 -0600 Subject: [PATCH 0128/1174] Allow editing whitespace delimiters --- javascript/edit-attention.js | 3 ++- modules/shared_options.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/javascript/edit-attention.js b/javascript/edit-attention.js index 8906c8922e1..bc4ebed4ca7 100644 --- a/javascript/edit-attention.js +++ b/javascript/edit-attention.js @@ -46,7 +46,8 @@ function keyupEditAttention(event) { function selectCurrentWord() { if (selectionStart !== selectionEnd) return false; - const delimiters = opts.keyedit_delimiters + " \r\n\t"; + let delimiters = opts.keyedit_delimiters.replace(/(^|[^\\])(\\\\)*\\$/, "$&\\").replace(/(^|[^\\])((\\\\)*")/g, "$1\\$2"); + delimiters = JSON.parse(`"${delimiters}"`); // seek backward until to find beggining while (!delimiters.includes(text[selectionStart - 1]) && selectionStart > 0) { diff --git a/modules/shared_options.py b/modules/shared_options.py index 00b273faa54..a1f157c6603 100644 --- a/modules/shared_options.py +++ b/modules/shared_options.py @@ -255,7 +255,7 @@ "dimensions_and_batch_together": OptionInfo(True, "Show Width/Height and Batch sliders in same row").needs_reload_ui(), "keyedit_precision_attention": OptionInfo(0.1, "Ctrl+up/down precision when editing (attention:1.1)", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}), "keyedit_precision_extra": OptionInfo(0.05, "Ctrl+up/down precision when editing ", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}), - "keyedit_delimiters": OptionInfo(".,\\/!?%^*;:{}=`~()", "Ctrl+up/down word delimiters"), + "keyedit_delimiters": OptionInfo(r".,\\/!?%^*;:{}=`~() \r\n\t", "Ctrl+up/down word delimiters"), "keyedit_move": OptionInfo(True, "Alt+left/right moves prompt elements"), "quicksettings_list": OptionInfo(["sd_model_checkpoint"], "Quicksettings list", ui_components.DropdownMulti, lambda: {"choices": list(shared.opts.data_labels.keys())}).js("info", "settingsHintsShowQuicksettings").info("setting entries that appear at the top of page rather than in settings tab").needs_reload_ui(), "ui_tab_order": OptionInfo([], "UI tab order", ui_components.DropdownMulti, lambda: {"choices": list(shared.tab_names)}).needs_reload_ui(), From 5cc7bf387661354e5c268c0e8198cc44328b3282 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sat, 30 Sep 2023 10:10:57 +0300 Subject: [PATCH 0129/1174] reword sd_checkpoint_dropdown_use_short setting and add explanation --- modules/shared_options.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/shared_options.py b/modules/shared_options.py index 8e8d402d990..61b93d47a9c 100644 --- a/modules/shared_options.py +++ b/modules/shared_options.py @@ -262,7 +262,7 @@ "ui_tab_order": OptionInfo([], "UI tab order", ui_components.DropdownMulti, lambda: {"choices": list(shared.tab_names)}).needs_reload_ui(), "hidden_tabs": OptionInfo([], "Hidden UI tabs", ui_components.DropdownMulti, lambda: {"choices": list(shared.tab_names)}).needs_reload_ui(), "ui_reorder_list": OptionInfo([], "txt2img/img2img UI item order", ui_components.DropdownMulti, lambda: {"choices": list(shared_items.ui_reorder_categories())}).info("selected items appear first").needs_reload_ui(), - "sd_checkpoint_dropdown_use_short": OptionInfo(False, "Use short name for Stable Diffusion checkpoint dropdown"), + "sd_checkpoint_dropdown_use_short": OptionInfo(False, "Checkpoint dropdown: use filenames without paths").info("models in subdirectories like photo/sd15.ckpt will be listed as just sd15.ckpt"), "hires_fix_show_sampler": OptionInfo(False, "Hires fix: show hires checkpoint and sampler selection").needs_reload_ui(), "hires_fix_show_prompts": OptionInfo(False, "Hires fix: show hires prompt and negative prompt").needs_reload_ui(), "disable_token_counters": OptionInfo(False, "Disable prompt token counters").needs_reload_ui(), From b2f9709538ee40c6bbf11e3f17f7e3ea4b9cb78a Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sat, 30 Sep 2023 10:29:10 +0300 Subject: [PATCH 0130/1174] get #13121 to work without restart --- modules/ui_extra_networks.py | 23 +++++++++++++++-------- 1 file changed, 15 insertions(+), 8 deletions(-) diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index d8c31142373..59d6ecc612a 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -1,3 +1,4 @@ +import functools import os.path import urllib.parse from pathlib import Path @@ -15,10 +16,16 @@ extra_pages = [] allowed_dirs = set() -allowed_preview_extensions = ["png", "jpg", "jpeg", "webp", "gif"] -if shared.opts.samples_format not in allowed_preview_extensions: - allowed_preview_extensions.append(shared.opts.samples_format) -allowed_preview_extensions_dot = ['.' + extension for extension in allowed_preview_extensions] +default_allowed_preview_extensions = ["png", "jpg", "jpeg", "webp", "gif"] + + +@functools.cache +def allowed_preview_extensions_with_extra(extra_extensions=None): + return set(default_allowed_preview_extensions) | set(extra_extensions or []) + + +def allowed_preview_extensions(): + return allowed_preview_extensions_with_extra((shared.opts.samples_format, )) def register_page(page): @@ -38,9 +45,9 @@ def fetch_file(filename: str = ""): if not any(Path(x).absolute() in Path(filename).absolute().parents for x in allowed_dirs): raise ValueError(f"File cannot be fetched: {filename}. Must be in one of directories registered by extra pages.") - ext = os.path.splitext(filename)[1].lower() - if ext not in allowed_preview_extensions_dot: - raise ValueError(f"File cannot be fetched: {filename}. Only png, jpg, webp, and gif.") + ext = os.path.splitext(filename)[1].lower()[1:] + if ext not in allowed_preview_extensions(): + raise ValueError(f"File cannot be fetched: {filename}. Extensions allowed: {allowed_preview_extensions()}.") # would profit from returning 304 return FileResponse(filename, headers={"Accept-Ranges": "bytes"}) @@ -278,7 +285,7 @@ def find_preview(self, path): Find a preview PNG for a given path (without extension) and call link_preview on it. """ - potential_files = sum([[path + "." + ext, path + ".preview." + ext] for ext in allowed_preview_extensions], []) + potential_files = sum([[path + "." + ext, path + ".preview." + ext] for ext in allowed_preview_extensions()], []) for file in potential_files: if os.path.isfile(file): From 0935d2c3047210b799cbc6f8ce15d3dffca95af7 Mon Sep 17 00:00:00 2001 From: missionfloyd Date: Sat, 30 Sep 2023 18:37:44 -0600 Subject: [PATCH 0131/1174] Use checkboxes for whitespace delimiters --- javascript/edit-attention.js | 8 ++++++-- modules/shared_options.py | 3 ++- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/javascript/edit-attention.js b/javascript/edit-attention.js index bc4ebed4ca7..943d81b080d 100644 --- a/javascript/edit-attention.js +++ b/javascript/edit-attention.js @@ -46,8 +46,12 @@ function keyupEditAttention(event) { function selectCurrentWord() { if (selectionStart !== selectionEnd) return false; - let delimiters = opts.keyedit_delimiters.replace(/(^|[^\\])(\\\\)*\\$/, "$&\\").replace(/(^|[^\\])((\\\\)*")/g, "$1\\$2"); - delimiters = JSON.parse(`"${delimiters}"`); + const whitespace_delimiters = {"Tab": "\t", "Carriage Return": "\r", "Line Feed": "\n"}; + let delimiters = opts.keyedit_delimiters; + + for (let i of opts.keyedit_delimiters_whitespace) { + delimiters += whitespace_delimiters[i]; + } // seek backward until to find beggining while (!delimiters.includes(text[selectionStart - 1]) && selectionStart > 0) { diff --git a/modules/shared_options.py b/modules/shared_options.py index a1f157c6603..717c948b59a 100644 --- a/modules/shared_options.py +++ b/modules/shared_options.py @@ -255,7 +255,8 @@ "dimensions_and_batch_together": OptionInfo(True, "Show Width/Height and Batch sliders in same row").needs_reload_ui(), "keyedit_precision_attention": OptionInfo(0.1, "Ctrl+up/down precision when editing (attention:1.1)", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}), "keyedit_precision_extra": OptionInfo(0.05, "Ctrl+up/down precision when editing ", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}), - "keyedit_delimiters": OptionInfo(r".,\\/!?%^*;:{}=`~() \r\n\t", "Ctrl+up/down word delimiters"), + "keyedit_delimiters": OptionInfo(r".,\/!?%^*;:{}=`~() ", "Ctrl+up/down word delimiters"), + "keyedit_delimiters_whitespace": OptionInfo(["Tab", "Carriage Return", "Line Feed"], "Ctrl+up/down whitespace delimiters", gr.CheckboxGroup, lambda: {"choices": ["Tab", "Carriage Return", "Line Feed"]}), "keyedit_move": OptionInfo(True, "Alt+left/right moves prompt elements"), "quicksettings_list": OptionInfo(["sd_model_checkpoint"], "Quicksettings list", ui_components.DropdownMulti, lambda: {"choices": list(shared.opts.data_labels.keys())}).js("info", "settingsHintsShowQuicksettings").info("setting entries that appear at the top of page rather than in settings tab").needs_reload_ui(), "ui_tab_order": OptionInfo([], "UI tab order", ui_components.DropdownMulti, lambda: {"choices": list(shared.tab_names)}).needs_reload_ui(), From 0eb5fde2fd42184f431ccba5f25d714272e3e6b0 Mon Sep 17 00:00:00 2001 From: missionfloyd Date: Sat, 30 Sep 2023 21:20:58 -0600 Subject: [PATCH 0132/1174] Remove unneeded code --- javascript/edit-attention.js | 13 +------------ 1 file changed, 1 insertion(+), 12 deletions(-) diff --git a/javascript/edit-attention.js b/javascript/edit-attention.js index 943d81b080d..df218c1e461 100644 --- a/javascript/edit-attention.js +++ b/javascript/edit-attention.js @@ -18,22 +18,11 @@ function keyupEditAttention(event) { const before = text.substring(0, selectionStart); let beforeParen = before.lastIndexOf(OPEN); if (beforeParen == -1) return false; - let beforeParenClose = before.lastIndexOf(CLOSE); - while (beforeParenClose !== -1 && beforeParenClose > beforeParen) { - beforeParen = before.lastIndexOf(OPEN, beforeParen - 1); - beforeParenClose = before.lastIndexOf(CLOSE, beforeParenClose - 1); - } // Find closing parenthesis around current cursor const after = text.substring(selectionStart); let afterParen = after.indexOf(CLOSE); if (afterParen == -1) return false; - let afterParenOpen = after.indexOf(OPEN); - while (afterParenOpen !== -1 && afterParen > afterParenOpen) { - afterParen = after.indexOf(CLOSE, afterParen + 1); - afterParenOpen = after.indexOf(OPEN, afterParenOpen + 1); - } - if (beforeParen === -1 || afterParen === -1) return false; // Set the selection to the text between the parenthesis const parenContent = text.substring(beforeParen + 1, selectionStart + afterParen); @@ -53,7 +42,7 @@ function keyupEditAttention(event) { delimiters += whitespace_delimiters[i]; } - // seek backward until to find beggining + // seek backward to find beginning while (!delimiters.includes(text[selectionStart - 1]) && selectionStart > 0) { selectionStart--; } From 56ef5e9d48750fd43f9faba31ff67e64368153b7 Mon Sep 17 00:00:00 2001 From: missionfloyd Date: Sat, 30 Sep 2023 21:44:05 -0600 Subject: [PATCH 0133/1174] Remove end parenthesis from weight --- javascript/edit-attention.js | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/javascript/edit-attention.js b/javascript/edit-attention.js index df218c1e461..794453bfeab 100644 --- a/javascript/edit-attention.js +++ b/javascript/edit-attention.js @@ -86,7 +86,7 @@ function keyupEditAttention(event) { } var end = text.slice(selectionEnd + 1).indexOf(closeCharacter) + 1; - var weight = parseFloat(text.slice(selectionEnd + 1, selectionEnd + 1 + end)); + var weight = parseFloat(text.slice(selectionEnd + 1, selectionEnd + end)); if (isNaN(weight)) return; weight += isPlus ? delta : -delta; From 2d947175b902d6838c803036d9757e7d3226b41d Mon Sep 17 00:00:00 2001 From: superhero-7 <537093830@qq.com> Date: Sun, 1 Oct 2023 12:25:19 +0800 Subject: [PATCH 0134/1174] fix linter issues --- modules/sd_hijack.py | 4 ++-- modules/sd_models_config.py | 3 +-- modules/xlmr_m18.py | 12 ++++++------ 3 files changed, 9 insertions(+), 10 deletions(-) diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 4b36c0e9c13..0689699cf19 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -212,7 +212,7 @@ def hijack(self, m): model_embeddings = m.cond_stage_model.roberta.embeddings model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.word_embeddings, self) m.cond_stage_model = sd_hijack_xlmr.FrozenXLMREmbedderWithCustomWords(m.cond_stage_model, self) - + elif type(m.cond_stage_model) == ldm.modules.encoders.modules.FrozenCLIPEmbedder: model_embeddings = m.cond_stage_model.transformer.text_model.embeddings model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self) @@ -258,7 +258,7 @@ def undo_hijack(self, m): if hasattr(m, 'cond_stage_model'): delattr(m, 'cond_stage_model') - + elif type(m.cond_stage_model) == sd_hijack_xlmr.FrozenXLMREmbedderWithCustomWords: m.cond_stage_model = m.cond_stage_model.wrapped diff --git a/modules/sd_models_config.py b/modules/sd_models_config.py index 9ba89dfc0eb..deab2f6e237 100644 --- a/modules/sd_models_config.py +++ b/modules/sd_models_config.py @@ -95,8 +95,7 @@ def guess_model_config_from_state_dict(sd, filename): if diffusion_model_input.shape[1] == 8: return config_instruct_pix2pix - - # import pdb; pdb.set_trace() + if sd.get('cond_stage_model.roberta.embeddings.word_embeddings.weight', None) is not None: if sd.get('cond_stage_model.transformation.weight').size()[0] == 1024: return config_alt_diffusion_m18 diff --git a/modules/xlmr_m18.py b/modules/xlmr_m18.py index 18785692a65..a727e865529 100644 --- a/modules/xlmr_m18.py +++ b/modules/xlmr_m18.py @@ -1,4 +1,4 @@ -from transformers import BertPreTrainedModel,BertModel,BertConfig +from transformers import BertPreTrainedModel,BertConfig import torch.nn as nn import torch from transformers.models.xlm_roberta.configuration_xlm_roberta import XLMRobertaConfig @@ -28,7 +28,7 @@ class BertSeriesModelWithTransformation(BertPreTrainedModel): config_class = BertSeriesConfig def __init__(self, config=None, **kargs): - # modify initialization for autoloading + # modify initialization for autoloading if config is None: config = XLMRobertaConfig() config.attention_probs_dropout_prob= 0.1 @@ -80,7 +80,7 @@ def encode(self,c): text["attention_mask"] = torch.tensor( text['attention_mask']).to(device) features = self(**text) - return features['projection_state'] + return features['projection_state'] def forward( self, @@ -147,8 +147,8 @@ def forward( "hidden_states": outputs.hidden_states, "attentions": outputs.attentions, } - - + + # return { # 'pooler_output':pooler_output, # 'last_hidden_state':outputs.last_hidden_state, @@ -161,4 +161,4 @@ def forward( class RobertaSeriesModelWithTransformation(BertSeriesModelWithTransformation): base_model_prefix = 'roberta' - config_class= RobertaSeriesConfig \ No newline at end of file + config_class= RobertaSeriesConfig From c7e810a9856641fbaf520976fde24c5536a4fd56 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sun, 1 Oct 2023 10:15:23 +0300 Subject: [PATCH 0135/1174] add onEdit function for js and rework token-counter.js to use it --- .eslintrc.js | 1 + javascript/token-counters.js | 26 +++++++++----------------- javascript/ui.js | 17 +++++++++++++++++ 3 files changed, 27 insertions(+), 17 deletions(-) diff --git a/.eslintrc.js b/.eslintrc.js index 4777c276e9b..cf8397695e1 100644 --- a/.eslintrc.js +++ b/.eslintrc.js @@ -74,6 +74,7 @@ module.exports = { create_submit_args: "readonly", restart_reload: "readonly", updateInput: "readonly", + onEdit: "readonly", //extraNetworks.js requestGet: "readonly", popup: "readonly", diff --git a/javascript/token-counters.js b/javascript/token-counters.js index 9d81a723b01..710345eb001 100644 --- a/javascript/token-counters.js +++ b/javascript/token-counters.js @@ -1,10 +1,9 @@ -let promptTokenCountDebounceTime = 800; -let promptTokenCountTimeouts = {}; -var promptTokenCountUpdateFunctions = {}; +let promptTokenCountUpdateFunctions = {}; function update_txt2img_tokens(...args) { // Called from Gradio update_token_counter("txt2img_token_button"); + update_token_counter("txt2img_negative_token_button"); if (args.length == 2) { return args[0]; } @@ -14,6 +13,7 @@ function update_txt2img_tokens(...args) { function update_img2img_tokens(...args) { // Called from Gradio update_token_counter("img2img_token_button"); + update_token_counter("img2img_negative_token_button"); if (args.length == 2) { return args[0]; } @@ -21,16 +21,7 @@ function update_img2img_tokens(...args) { } function update_token_counter(button_id) { - if (opts.disable_token_counters) { - return; - } - if (promptTokenCountTimeouts[button_id]) { - clearTimeout(promptTokenCountTimeouts[button_id]); - } - promptTokenCountTimeouts[button_id] = setTimeout( - () => gradioApp().getElementById(button_id)?.click(), - promptTokenCountDebounceTime, - ); + promptTokenCountUpdateFunctions[button_id]?.(); } @@ -69,10 +60,11 @@ function setupTokenCounting(id, id_counter, id_button) { prompt.parentElement.insertBefore(counter, prompt); prompt.parentElement.style.position = "relative"; - promptTokenCountUpdateFunctions[id] = function() { - update_token_counter(id_button); - }; - textarea.addEventListener("input", promptTokenCountUpdateFunctions[id]); + func = onEdit(id, textarea, 800, function() { + gradioApp().getElementById(id_button)?.click(); + }); + promptTokenCountUpdateFunctions[id] = func; + promptTokenCountUpdateFunctions[id_button] = func; } function setupTokenCounters() { diff --git a/javascript/ui.js b/javascript/ui.js index bedcbf3e211..aee0d1da9a7 100644 --- a/javascript/ui.js +++ b/javascript/ui.js @@ -366,3 +366,20 @@ function switchWidthHeight(tabname) { updateInput(height); return []; } + + +var onEditTimers = {}; + +// calls func after afterMs milliseconds has passed since the input elem has beed enited by user +function onEdit(editId, elem, afterMs, func) { + var edited = function() { + var existingTimer = onEditTimers[editId]; + if (existingTimer) clearTimeout(existingTimer); + + onEditTimers[editId] = setTimeout(func, afterMs); + }; + + elem.addEventListener("input", edited); + + return edited; +} From deeec0b34359b0cc4c93ee01d3c7bca9c2f609b8 Mon Sep 17 00:00:00 2001 From: Won-Kyu Park Date: Sun, 1 Oct 2023 00:25:53 +0900 Subject: [PATCH 0136/1174] fix fieldname regex to accept additional [-/] chars --- modules/generation_parameters_copypaste.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py index d39f2ebac36..f1b0225e3c9 100644 --- a/modules/generation_parameters_copypaste.py +++ b/modules/generation_parameters_copypaste.py @@ -9,7 +9,7 @@ from modules import shared, ui_tempdir, script_callbacks, processing from PIL import Image -re_param_code = r'\s*([\w ]+):\s*("(?:\\.|[^\\"])+"|[^,]*)(?:,|$)' +re_param_code = r'\s*(\w[\w -/]+):\s*("(?:\\.|[^\\"])+"|[^,]*)(?:,|$)' re_param = re.compile(re_param_code) re_imagesize = re.compile(r"^(\d+)x(\d+)$") re_hypernet_hash = re.compile("\(([0-9a-f]+)\)$") From c0113872c5f814cf8cf96deca541bffaf1af2568 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sun, 1 Oct 2023 11:48:41 +0300 Subject: [PATCH 0137/1174] add search field to settings --- javascript/settings.js | 46 ++++++++++++++++++++++++++++++++++++ javascript/token-counters.js | 2 +- javascript/ui.js | 15 ------------ modules/ui_settings.py | 16 +++++++++++-- style.css | 2 ++ 5 files changed, 63 insertions(+), 18 deletions(-) create mode 100644 javascript/settings.js diff --git a/javascript/settings.js b/javascript/settings.js new file mode 100644 index 00000000000..7889cf1cd9c --- /dev/null +++ b/javascript/settings.js @@ -0,0 +1,46 @@ +let settingsExcludeTabsFromShowAll = { + settings_tab_defaults: 1, + settings_tab_sysinfo: 1, + settings_tab_actions: 1, + settings_tab_licenses: 1, +}; + +function settingsShowAllTabs() { + gradioApp().querySelectorAll('#settings > div').forEach(function(elem) { + if (settingsExcludeTabsFromShowAll[elem.id]) return; + + elem.style.display = "block"; + }); +} + +function settingsShowOneTab() { + gradioApp().querySelector('#settings_show_one_page').click(); +} + +onUiLoaded(function() { + var edit = gradioApp().querySelector('#settings_search'); + var editTextarea = gradioApp().querySelector('#settings_search > label > input'); + var buttonShowAllPages = gradioApp().getElementById('settings_show_all_pages'); + var settings_tabs = gradioApp().querySelector('#settings div'); + + onEdit('settingsSearch', editTextarea, 250, function() { + var searchText = (editTextarea.value || "").trim(); + + gradioApp().querySelectorAll('#settings > div[id^=settings_] div[id^=column_settings_] > *').forEach(function(elem) { + var visible = elem.textContent.trim().indexOf(searchText) != -1; + elem.style.display = visible ? "" : "none"; + }); + + if (searchText != "") { + settingsShowAllTabs(); + } else { + settingsShowOneTab(); + } + }); + + settings_tabs.insertBefore(edit, settings_tabs.firstChild); + settings_tabs.appendChild(buttonShowAllPages); + + + buttonShowAllPages.addEventListener("click", settingsShowAllTabs); +}); diff --git a/javascript/token-counters.js b/javascript/token-counters.js index 710345eb001..2ecc7d91010 100644 --- a/javascript/token-counters.js +++ b/javascript/token-counters.js @@ -60,7 +60,7 @@ function setupTokenCounting(id, id_counter, id_button) { prompt.parentElement.insertBefore(counter, prompt); prompt.parentElement.style.position = "relative"; - func = onEdit(id, textarea, 800, function() { + var func = onEdit(id, textarea, 800, function() { gradioApp().getElementById(id_button)?.click(); }); promptTokenCountUpdateFunctions[id] = func; diff --git a/javascript/ui.js b/javascript/ui.js index aee0d1da9a7..2e262602044 100644 --- a/javascript/ui.js +++ b/javascript/ui.js @@ -263,21 +263,6 @@ onAfterUiUpdate(function() { json_elem.parentElement.style.display = "none"; setupTokenCounters(); - - var show_all_pages = gradioApp().getElementById('settings_show_all_pages'); - var settings_tabs = gradioApp().querySelector('#settings div'); - if (show_all_pages && settings_tabs) { - settings_tabs.appendChild(show_all_pages); - show_all_pages.onclick = function() { - gradioApp().querySelectorAll('#settings > div').forEach(function(elem) { - if (elem.id == "settings_tab_licenses") { - return; - } - - elem.style.display = "block"; - }); - }; - } }); onOptionsChanged(function() { diff --git a/modules/ui_settings.py b/modules/ui_settings.py index c6fe3604a1a..74a3aef32cc 100644 --- a/modules/ui_settings.py +++ b/modules/ui_settings.py @@ -64,6 +64,9 @@ class UiSettings: quicksettings_list = None quicksettings_names = None text_settings = None + show_all_pages = None + show_one_page = None + search_input = None def run_settings(self, *args): changed = [] @@ -136,7 +139,7 @@ def create_ui(self, loadsave, dummy_component): gr.Group() current_tab = gr.TabItem(elem_id=f"settings_{elem_id}", label=text) current_tab.__enter__() - current_row = gr.Column(variant='compact') + current_row = gr.Column(elem_id=f"column_settings_{elem_id}", variant='compact') current_row.__enter__() previous_section = item.section @@ -183,7 +186,11 @@ def create_ui(self, loadsave, dummy_component): with gr.TabItem("Licenses", id="licenses", elem_id="settings_tab_licenses"): gr.HTML(shared.html("licenses.html"), elem_id="licenses") - gr.Button(value="Show all pages", elem_id="settings_show_all_pages") + self.show_all_pages = gr.Button(value="Show all pages", elem_id="settings_show_all_pages") + self.show_one_page = gr.Button(value="Show only one page", elem_id="settings_show_one_page", visible=False) + self.show_one_page.click(lambda: None) + + self.search_input = gr.Textbox(value="", elem_id="settings_search", max_lines=1, placeholder="Search...", show_label=False) self.text_settings = gr.Textbox(elem_id="settings_json", value=lambda: opts.dumpjson(), visible=False) @@ -313,3 +320,8 @@ def get_settings_values(): outputs=[self.component_dict[k] for k in component_keys], queue=False, ) + + def search(self, text): + print(text) + + return [gr.update(visible=text in (comp.label or "")) for comp in self.components] diff --git a/style.css b/style.css index 58eb29c1598..eee5055280d 100644 --- a/style.css +++ b/style.css @@ -423,6 +423,7 @@ div#extras_scale_to_tab div.form{ #settings > div{ border: none; margin-left: 10em; + padding: 0 var(--spacing-xl); } #settings > div.tab-nav{ @@ -437,6 +438,7 @@ div#extras_scale_to_tab div.form{ border: none; text-align: left; white-space: initial; + padding: 4px; } #settings_result{ From f71e919ecb001c4d191b76a87477d6118de7be12 Mon Sep 17 00:00:00 2001 From: FluttyProger Date: Sun, 1 Oct 2023 18:06:48 +0300 Subject: [PATCH 0138/1174] Ability for extensions to return custom data via api in response.images --- modules/api/api.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/modules/api/api.py b/modules/api/api.py index 905ef9c9536..efedafa42e2 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -103,7 +103,8 @@ def decode_base64_to_image(encoding): def encode_pil_to_base64(image): with io.BytesIO() as output_bytes: - + if isinstance(image, str): + return image if opts.samples_format.lower() == 'png': use_metadata = False metadata = PngImagePlugin.PngInfo() From 3f763d41e8ff7f09f89adb00eec440f18566d260 Mon Sep 17 00:00:00 2001 From: missionfloyd Date: Sun, 1 Oct 2023 22:38:27 -0600 Subject: [PATCH 0139/1174] Change denoising_strength default to None. --- modules/processing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/processing.py b/modules/processing.py index e124e7f0dd2..061d9955a24 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -142,7 +142,7 @@ class StableDiffusionProcessing: overlay_images: list = None eta: float = None do_not_reload_embeddings: bool = False - denoising_strength: float = 0 + denoising_strength: float = None ddim_discretize: str = None s_min_uncond: float = None s_churn: float = None From 6ab0b65ed15730f590a3e530f00a2047f582b488 Mon Sep 17 00:00:00 2001 From: PermissionDenied7335 Date: Mon, 2 Oct 2023 15:43:59 +0800 Subject: [PATCH 0140/1174] Added an option not to enable venv --- webui.sh | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/webui.sh b/webui.sh index 3d0f87eed74..f3621903bf6 100755 --- a/webui.sh +++ b/webui.sh @@ -4,12 +4,6 @@ # change the variables in webui-user.sh instead # ################################################# - -use_venv=1 -if [[ $venv_dir == "-" ]]; then - use_venv=0 -fi - SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) @@ -28,6 +22,12 @@ then source "$SCRIPT_DIR"/webui-user.sh fi +# If $venv_dir is "-", then disable venv support +use_venv=1 +if [[ $venv_dir == "-" ]]; then + use_venv=0 +fi + # Set defaults # Install directory without trailing slash if [[ -z "${install_dir}" ]] From c2279da52260410ac2b4f64b7de7b2866f3a07a1 Mon Sep 17 00:00:00 2001 From: Won-Kyu Park Date: Tue, 3 Oct 2023 01:16:41 +0900 Subject: [PATCH 0141/1174] fix re_param_code (regression bug PR #13458) --- modules/generation_parameters_copypaste.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py index f1b0225e3c9..0a606515b79 100644 --- a/modules/generation_parameters_copypaste.py +++ b/modules/generation_parameters_copypaste.py @@ -9,7 +9,7 @@ from modules import shared, ui_tempdir, script_callbacks, processing from PIL import Image -re_param_code = r'\s*(\w[\w -/]+):\s*("(?:\\.|[^\\"])+"|[^,]*)(?:,|$)' +re_param_code = r'\s*(\w[\w \-/]+):\s*("(?:\\.|[^\\"])+"|[^,]*)(?:,|$)' re_param = re.compile(re_param_code) re_imagesize = re.compile(r"^(\d+)x(\d+)$") re_hypernet_hash = re.compile("\(([0-9a-f]+)\)$") From 86a46e81892c72cc50f9a37dfb1ca285f189134d Mon Sep 17 00:00:00 2001 From: missionfloyd Date: Mon, 2 Oct 2023 22:22:15 -0600 Subject: [PATCH 0142/1174] Fix accidentally closing popup dialogs --- javascript/extraNetworks.js | 11 ++++------- style.css | 14 ++++++++++---- 2 files changed, 14 insertions(+), 11 deletions(-) diff --git a/javascript/extraNetworks.js b/javascript/extraNetworks.js index 493f31af28a..b1b85669b02 100644 --- a/javascript/extraNetworks.js +++ b/javascript/extraNetworks.js @@ -216,27 +216,24 @@ function extraNetworksSearchButton(tabs_id, event) { var globalPopup = null; var globalPopupInner = null; + function closePopup() { if (!globalPopup) return; - globalPopup.style.display = "none"; } + function popup(contents) { if (!globalPopup) { globalPopup = document.createElement('div'); - globalPopup.onclick = closePopup; globalPopup.classList.add('global-popup'); - + var close = document.createElement('div'); close.classList.add('global-popup-close'); - close.onclick = closePopup; + close.addEventListener("click", closePopup); close.title = "Close"; globalPopup.appendChild(close); globalPopupInner = document.createElement('div'); - globalPopupInner.onclick = function(event) { - event.stopPropagation(); return false; - }; globalPopupInner.classList.add('global-popup-inner'); globalPopup.appendChild(globalPopupInner); diff --git a/style.css b/style.css index fb4e2f1f02e..e034ecfd0f0 100644 --- a/style.css +++ b/style.css @@ -581,7 +581,6 @@ table.popup-table .link{ width: 100%; height: 100%; overflow: auto; - background-color: rgba(20, 20, 20, 0.95); } .global-popup *{ @@ -590,9 +589,6 @@ table.popup-table .link{ .global-popup-close:before { content: "×"; -} - -.global-popup-close{ position: fixed; right: 0.25em; top: 0; @@ -601,10 +597,20 @@ table.popup-table .link{ font-size: 32pt; } +.global-popup-close{ + position: fixed; + left: 0; + top: 0; + width: 100%; + height: 100%; + background-color: rgba(20, 20, 20, 0.95); +} + .global-popup-inner{ display: inline-block; margin: auto; padding: 2em; + z-index: 1001; } /* fullpage image viewer */ From e5381320b99f657aadad0bf8f414108a96567b3c Mon Sep 17 00:00:00 2001 From: missionfloyd Date: Mon, 2 Oct 2023 22:33:03 -0600 Subject: [PATCH 0143/1174] Lint --- javascript/extraNetworks.js | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/javascript/extraNetworks.js b/javascript/extraNetworks.js index b1b85669b02..439e76a3f6d 100644 --- a/javascript/extraNetworks.js +++ b/javascript/extraNetworks.js @@ -226,7 +226,7 @@ function popup(contents) { if (!globalPopup) { globalPopup = document.createElement('div'); globalPopup.classList.add('global-popup'); - + var close = document.createElement('div'); close.classList.add('global-popup-close'); close.addEventListener("click", closePopup); From 7d60076b8b275771a1aa98f017aff845ef68d964 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Tue, 3 Oct 2023 16:22:32 +0300 Subject: [PATCH 0144/1174] case-insensitive search for settings --- javascript/settings.js | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/javascript/settings.js b/javascript/settings.js index 7889cf1cd9c..4e79ec003a5 100644 --- a/javascript/settings.js +++ b/javascript/settings.js @@ -24,10 +24,10 @@ onUiLoaded(function() { var settings_tabs = gradioApp().querySelector('#settings div'); onEdit('settingsSearch', editTextarea, 250, function() { - var searchText = (editTextarea.value || "").trim(); + var searchText = (editTextarea.value || "").trim().toLowerCase(); gradioApp().querySelectorAll('#settings > div[id^=settings_] div[id^=column_settings_] > *').forEach(function(elem) { - var visible = elem.textContent.trim().indexOf(searchText) != -1; + var visible = elem.textContent.trim().toLowerCase().indexOf(searchText) != -1; elem.style.display = visible ? "" : "none"; }); From 35fd24e857e625c090e9af3fdcef145b88bef436 Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Tue, 3 Oct 2023 23:05:48 +0900 Subject: [PATCH 0145/1174] Less placeholder bug_report template --- .github/ISSUE_TEMPLATE/bug_report.yml | 44 +++++---------------------- 1 file changed, 8 insertions(+), 36 deletions(-) diff --git a/.github/ISSUE_TEMPLATE/bug_report.yml b/.github/ISSUE_TEMPLATE/bug_report.yml index a423f052d16..5876e941085 100644 --- a/.github/ISSUE_TEMPLATE/bug_report.yml +++ b/.github/ISSUE_TEMPLATE/bug_report.yml @@ -39,10 +39,7 @@ body: label: What happened? description: Tell us what happened in a very clear and simple way placeholder: | - I tried to use txt2img with the XYZ grid script, with DPM++ SDE, DPM++ 2M SDE samplers. - It should generate a grid of 2 images but I only got 1. - - (add screenshot or screen recording if necessary) + txt2img is not working as intended. validations: required: true - type: textarea @@ -51,9 +48,9 @@ body: label: Steps to reproduce the problem description: Please provide us with precise step by step instructions on how to reproduce the bug placeholder: | - 1. Go to txt2img tab, select XYZ grid script - 2. Set axis type to `Sampler`, and select DPM++ 2M SDE, DPM++ 3M SDE - 3. Set `Sampling steps` to 1, click generate button + 1. Go to ... + 2. Press ... + 3. ... validations: required: true - type: textarea @@ -62,8 +59,7 @@ body: label: What should have happened? description: Tell us what you think the normal behavior should be placeholder: | - It should generate a grid of 2 images. - This was working in webui version 1.x.x + WebUI should ... validations: required: true - type: dropdown @@ -97,37 +93,13 @@ body: label: Console logs description: Please provide **full** cmd/terminal logs from the moment you started UI to the end of it, after the bug occured. If it's very long, provide a link to pastebin or similar service. render: Shell - placeholder: | - generating image for xyz plot: UnboundLocalError - Traceback (most recent call last): - File "B:\GitHub\stable-diffusion-webui\scripts\xyz_grid.py", line 698, in cell - res = process_images(pc) - File "B:\GitHub\stable-diffusion-webui\modules\processing.py", line 732, in process_images - res = process_images_inner(p) - File "B:\GitHub\stable-diffusion-webui\modules\processing.py", line 867, in process_images_inner - samples_ddim = p.sample(conditioning=p.c, unconditional_conditioning=p.uc, seeds=p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, prompts=p.prompts) - File "B:\GitHub\stable-diffusion-webui\modules\processing.py", line 1140, in sample - samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x)) - File "B:\GitHub\stable-diffusion-webui\modules\sd_samplers_kdiffusion.py", line 235, in sample - samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, x, extra_args=self.sampler_extra_args, disable=False, callback=self.callback_state, **extra_params_kwargs)) - File "B:\GitHub\stable-diffusion-webui\modules\sd_samplers_common.py", line 261, in launch_sampling - return func() - File "B:\GitHub\stable-diffusion-webui\modules\sd_samplers_kdiffusion.py", line 235, in - samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, x, extra_args=self.sampler_extra_args, disable=False, callback=self.callback_state, **extra_params_kwargs)) - File "B:\GitHub\stable-diffusion-webui\venv\lib\site-packages\torch\utils\_contextlib.py", line 115, in decorate_context - return func(*args, **kwargs) - File "B:\GitHub\stable-diffusion-webui\repositories\k-diffusion\k_diffusion\sampling.py", line 651, in sample_dpmpp_2m_sde - h_last = h - UnboundLocalError: local variable 'h' referenced before assignment validations: required: true - type: textarea id: misc attributes: label: Additional information - description: Please provide us with any relevant additional info or context. - placeholder: | + description: | + Please provide us with any relevant additional info or context. Examples: - I have updated the GPU driver recently. - I suspect the issue is caused by XXXXX. - I am using a VPN. +  I have updated my GPU driver recently. From e34949be52a89af21b2bcb0c18ca8d834e9cb562 Mon Sep 17 00:00:00 2001 From: missionfloyd Date: Fri, 6 Oct 2023 22:49:33 -0600 Subject: [PATCH 0146/1174] Edit-attention fixes --- javascript/edit-attention.js | 56 ++++++++++++++++++++++-------------- 1 file changed, 35 insertions(+), 21 deletions(-) diff --git a/javascript/edit-attention.js b/javascript/edit-attention.js index 794453bfeab..b7f3215d0b5 100644 --- a/javascript/edit-attention.js +++ b/javascript/edit-attention.js @@ -26,6 +26,7 @@ function keyupEditAttention(event) { // Set the selection to the text between the parenthesis const parenContent = text.substring(beforeParen + 1, selectionStart + afterParen); + if (!/.*:[\d.]+/.test(parenContent)) return false; const lastColon = parenContent.lastIndexOf(":"); selectionStart = beforeParen + 1; selectionEnd = selectionStart + lastColon; @@ -66,40 +67,53 @@ function keyupEditAttention(event) { var closeCharacter = ')'; var delta = opts.keyedit_precision_attention; - if (selectionStart > 0 && text[selectionStart - 1] == '<') { + if (selectionStart > 0 && /<.*:[\d.]+>/.test(text.slice(selectionStart - 1, selectionEnd + text.slice(selectionEnd).indexOf(">") + 1))) { closeCharacter = '>'; delta = opts.keyedit_precision_extra; - } else if (selectionStart == 0 || text[selectionStart - 1] != "(") { - + } else if (selectionStart > 0 && /\(.*\)|\[.*\]/.test(text.slice(selectionStart - 1, selectionEnd + 1))) { + closeCharacter = null; + if (isPlus) { + text = text.slice(0, selectionStart) + text[selectionStart - 1] + text.slice(selectionStart, selectionEnd) + text[selectionEnd] + text.slice(selectionEnd); + selectionStart++; + selectionEnd++; + } else { + text = text.slice(0, selectionStart - 1) + text.slice(selectionStart, selectionEnd) + text.slice(selectionEnd + 1); + selectionStart--; + selectionEnd--; + } + } else if (selectionStart == 0 || !/\(.*:[\d.]+\)/.test(text.slice(selectionStart - 1, selectionEnd + text.slice(selectionEnd).indexOf(")") + 1))) { // do not include spaces at the end while (selectionEnd > selectionStart && text[selectionEnd - 1] == ' ') { - selectionEnd -= 1; + selectionEnd--; } + if (selectionStart == selectionEnd) { return; } text = text.slice(0, selectionStart) + "(" + text.slice(selectionStart, selectionEnd) + ":1.0)" + text.slice(selectionEnd); - selectionStart += 1; - selectionEnd += 1; + selectionStart++; + selectionEnd++; } - var end = text.slice(selectionEnd + 1).indexOf(closeCharacter) + 1; - var weight = parseFloat(text.slice(selectionEnd + 1, selectionEnd + end)); - if (isNaN(weight)) return; - - weight += isPlus ? delta : -delta; - weight = parseFloat(weight.toPrecision(12)); - if (String(weight).length == 1) weight += ".0"; - - if (closeCharacter == ')' && weight == 1) { - var endParenPos = text.substring(selectionEnd).indexOf(')'); - text = text.slice(0, selectionStart - 1) + text.slice(selectionStart, selectionEnd) + text.slice(selectionEnd + endParenPos + 1); - selectionStart--; - selectionEnd--; - } else { - text = text.slice(0, selectionEnd + 1) + weight + text.slice(selectionEnd + end); + if (closeCharacter) { + var end = text.slice(selectionEnd + 1).indexOf(closeCharacter) + 1; + var weight = parseFloat(text.slice(selectionEnd + 1, selectionEnd + end)); + if (isNaN(weight)) return; + + weight += isPlus ? delta : -delta; + weight = parseFloat(weight.toPrecision(12)); + if (String(weight).length == 1) weight += ".0"; + + if (closeCharacter == ')' && weight == 1) { + var endParenPos = text.substring(selectionEnd).indexOf(')'); + text = text.slice(0, selectionStart - 1) + text.slice(selectionStart, selectionEnd) + text.slice(selectionEnd + endParenPos + 1); + selectionStart--; + selectionEnd--; + } else { + text = text.slice(0, selectionEnd + 1) + weight + text.slice(selectionEnd + end); + } } target.focus(); From 76010a51ef1f3805a7487723599035bc2356c3fb Mon Sep 17 00:00:00 2001 From: wangqiuwen Date: Sat, 7 Oct 2023 15:36:01 +0800 Subject: [PATCH 0147/1174] up --- modules/sd_models.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/modules/sd_models.py b/modules/sd_models.py index eedb38c65ad..3a060ab685a 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -1,4 +1,5 @@ import collections +import copy import os.path import sys import gc @@ -309,8 +310,6 @@ def get_checkpoint_state_dict(checkpoint_info: CheckpointInfo, timer): if checkpoint_info in checkpoints_loaded: # use checkpoint cache print(f"Loading weights [{sd_model_hash}] from cache") - # move to end as latest - checkpoints_loaded.move_to_end(checkpoint_info) return checkpoints_loaded[checkpoint_info] print(f"Loading weights [{sd_model_hash}] from {checkpoint_info.filename}") @@ -352,12 +351,12 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer if model.is_sdxl: sd_models_xl.extend_sdxl(model) - model.load_state_dict(state_dict, strict=False) - timer.record("apply weights to model") - if shared.opts.sd_checkpoint_cache > 0: # cache newly loaded model - checkpoints_loaded[checkpoint_info] = state_dict + checkpoints_loaded[checkpoint_info] = copy.deepcopy(state_dict) + + model.load_state_dict(state_dict, strict=False) + timer.record("apply weights to model") del state_dict From 770ee23f18d12fb3b5627c636aa420f481e292ee Mon Sep 17 00:00:00 2001 From: wangqiuwen Date: Sat, 7 Oct 2023 15:38:50 +0800 Subject: [PATCH 0148/1174] reverst --- modules/sd_models.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/modules/sd_models.py b/modules/sd_models.py index 3a060ab685a..8d63e7f10f7 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -310,6 +310,8 @@ def get_checkpoint_state_dict(checkpoint_info: CheckpointInfo, timer): if checkpoint_info in checkpoints_loaded: # use checkpoint cache print(f"Loading weights [{sd_model_hash}] from cache") + # move to end as latest + checkpoints_loaded.move_to_end(checkpoint_info) return checkpoints_loaded[checkpoint_info] print(f"Loading weights [{sd_model_hash}] from {checkpoint_info.filename}") From 09a2da835ef9c3dd12f3d842fbd5dd9e19ded859 Mon Sep 17 00:00:00 2001 From: missionfloyd Date: Sat, 7 Oct 2023 14:48:43 -0600 Subject: [PATCH 0149/1174] Add brackets, vertical bar to default delimiters --- modules/shared_options.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/shared_options.py b/modules/shared_options.py index ab9b0072861..4824323443c 100644 --- a/modules/shared_options.py +++ b/modules/shared_options.py @@ -256,7 +256,7 @@ "dimensions_and_batch_together": OptionInfo(True, "Show Width/Height and Batch sliders in same row").needs_reload_ui(), "keyedit_precision_attention": OptionInfo(0.1, "Ctrl+up/down precision when editing (attention:1.1)", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}), "keyedit_precision_extra": OptionInfo(0.05, "Ctrl+up/down precision when editing ", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}), - "keyedit_delimiters": OptionInfo(r".,\/!?%^*;:{}=`~() ", "Ctrl+up/down word delimiters"), + "keyedit_delimiters": OptionInfo(r".,\/!?%^*;:{}=`~()[]<>| ", "Ctrl+up/down word delimiters"), "keyedit_delimiters_whitespace": OptionInfo(["Tab", "Carriage Return", "Line Feed"], "Ctrl+up/down whitespace delimiters", gr.CheckboxGroup, lambda: {"choices": ["Tab", "Carriage Return", "Line Feed"]}), "keyedit_move": OptionInfo(True, "Alt+left/right moves prompt elements"), "quicksettings_list": OptionInfo(["sd_model_checkpoint"], "Quicksettings list", ui_components.DropdownMulti, lambda: {"choices": list(shared.opts.data_labels.keys())}).js("info", "settingsHintsShowQuicksettings").info("setting entries that appear at the top of page rather than in settings tab").needs_reload_ui(), From fd51b8501e1d9f9380c948cbf665c7708baef5d6 Mon Sep 17 00:00:00 2001 From: missionfloyd Date: Sat, 7 Oct 2023 15:28:25 -0600 Subject: [PATCH 0150/1174] Fix multi-line selections --- javascript/edit-attention.js | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/javascript/edit-attention.js b/javascript/edit-attention.js index b7f3215d0b5..d7b6001b5db 100644 --- a/javascript/edit-attention.js +++ b/javascript/edit-attention.js @@ -26,7 +26,7 @@ function keyupEditAttention(event) { // Set the selection to the text between the parenthesis const parenContent = text.substring(beforeParen + 1, selectionStart + afterParen); - if (!/.*:[\d.]+/.test(parenContent)) return false; + if (!/.*:[\d.]+/s.test(parenContent)) return false; const lastColon = parenContent.lastIndexOf(":"); selectionStart = beforeParen + 1; selectionEnd = selectionStart + lastColon; @@ -67,10 +67,10 @@ function keyupEditAttention(event) { var closeCharacter = ')'; var delta = opts.keyedit_precision_attention; - if (selectionStart > 0 && /<.*:[\d.]+>/.test(text.slice(selectionStart - 1, selectionEnd + text.slice(selectionEnd).indexOf(">") + 1))) { + if (selectionStart > 0 && /<.*:[\d.]+>/s.test(text.slice(selectionStart - 1, selectionEnd + text.slice(selectionEnd).indexOf(">") + 1))) { closeCharacter = '>'; delta = opts.keyedit_precision_extra; - } else if (selectionStart > 0 && /\(.*\)|\[.*\]/.test(text.slice(selectionStart - 1, selectionEnd + 1))) { + } else if (selectionStart > 0 && /\(.*\)|\[.*\]/s.test(text.slice(selectionStart - 1, selectionEnd + 1))) { closeCharacter = null; if (isPlus) { text = text.slice(0, selectionStart) + text[selectionStart - 1] + text.slice(selectionStart, selectionEnd) + text[selectionEnd] + text.slice(selectionEnd); @@ -81,7 +81,7 @@ function keyupEditAttention(event) { selectionStart--; selectionEnd--; } - } else if (selectionStart == 0 || !/\(.*:[\d.]+\)/.test(text.slice(selectionStart - 1, selectionEnd + text.slice(selectionEnd).indexOf(")") + 1))) { + } else if (selectionStart == 0 || !/\(.*:[\d.]+\)/s.test(text.slice(selectionStart - 1, selectionEnd + text.slice(selectionEnd).indexOf(")") + 1))) { // do not include spaces at the end while (selectionEnd > selectionStart && text[selectionEnd - 1] == ' ') { selectionEnd--; From 3562b0dc7427e92828001cd713ff47a83255ccc8 Mon Sep 17 00:00:00 2001 From: missionfloyd Date: Sat, 7 Oct 2023 15:52:16 -0600 Subject: [PATCH 0151/1174] Fix negative values --- javascript/edit-attention.js | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/javascript/edit-attention.js b/javascript/edit-attention.js index d7b6001b5db..89b37aaf41a 100644 --- a/javascript/edit-attention.js +++ b/javascript/edit-attention.js @@ -26,7 +26,7 @@ function keyupEditAttention(event) { // Set the selection to the text between the parenthesis const parenContent = text.substring(beforeParen + 1, selectionStart + afterParen); - if (!/.*:[\d.]+/s.test(parenContent)) return false; + if (!/.*:-?[\d.]+/s.test(parenContent)) return false; const lastColon = parenContent.lastIndexOf(":"); selectionStart = beforeParen + 1; selectionEnd = selectionStart + lastColon; @@ -67,7 +67,7 @@ function keyupEditAttention(event) { var closeCharacter = ')'; var delta = opts.keyedit_precision_attention; - if (selectionStart > 0 && /<.*:[\d.]+>/s.test(text.slice(selectionStart - 1, selectionEnd + text.slice(selectionEnd).indexOf(">") + 1))) { + if (selectionStart > 0 && /<.*:-?[\d.]+>/s.test(text.slice(selectionStart - 1, selectionEnd + text.slice(selectionEnd).indexOf(">") + 1))) { closeCharacter = '>'; delta = opts.keyedit_precision_extra; } else if (selectionStart > 0 && /\(.*\)|\[.*\]/s.test(text.slice(selectionStart - 1, selectionEnd + 1))) { @@ -81,7 +81,7 @@ function keyupEditAttention(event) { selectionStart--; selectionEnd--; } - } else if (selectionStart == 0 || !/\(.*:[\d.]+\)/s.test(text.slice(selectionStart - 1, selectionEnd + text.slice(selectionEnd).indexOf(")") + 1))) { + } else if (selectionStart == 0 || !/\(.*:-?[\d.]+\)/s.test(text.slice(selectionStart - 1, selectionEnd + text.slice(selectionEnd).indexOf(")") + 1))) { // do not include spaces at the end while (selectionEnd > selectionStart && text[selectionEnd - 1] == ' ') { selectionEnd--; From 9821625a76177ebc8b62a1ee6d8ef39cf4805f99 Mon Sep 17 00:00:00 2001 From: Leon Date: Mon, 9 Oct 2023 18:36:48 +0800 Subject: [PATCH 0152/1174] fix the key error exception when adding an overwriting key which is defined in the extensions --- modules/processing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/processing.py b/modules/processing.py index 36bc94f7867..fee2440ff58 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -711,7 +711,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed: if p.scripts is not None: p.scripts.before_process(p) - stored_opts = {k: opts.data[k] for k in p.override_settings.keys() if k in opts.data} + stored_opts = {k: opts.data[k] if k in opts.data else opts.get_default(k) for k in p.override_settings.keys() if k in opts.data} try: # if no checkpoint override or the override checkpoint can't be found, remove override entry and load opts checkpoint From 2aa485b5afb13fd6aab79777e4dfc488591b2f1c Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Mon, 9 Oct 2023 22:52:09 +0800 Subject: [PATCH 0153/1174] add lora bundle system --- extensions-builtin/Lora/network.py | 1 + extensions-builtin/Lora/networks.py | 48 +++++++++++++++++++++++++++++ 2 files changed, 49 insertions(+) diff --git a/extensions-builtin/Lora/network.py b/extensions-builtin/Lora/network.py index d8e8dfb7ff0..6021fd8de0f 100644 --- a/extensions-builtin/Lora/network.py +++ b/extensions-builtin/Lora/network.py @@ -93,6 +93,7 @@ def __init__(self, name, network_on_disk: NetworkOnDisk): self.unet_multiplier = 1.0 self.dyn_dim = None self.modules = {} + self.bundle_embeddings = {} self.mtime = None self.mentioned_name = None diff --git a/extensions-builtin/Lora/networks.py b/extensions-builtin/Lora/networks.py index 315682b31b0..652b8ebedfd 100644 --- a/extensions-builtin/Lora/networks.py +++ b/extensions-builtin/Lora/networks.py @@ -15,6 +15,7 @@ from typing import Union from modules import shared, devices, sd_models, errors, scripts, sd_hijack +from modules.textual_inversion.textual_inversion import Embedding module_types = [ network_lora.ModuleTypeLora(), @@ -149,9 +150,15 @@ def load_network(name, network_on_disk): is_sd2 = 'model_transformer_resblocks' in shared.sd_model.network_layer_mapping matched_networks = {} + bundle_embeddings = {} for key_network, weight in sd.items(): key_network_without_network_parts, network_part = key_network.split(".", 1) + if key_network_without_network_parts == "bundle_emb": + emb_name, vec_name = network_part.split(".", 1) + emb_dict = bundle_embeddings.get(emb_name, {}) + emb_dict[vec_name] = weight + bundle_embeddings[emb_name] = emb_dict key = convert_diffusers_name_to_compvis(key_network_without_network_parts, is_sd2) sd_module = shared.sd_model.network_layer_mapping.get(key, None) @@ -195,6 +202,8 @@ def load_network(name, network_on_disk): net.modules[key] = net_module + net.bundle_embeddings = bundle_embeddings + if keys_failed_to_match: logging.debug(f"Network {network_on_disk.filename} didn't match keys: {keys_failed_to_match}") @@ -210,11 +219,14 @@ def purge_networks_from_memory(): def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=None): + emb_db = sd_hijack.model_hijack.embedding_db already_loaded = {} for net in loaded_networks: if net.name in names: already_loaded[net.name] = net + for emb_name in net.bundle_embeddings: + emb_db.register_embedding_by_name(None, shared.sd_model, emb_name) loaded_networks.clear() @@ -257,6 +269,41 @@ def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=No net.dyn_dim = dyn_dims[i] if dyn_dims else 1.0 loaded_networks.append(net) + for emb_name, data in net.bundle_embeddings.items(): + # textual inversion embeddings + if 'string_to_param' in data: + param_dict = data['string_to_param'] + param_dict = getattr(param_dict, '_parameters', param_dict) # fix for torch 1.12.1 loading saved file from torch 1.11 + assert len(param_dict) == 1, 'embedding file has multiple terms in it' + emb = next(iter(param_dict.items()))[1] + vec = emb.detach().to(devices.device, dtype=torch.float32) + shape = vec.shape[-1] + vectors = vec.shape[0] + elif type(data) == dict and 'clip_g' in data and 'clip_l' in data: # SDXL embedding + vec = {k: v.detach().to(devices.device, dtype=torch.float32) for k, v in data.items()} + shape = data['clip_g'].shape[-1] + data['clip_l'].shape[-1] + vectors = data['clip_g'].shape[0] + elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor: # diffuser concepts + assert len(data.keys()) == 1, 'embedding file has multiple terms in it' + + emb = next(iter(data.values())) + if len(emb.shape) == 1: + emb = emb.unsqueeze(0) + vec = emb.detach().to(devices.device, dtype=torch.float32) + shape = vec.shape[-1] + vectors = vec.shape[0] + else: + raise Exception(f"Couldn't identify {emb_name} in lora: {name} as neither textual inversion embedding nor diffuser concept.") + + embedding = Embedding(vec, emb_name) + embedding.vectors = vectors + embedding.shape = shape + + if emb_db.expected_shape == -1 or emb_db.expected_shape == embedding.shape: + emb_db.register_embedding(embedding, shared.sd_model) + else: + emb_db.skipped_embeddings[name] = embedding + if failed_to_load_networks: sd_hijack.model_hijack.comments.append("Networks not found: " + ", ".join(failed_to_load_networks)) @@ -565,6 +612,7 @@ def infotext_pasted(infotext, params): available_networks = {} available_network_aliases = {} loaded_networks = [] +loaded_bundle_embeddings = {} networks_in_memory = {} available_network_hash_lookup = {} forbidden_network_aliases = {} From 3d8b1af6beb9015f6b3573661d8ed00275f6129f Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Tue, 10 Oct 2023 12:09:33 +0800 Subject: [PATCH 0154/1174] Support string_to_param nested dict format: bundle_emb.EMBNAME.string_to_param.KEYNAME --- extensions-builtin/Lora/networks.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/extensions-builtin/Lora/networks.py b/extensions-builtin/Lora/networks.py index 652b8ebedfd..ab3517d8163 100644 --- a/extensions-builtin/Lora/networks.py +++ b/extensions-builtin/Lora/networks.py @@ -157,7 +157,11 @@ def load_network(name, network_on_disk): if key_network_without_network_parts == "bundle_emb": emb_name, vec_name = network_part.split(".", 1) emb_dict = bundle_embeddings.get(emb_name, {}) - emb_dict[vec_name] = weight + if vec_name.split('.')[0] == 'string_to_param': + _, k2 = vec_name.split('.', 1) + emb_dict['string_to_param'] = {k2: weight} + else: + emb_dict[vec_name] = weight bundle_embeddings[emb_name] = emb_dict key = convert_diffusers_name_to_compvis(key_network_without_network_parts, is_sd2) @@ -301,6 +305,7 @@ def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=No if emb_db.expected_shape == -1 or emb_db.expected_shape == embedding.shape: emb_db.register_embedding(embedding, shared.sd_model) + print(f'registered bundle embedding: {embedding.name}') else: emb_db.skipped_embeddings[name] = embedding From 2282eb8dd5905e8ed71231a0b8fc77599d10c12f Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Tue, 10 Oct 2023 12:11:00 +0800 Subject: [PATCH 0155/1174] Remove dev debug print --- extensions-builtin/Lora/networks.py | 1 - 1 file changed, 1 deletion(-) diff --git a/extensions-builtin/Lora/networks.py b/extensions-builtin/Lora/networks.py index ab3517d8163..465e24c8c77 100644 --- a/extensions-builtin/Lora/networks.py +++ b/extensions-builtin/Lora/networks.py @@ -305,7 +305,6 @@ def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=No if emb_db.expected_shape == -1 or emb_db.expected_shape == embedding.shape: emb_db.register_embedding(embedding, shared.sd_model) - print(f'registered bundle embedding: {embedding.name}') else: emb_db.skipped_embeddings[name] = embedding From 81e94de3185d42dba4e7bb72cf836f683f28b03f Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Tue, 10 Oct 2023 14:44:20 +0800 Subject: [PATCH 0156/1174] Add warning when meet emb name conflicting Choose standalone embedding (in /embeddings folder) first --- extensions-builtin/Lora/lora_logger.py | 33 +++++++++++ extensions-builtin/Lora/networks.py | 80 +++++++++++++++----------- 2 files changed, 81 insertions(+), 32 deletions(-) create mode 100644 extensions-builtin/Lora/lora_logger.py diff --git a/extensions-builtin/Lora/lora_logger.py b/extensions-builtin/Lora/lora_logger.py new file mode 100644 index 00000000000..d50e90f0940 --- /dev/null +++ b/extensions-builtin/Lora/lora_logger.py @@ -0,0 +1,33 @@ +import sys +import copy +import logging + + +class ColoredFormatter(logging.Formatter): + COLORS = { + "DEBUG": "\033[0;36m", # CYAN + "INFO": "\033[0;32m", # GREEN + "WARNING": "\033[0;33m", # YELLOW + "ERROR": "\033[0;31m", # RED + "CRITICAL": "\033[0;37;41m", # WHITE ON RED + "RESET": "\033[0m", # RESET COLOR + } + + def format(self, record): + colored_record = copy.copy(record) + levelname = colored_record.levelname + seq = self.COLORS.get(levelname, self.COLORS["RESET"]) + colored_record.levelname = f"{seq}{levelname}{self.COLORS['RESET']}" + return super().format(colored_record) + + +logger = logging.getLogger("lora") +logger.propagate = False + + +if not logger.handlers: + handler = logging.StreamHandler(sys.stdout) + handler.setFormatter( + ColoredFormatter("[%(name)s]-%(levelname)s: %(message)s") + ) + logger.addHandler(handler) \ No newline at end of file diff --git a/extensions-builtin/Lora/networks.py b/extensions-builtin/Lora/networks.py index 465e24c8c77..12f70576981 100644 --- a/extensions-builtin/Lora/networks.py +++ b/extensions-builtin/Lora/networks.py @@ -17,6 +17,8 @@ from modules import shared, devices, sd_models, errors, scripts, sd_hijack from modules.textual_inversion.textual_inversion import Embedding +from lora_logger import logger + module_types = [ network_lora.ModuleTypeLora(), network_hada.ModuleTypeHada(), @@ -206,7 +208,40 @@ def load_network(name, network_on_disk): net.modules[key] = net_module - net.bundle_embeddings = bundle_embeddings + embeddings = {} + for emb_name, data in bundle_embeddings.items(): + # textual inversion embeddings + if 'string_to_param' in data: + param_dict = data['string_to_param'] + param_dict = getattr(param_dict, '_parameters', param_dict) # fix for torch 1.12.1 loading saved file from torch 1.11 + assert len(param_dict) == 1, 'embedding file has multiple terms in it' + emb = next(iter(param_dict.items()))[1] + vec = emb.detach().to(devices.device, dtype=torch.float32) + shape = vec.shape[-1] + vectors = vec.shape[0] + elif type(data) == dict and 'clip_g' in data and 'clip_l' in data: # SDXL embedding + vec = {k: v.detach().to(devices.device, dtype=torch.float32) for k, v in data.items()} + shape = data['clip_g'].shape[-1] + data['clip_l'].shape[-1] + vectors = data['clip_g'].shape[0] + elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor: # diffuser concepts + assert len(data.keys()) == 1, 'embedding file has multiple terms in it' + + emb = next(iter(data.values())) + if len(emb.shape) == 1: + emb = emb.unsqueeze(0) + vec = emb.detach().to(devices.device, dtype=torch.float32) + shape = vec.shape[-1] + vectors = vec.shape[0] + else: + raise Exception(f"Couldn't identify {emb_name} in lora: {name} as neither textual inversion embedding nor diffuser concept.") + + embedding = Embedding(vec, emb_name) + embedding.vectors = vectors + embedding.shape = shape + embedding.loaded = None + embeddings[emb_name] = embedding + + net.bundle_embeddings = embeddings if keys_failed_to_match: logging.debug(f"Network {network_on_disk.filename} didn't match keys: {keys_failed_to_match}") @@ -229,8 +264,9 @@ def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=No for net in loaded_networks: if net.name in names: already_loaded[net.name] = net - for emb_name in net.bundle_embeddings: - emb_db.register_embedding_by_name(None, shared.sd_model, emb_name) + for emb_name, embedding in net.bundle_embeddings.items(): + if embedding.loaded: + emb_db.register_embedding_by_name(None, shared.sd_model, emb_name) loaded_networks.clear() @@ -273,37 +309,17 @@ def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=No net.dyn_dim = dyn_dims[i] if dyn_dims else 1.0 loaded_networks.append(net) - for emb_name, data in net.bundle_embeddings.items(): - # textual inversion embeddings - if 'string_to_param' in data: - param_dict = data['string_to_param'] - param_dict = getattr(param_dict, '_parameters', param_dict) # fix for torch 1.12.1 loading saved file from torch 1.11 - assert len(param_dict) == 1, 'embedding file has multiple terms in it' - emb = next(iter(param_dict.items()))[1] - vec = emb.detach().to(devices.device, dtype=torch.float32) - shape = vec.shape[-1] - vectors = vec.shape[0] - elif type(data) == dict and 'clip_g' in data and 'clip_l' in data: # SDXL embedding - vec = {k: v.detach().to(devices.device, dtype=torch.float32) for k, v in data.items()} - shape = data['clip_g'].shape[-1] + data['clip_l'].shape[-1] - vectors = data['clip_g'].shape[0] - elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor: # diffuser concepts - assert len(data.keys()) == 1, 'embedding file has multiple terms in it' - - emb = next(iter(data.values())) - if len(emb.shape) == 1: - emb = emb.unsqueeze(0) - vec = emb.detach().to(devices.device, dtype=torch.float32) - shape = vec.shape[-1] - vectors = vec.shape[0] - else: - raise Exception(f"Couldn't identify {emb_name} in lora: {name} as neither textual inversion embedding nor diffuser concept.") - - embedding = Embedding(vec, emb_name) - embedding.vectors = vectors - embedding.shape = shape + for emb_name, embedding in net.bundle_embeddings.items(): + if embedding.loaded is None and emb_name in emb_db.word_embeddings: + logger.warning( + f'Skip bundle embedding: "{emb_name}"' + ' as it was already loaded from embeddings folder' + ) + continue + embedding.loaded = False if emb_db.expected_shape == -1 or emb_db.expected_shape == embedding.shape: + embedding.loaded = True emb_db.register_embedding(embedding, shared.sd_model) else: emb_db.skipped_embeddings[name] = embedding From 891ccb767c3815db48a124677d1cd0f204018ad4 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Tue, 10 Oct 2023 15:07:25 +0800 Subject: [PATCH 0157/1174] Fix lint --- extensions-builtin/Lora/lora_logger.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/extensions-builtin/Lora/lora_logger.py b/extensions-builtin/Lora/lora_logger.py index d50e90f0940..d51de29704f 100644 --- a/extensions-builtin/Lora/lora_logger.py +++ b/extensions-builtin/Lora/lora_logger.py @@ -30,4 +30,4 @@ def format(self, record): handler.setFormatter( ColoredFormatter("[%(name)s]-%(levelname)s: %(message)s") ) - logger.addHandler(handler) \ No newline at end of file + logger.addHandler(handler) From dbb10fbd8c2dd4f3ca83a1d2e15e188799074ce4 Mon Sep 17 00:00:00 2001 From: Won-Kyu Park Date: Sun, 1 Oct 2023 20:18:25 +0900 Subject: [PATCH 0158/1174] show the preview image in the modalview if available --- javascript/imageviewer.js | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/javascript/imageviewer.js b/javascript/imageviewer.js index c21d396eefd..e4dae91bc9d 100644 --- a/javascript/imageviewer.js +++ b/javascript/imageviewer.js @@ -33,8 +33,11 @@ function updateOnBackgroundChange() { const modalImage = gradioApp().getElementById("modalImage"); if (modalImage && modalImage.offsetParent) { let currentButton = selected_gallery_button(); - - if (currentButton?.children?.length > 0 && modalImage.src != currentButton.children[0].src) { + let preview = gradioApp().querySelectorAll('.livePreview > img'); + if (preview.length > 0) { + // show preview image if available + modalImage.src = preview[preview.length - 1].src; + } else if (currentButton?.children?.length > 0 && modalImage.src != currentButton.children[0].src) { modalImage.src = currentButton.children[0].src; if (modalImage.style.display === 'none') { const modal = gradioApp().getElementById("lightboxModal"); From 906d1179e9a333eeb0f12a95b592dd5b44eb0aaa Mon Sep 17 00:00:00 2001 From: v0xie <28695009+v0xie@users.noreply.github.com> Date: Wed, 11 Oct 2023 21:26:58 -0700 Subject: [PATCH 0159/1174] support inference with LyCORIS GLora networks --- extensions-builtin/Lora/network_glora.py | 33 ++++++++++++++++++++++++ extensions-builtin/Lora/networks.py | 2 ++ 2 files changed, 35 insertions(+) create mode 100644 extensions-builtin/Lora/network_glora.py diff --git a/extensions-builtin/Lora/network_glora.py b/extensions-builtin/Lora/network_glora.py new file mode 100644 index 00000000000..492d487078d --- /dev/null +++ b/extensions-builtin/Lora/network_glora.py @@ -0,0 +1,33 @@ + +import network + +class ModuleTypeGLora(network.ModuleType): + def create_module(self, net: network.Network, weights: network.NetworkWeights): + if all(x in weights.w for x in ["a1.weight", "a2.weight", "alpha", "b1.weight", "b2.weight"]): + return NetworkModuleGLora(net, weights) + + return None + +# adapted from https://github.com/KohakuBlueleaf/LyCORIS +class NetworkModuleGLora(network.NetworkModule): + def __init__(self, net: network.Network, weights: network.NetworkWeights): + super().__init__(net, weights) + + if hasattr(self.sd_module, 'weight'): + self.shape = self.sd_module.weight.shape + + self.w1a = weights.w["a1.weight"] + self.w1b = weights.w["b1.weight"] + self.w2a = weights.w["a2.weight"] + self.w2b = weights.w["b2.weight"] + + def calc_updown(self, orig_weight): + w1a = self.w1a.to(orig_weight.device, dtype=orig_weight.dtype) + w1b = self.w1b.to(orig_weight.device, dtype=orig_weight.dtype) + w2a = self.w2a.to(orig_weight.device, dtype=orig_weight.dtype) + w2b = self.w2b.to(orig_weight.device, dtype=orig_weight.dtype) + + output_shape = [w1a.size(0), w1b.size(1)] + updown = ((w2b @ w1b) + ((orig_weight @ w2a) @ w1a)) + + return self.finalize_updown(updown, orig_weight, output_shape) diff --git a/extensions-builtin/Lora/networks.py b/extensions-builtin/Lora/networks.py index 315682b31b0..ddab3c55e4b 100644 --- a/extensions-builtin/Lora/networks.py +++ b/extensions-builtin/Lora/networks.py @@ -5,6 +5,7 @@ import lora_patches import network import network_lora +import network_glora import network_hada import network_ia3 import network_lokr @@ -23,6 +24,7 @@ network_lokr.ModuleTypeLokr(), network_full.ModuleTypeFull(), network_norm.ModuleTypeNorm(), + network_glora.ModuleTypeGLora(), ] From fbc8d213546047d8970b92809e15b33e8a1301be Mon Sep 17 00:00:00 2001 From: Won-Kyu Park Date: Sat, 14 Oct 2023 02:39:04 +0900 Subject: [PATCH 0160/1174] fix IndexError: list index out of range error interrupted while postprocess --- modules/processing.py | 1 + 1 file changed, 1 insertion(+) diff --git a/modules/processing.py b/modules/processing.py index 36bc94f7867..df037fb0ea7 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -820,6 +820,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: state.skipped = False if state.interrupted: + infotexts.append(Processed(p, []).infotext(p, 0)) break sd_models.reload_model_weights() # model can be changed for example by refiner From 44d14bc32e4a5501df04f844a00f9e18b777f7eb Mon Sep 17 00:00:00 2001 From: Gleb Alekseev Date: Fri, 13 Oct 2023 15:08:59 -0300 Subject: [PATCH 0161/1174] added option to play notification sound or not --- modules/shared_options.py | 1 + modules/ui.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/modules/shared_options.py b/modules/shared_options.py index 00b273faa54..afcbf9b858d 100644 --- a/modules/shared_options.py +++ b/modules/shared_options.py @@ -22,6 +22,7 @@ } options_templates.update(options_section(('saving-images', "Saving images/grids"), { + "notification_audio": OptionInfo(True, "Play notification sound after image generation", comment_after="(notification.mp3 should be present in the root directory)").needs_reload_ui(), "samples_save": OptionInfo(True, "Always save all generated images"), "samples_format": OptionInfo('png', 'File format for images'), "samples_filename_pattern": OptionInfo("", "Images filename pattern", component_args=hide_dirs).link("wiki", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Custom-Images-Filename-Name-and-Subdirectory"), diff --git a/modules/ui.py b/modules/ui.py index 579bab9800c..df327891547 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1286,7 +1286,7 @@ def get_textual_inversion_template_names(): loadsave.setup_ui() - if os.path.exists(os.path.join(script_path, "notification.mp3")): + if os.path.exists(os.path.join(script_path, "notification.mp3")) and shared.opts.notification_audio: gr.Audio(interactive=False, value=os.path.join(script_path, "notification.mp3"), elem_id="audio_notification", visible=False) footer = shared.html("footer.html") From 954499a49409582085ed288b94b837ecae7ff86c Mon Sep 17 00:00:00 2001 From: missionfloyd Date: Fri, 13 Oct 2023 16:44:01 -0600 Subject: [PATCH 0162/1174] Convert (emphasis) to (emphasis:1.1) per @SirVeggie's suggestion --- javascript/edit-attention.js | 55 ++++++++++++++++++++---------------- 1 file changed, 30 insertions(+), 25 deletions(-) diff --git a/javascript/edit-attention.js b/javascript/edit-attention.js index 89b37aaf41a..ee48eb99cb6 100644 --- a/javascript/edit-attention.js +++ b/javascript/edit-attention.js @@ -71,16 +71,23 @@ function keyupEditAttention(event) { closeCharacter = '>'; delta = opts.keyedit_precision_extra; } else if (selectionStart > 0 && /\(.*\)|\[.*\]/s.test(text.slice(selectionStart - 1, selectionEnd + 1))) { - closeCharacter = null; - if (isPlus) { - text = text.slice(0, selectionStart) + text[selectionStart - 1] + text.slice(selectionStart, selectionEnd) + text[selectionEnd] + text.slice(selectionEnd); - selectionStart++; - selectionEnd++; + let start = text[selectionStart - 1]; + let end = text[selectionEnd]; + let numParen = 0; + + while (text[selectionStart - numParen - 1] == start && text[selectionEnd + numParen] == end) { + numParen++; + } + + if (start == "(") { + weight = 1.1 ** numParen; } else { - text = text.slice(0, selectionStart - 1) + text.slice(selectionStart, selectionEnd) + text.slice(selectionEnd + 1); - selectionStart--; - selectionEnd--; + weight = 0.9 ** numParen; } + + text = text.slice(0, selectionStart - numParen) + "(" + text.slice(selectionStart, selectionEnd) + ":" + weight + ")" + text.slice(selectionEnd + numParen); + selectionStart -= numParen - 1; + selectionEnd -= numParen - 1; } else if (selectionStart == 0 || !/\(.*:-?[\d.]+\)/s.test(text.slice(selectionStart - 1, selectionEnd + text.slice(selectionEnd).indexOf(")") + 1))) { // do not include spaces at the end while (selectionEnd > selectionStart && text[selectionEnd - 1] == ' ') { @@ -97,23 +104,21 @@ function keyupEditAttention(event) { selectionEnd++; } - if (closeCharacter) { - var end = text.slice(selectionEnd + 1).indexOf(closeCharacter) + 1; - var weight = parseFloat(text.slice(selectionEnd + 1, selectionEnd + end)); - if (isNaN(weight)) return; - - weight += isPlus ? delta : -delta; - weight = parseFloat(weight.toPrecision(12)); - if (String(weight).length == 1) weight += ".0"; - - if (closeCharacter == ')' && weight == 1) { - var endParenPos = text.substring(selectionEnd).indexOf(')'); - text = text.slice(0, selectionStart - 1) + text.slice(selectionStart, selectionEnd) + text.slice(selectionEnd + endParenPos + 1); - selectionStart--; - selectionEnd--; - } else { - text = text.slice(0, selectionEnd + 1) + weight + text.slice(selectionEnd + end); - } + var end = text.slice(selectionEnd + 1).indexOf(closeCharacter) + 1; + var weight = parseFloat(text.slice(selectionEnd + 1, selectionEnd + end)); + if (isNaN(weight)) return; + + weight += isPlus ? delta : -delta; + weight = parseFloat(weight.toPrecision(12)); + if (Number.isInteger(weight)) weight += ".0"; + + if (closeCharacter == ')' && weight == 1) { + var endParenPos = text.substring(selectionEnd).indexOf(')'); + text = text.slice(0, selectionStart - 1) + text.slice(selectionStart, selectionEnd) + text.slice(selectionEnd + endParenPos + 1); + selectionStart--; + selectionEnd--; + } else { + text = text.slice(0, selectionEnd + 1) + weight + text.slice(selectionEnd + end); } target.focus(); From fff1a0c74fb769701ade393cda005bf5975ec5f4 Mon Sep 17 00:00:00 2001 From: missionfloyd Date: Fri, 13 Oct 2023 17:18:02 -0600 Subject: [PATCH 0163/1174] Make attention conversion optional Fix square brackets multiplier --- javascript/edit-attention.js | 71 ++++++++++++++++++++++-------------- modules/shared_options.py | 2 +- 2 files changed, 44 insertions(+), 29 deletions(-) diff --git a/javascript/edit-attention.js b/javascript/edit-attention.js index ee48eb99cb6..8b6dd2406fc 100644 --- a/javascript/edit-attention.js +++ b/javascript/edit-attention.js @@ -73,21 +73,34 @@ function keyupEditAttention(event) { } else if (selectionStart > 0 && /\(.*\)|\[.*\]/s.test(text.slice(selectionStart - 1, selectionEnd + 1))) { let start = text[selectionStart - 1]; let end = text[selectionEnd]; - let numParen = 0; - - while (text[selectionStart - numParen - 1] == start && text[selectionEnd + numParen] == end) { - numParen++; - } - - if (start == "(") { - weight = 1.1 ** numParen; + if (opts.keyedit_convert) { + let numParen = 0; + + while (text[selectionStart - numParen - 1] == start && text[selectionEnd + numParen] == end) { + numParen++; + } + + if (start == "(") { + weight = 1.1 ** numParen; + } else { + weight = (1 / 1.1) ** numParen; + } + + text = text.slice(0, selectionStart - numParen) + "(" + text.slice(selectionStart, selectionEnd) + ":" + weight + ")" + text.slice(selectionEnd + numParen); + selectionStart -= numParen - 1; + selectionEnd -= numParen - 1; } else { - weight = 0.9 ** numParen; + closeCharacter = null; + if (isPlus) { + text = text.slice(0, selectionStart) + start + text.slice(selectionStart, selectionEnd) + end + text.slice(selectionEnd); + selectionStart++; + selectionEnd++; + } else { + text = text.slice(0, selectionStart - 1) + text.slice(selectionStart, selectionEnd) + text.slice(selectionEnd + 1); + selectionStart--; + selectionEnd--; + } } - - text = text.slice(0, selectionStart - numParen) + "(" + text.slice(selectionStart, selectionEnd) + ":" + weight + ")" + text.slice(selectionEnd + numParen); - selectionStart -= numParen - 1; - selectionEnd -= numParen - 1; } else if (selectionStart == 0 || !/\(.*:-?[\d.]+\)/s.test(text.slice(selectionStart - 1, selectionEnd + text.slice(selectionEnd).indexOf(")") + 1))) { // do not include spaces at the end while (selectionEnd > selectionStart && text[selectionEnd - 1] == ' ') { @@ -104,21 +117,23 @@ function keyupEditAttention(event) { selectionEnd++; } - var end = text.slice(selectionEnd + 1).indexOf(closeCharacter) + 1; - var weight = parseFloat(text.slice(selectionEnd + 1, selectionEnd + end)); - if (isNaN(weight)) return; - - weight += isPlus ? delta : -delta; - weight = parseFloat(weight.toPrecision(12)); - if (Number.isInteger(weight)) weight += ".0"; - - if (closeCharacter == ')' && weight == 1) { - var endParenPos = text.substring(selectionEnd).indexOf(')'); - text = text.slice(0, selectionStart - 1) + text.slice(selectionStart, selectionEnd) + text.slice(selectionEnd + endParenPos + 1); - selectionStart--; - selectionEnd--; - } else { - text = text.slice(0, selectionEnd + 1) + weight + text.slice(selectionEnd + end); + if (closeCharacter) { + var end = text.slice(selectionEnd + 1).indexOf(closeCharacter) + 1; + var weight = parseFloat(text.slice(selectionEnd + 1, selectionEnd + end)); + if (isNaN(weight)) return; + + weight += isPlus ? delta : -delta; + weight = parseFloat(weight.toPrecision(12)); + if (Number.isInteger(weight)) weight += ".0"; + + if (closeCharacter == ')' && weight == 1) { + var endParenPos = text.substring(selectionEnd).indexOf(')'); + text = text.slice(0, selectionStart - 1) + text.slice(selectionStart, selectionEnd) + text.slice(selectionEnd + endParenPos + 1); + selectionStart--; + selectionEnd--; + } else { + text = text.slice(0, selectionEnd + 1) + weight + text.slice(selectionEnd + end); + } } target.focus(); diff --git a/modules/shared_options.py b/modules/shared_options.py index 4824323443c..a674d3da2dc 100644 --- a/modules/shared_options.py +++ b/modules/shared_options.py @@ -258,6 +258,7 @@ "keyedit_precision_extra": OptionInfo(0.05, "Ctrl+up/down precision when editing ", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}), "keyedit_delimiters": OptionInfo(r".,\/!?%^*;:{}=`~()[]<>| ", "Ctrl+up/down word delimiters"), "keyedit_delimiters_whitespace": OptionInfo(["Tab", "Carriage Return", "Line Feed"], "Ctrl+up/down whitespace delimiters", gr.CheckboxGroup, lambda: {"choices": ["Tab", "Carriage Return", "Line Feed"]}), + "keyedit_convert": OptionInfo(True, "Convert (attention) to (attention:1.1)"), "keyedit_move": OptionInfo(True, "Alt+left/right moves prompt elements"), "quicksettings_list": OptionInfo(["sd_model_checkpoint"], "Quicksettings list", ui_components.DropdownMulti, lambda: {"choices": list(shared.opts.data_labels.keys())}).js("info", "settingsHintsShowQuicksettings").info("setting entries that appear at the top of page rather than in settings tab").needs_reload_ui(), "ui_tab_order": OptionInfo([], "UI tab order", ui_components.DropdownMulti, lambda: {"choices": list(shared.tab_names)}).needs_reload_ui(), @@ -332,4 +333,3 @@ "restore_config_state_file": OptionInfo("", "Config state file to restore from, under 'config-states/' folder"), "sd_checkpoint_hash": OptionInfo("", "SHA256 hash of the current checkpoint"), })) - From 3a66c3c9e1cecab5095f37964d40a5f4cde317af Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sat, 14 Oct 2023 07:35:06 +0300 Subject: [PATCH 0164/1174] put notification.mp3 option at the end of the page --- modules/shared_options.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/modules/shared_options.py b/modules/shared_options.py index cb3566382b7..ce395302ede 100644 --- a/modules/shared_options.py +++ b/modules/shared_options.py @@ -22,7 +22,6 @@ } options_templates.update(options_section(('saving-images', "Saving images/grids"), { - "notification_audio": OptionInfo(True, "Play notification sound after image generation", comment_after="(notification.mp3 should be present in the root directory)").needs_reload_ui(), "samples_save": OptionInfo(True, "Always save all generated images"), "samples_format": OptionInfo('png', 'File format for images'), "samples_filename_pattern": OptionInfo("", "Images filename pattern", component_args=hide_dirs).link("wiki", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Custom-Images-Filename-Name-and-Subdirectory"), @@ -63,6 +62,8 @@ "clean_temp_dir_at_start": OptionInfo(False, "Cleanup non-default temporary directory when starting webui"), "save_incomplete_images": OptionInfo(False, "Save incomplete images").info("save images that has been interrupted in mid-generation; even if not saved, they will still show up in webui output."), + + "notification_audio": OptionInfo(True, "Play notification sound after image generation").info("notification.mp3 should be present in the root directory").needs_reload_ui(), })) options_templates.update(options_section(('saving-paths', "Paths for saving"), { From a109c7aeb8871fe0ae201794f140f8f2e9b5c3ac Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sat, 14 Oct 2023 07:49:03 +0300 Subject: [PATCH 0165/1174] more general case of adding an infotext when no images have been generated --- modules/processing.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/modules/processing.py b/modules/processing.py index df037fb0ea7..816f5fc7131 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -820,7 +820,6 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: state.skipped = False if state.interrupted: - infotexts.append(Processed(p, []).infotext(p, 0)) break sd_models.reload_model_weights() # model can be changed for example by refiner @@ -961,6 +960,9 @@ def infotext(index=0, use_main_prompt=False): state.nextjob() + if not infotexts: + infotexts.append(Processed(p, []).infotext(p, 0)) + p.color_corrections = None index_of_first_image = 0 From 0619df9835833079f8ba5cb5a510b55c4433acaf Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sat, 14 Oct 2023 08:01:04 +0300 Subject: [PATCH 0166/1174] use shallow copy for #13535 --- modules/sd_models.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/modules/sd_models.py b/modules/sd_models.py index 2b43868e9fb..c8efeedca0c 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -1,5 +1,4 @@ import collections -import copy import os.path import sys import gc @@ -360,7 +359,7 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer if shared.opts.sd_checkpoint_cache > 0: # cache newly loaded model - checkpoints_loaded[checkpoint_info] = copy.deepcopy(state_dict) + checkpoints_loaded[checkpoint_info] = state_dict.copy() model.load_state_dict(state_dict, strict=False) timer.record("apply weights to model") From a8cbe50c9fa324ed887089e4333452ecc4355c92 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sat, 14 Oct 2023 12:14:56 +0300 Subject: [PATCH 0167/1174] remove duplicated code --- extensions-builtin/Lora/networks.py | 31 +------- .../textual_inversion/textual_inversion.py | 74 ++++++++++--------- 2 files changed, 42 insertions(+), 63 deletions(-) diff --git a/extensions-builtin/Lora/networks.py b/extensions-builtin/Lora/networks.py index 12f70576981..d5f0f9f1650 100644 --- a/extensions-builtin/Lora/networks.py +++ b/extensions-builtin/Lora/networks.py @@ -15,7 +15,7 @@ from typing import Union from modules import shared, devices, sd_models, errors, scripts, sd_hijack -from modules.textual_inversion.textual_inversion import Embedding +import modules.textual_inversion.textual_inversion as textual_inversion from lora_logger import logger @@ -210,34 +210,7 @@ def load_network(name, network_on_disk): embeddings = {} for emb_name, data in bundle_embeddings.items(): - # textual inversion embeddings - if 'string_to_param' in data: - param_dict = data['string_to_param'] - param_dict = getattr(param_dict, '_parameters', param_dict) # fix for torch 1.12.1 loading saved file from torch 1.11 - assert len(param_dict) == 1, 'embedding file has multiple terms in it' - emb = next(iter(param_dict.items()))[1] - vec = emb.detach().to(devices.device, dtype=torch.float32) - shape = vec.shape[-1] - vectors = vec.shape[0] - elif type(data) == dict and 'clip_g' in data and 'clip_l' in data: # SDXL embedding - vec = {k: v.detach().to(devices.device, dtype=torch.float32) for k, v in data.items()} - shape = data['clip_g'].shape[-1] + data['clip_l'].shape[-1] - vectors = data['clip_g'].shape[0] - elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor: # diffuser concepts - assert len(data.keys()) == 1, 'embedding file has multiple terms in it' - - emb = next(iter(data.values())) - if len(emb.shape) == 1: - emb = emb.unsqueeze(0) - vec = emb.detach().to(devices.device, dtype=torch.float32) - shape = vec.shape[-1] - vectors = vec.shape[0] - else: - raise Exception(f"Couldn't identify {emb_name} in lora: {name} as neither textual inversion embedding nor diffuser concept.") - - embedding = Embedding(vec, emb_name) - embedding.vectors = vectors - embedding.shape = shape + embedding = textual_inversion.create_embedding_from_data(data, emb_name, filename=network_on_disk.filename + "/" + emb_name) embedding.loaded = None embeddings[emb_name] = embedding diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 401a0a2ab02..04dda585cc9 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -181,40 +181,7 @@ def load_from_file(self, path, filename): else: return - - # textual inversion embeddings - if 'string_to_param' in data: - param_dict = data['string_to_param'] - param_dict = getattr(param_dict, '_parameters', param_dict) # fix for torch 1.12.1 loading saved file from torch 1.11 - assert len(param_dict) == 1, 'embedding file has multiple terms in it' - emb = next(iter(param_dict.items()))[1] - vec = emb.detach().to(devices.device, dtype=torch.float32) - shape = vec.shape[-1] - vectors = vec.shape[0] - elif type(data) == dict and 'clip_g' in data and 'clip_l' in data: # SDXL embedding - vec = {k: v.detach().to(devices.device, dtype=torch.float32) for k, v in data.items()} - shape = data['clip_g'].shape[-1] + data['clip_l'].shape[-1] - vectors = data['clip_g'].shape[0] - elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor: # diffuser concepts - assert len(data.keys()) == 1, 'embedding file has multiple terms in it' - - emb = next(iter(data.values())) - if len(emb.shape) == 1: - emb = emb.unsqueeze(0) - vec = emb.detach().to(devices.device, dtype=torch.float32) - shape = vec.shape[-1] - vectors = vec.shape[0] - else: - raise Exception(f"Couldn't identify {filename} as neither textual inversion embedding nor diffuser concept.") - - embedding = Embedding(vec, name) - embedding.step = data.get('step', None) - embedding.sd_checkpoint = data.get('sd_checkpoint', None) - embedding.sd_checkpoint_name = data.get('sd_checkpoint_name', None) - embedding.vectors = vectors - embedding.shape = shape - embedding.filename = path - embedding.set_hash(hashes.sha256(embedding.filename, "textual_inversion/" + name) or '') + embedding = create_embedding_from_data(data, name, filename=filename, filepath=path) if self.expected_shape == -1 or self.expected_shape == embedding.shape: self.register_embedding(embedding, shared.sd_model) @@ -313,6 +280,45 @@ def create_embedding(name, num_vectors_per_token, overwrite_old, init_text='*'): return fn +def create_embedding_from_data(data, name, filename='unknown embedding file', filepath=None): + if 'string_to_param' in data: # textual inversion embeddings + param_dict = data['string_to_param'] + param_dict = getattr(param_dict, '_parameters', param_dict) # fix for torch 1.12.1 loading saved file from torch 1.11 + assert len(param_dict) == 1, 'embedding file has multiple terms in it' + emb = next(iter(param_dict.items()))[1] + vec = emb.detach().to(devices.device, dtype=torch.float32) + shape = vec.shape[-1] + vectors = vec.shape[0] + elif type(data) == dict and 'clip_g' in data and 'clip_l' in data: # SDXL embedding + vec = {k: v.detach().to(devices.device, dtype=torch.float32) for k, v in data.items()} + shape = data['clip_g'].shape[-1] + data['clip_l'].shape[-1] + vectors = data['clip_g'].shape[0] + elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor: # diffuser concepts + assert len(data.keys()) == 1, 'embedding file has multiple terms in it' + + emb = next(iter(data.values())) + if len(emb.shape) == 1: + emb = emb.unsqueeze(0) + vec = emb.detach().to(devices.device, dtype=torch.float32) + shape = vec.shape[-1] + vectors = vec.shape[0] + else: + raise Exception(f"Couldn't identify {filename} as neither textual inversion embedding nor diffuser concept.") + + embedding = Embedding(vec, name) + embedding.step = data.get('step', None) + embedding.sd_checkpoint = data.get('sd_checkpoint', None) + embedding.sd_checkpoint_name = data.get('sd_checkpoint_name', None) + embedding.vectors = vectors + embedding.shape = shape + + if filepath: + embedding.filename = filepath + embedding.set_hash(hashes.sha256(filepath, "textual_inversion/" + name) or '') + + return embedding + + def write_loss(log_directory, filename, step, epoch_len, values): if shared.opts.training_write_csv_every == 0: return From 117ec7199416b56a16cbc5591f13fe6cce521d9e Mon Sep 17 00:00:00 2001 From: Won-Kyu Park Date: Sat, 14 Oct 2023 21:58:28 +0900 Subject: [PATCH 0168/1174] support webui.settings.bat --- webui.bat | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/webui.bat b/webui.bat index a630ea4d9ff..e2c9079d2fb 100644 --- a/webui.bat +++ b/webui.bat @@ -1,5 +1,9 @@ @echo off +if exist webui.settings.bat ( + call webui.settings.bat +) + if not defined PYTHON (set PYTHON=python) if defined GIT (set "GIT_PYTHON_GIT_EXECUTABLE=%GIT%") if not defined VENV_DIR (set "VENV_DIR=%~dp0%venv") From f00eaa4d000856964f810aa4b651200861bafc6f Mon Sep 17 00:00:00 2001 From: Khachatur Avanesian Date: Sun, 15 Oct 2023 02:34:03 +0300 Subject: [PATCH 0169/1174] Start / Restart generation by Ctrl (Alt) + Enter Add ability to interrupt current generation and start generation again by Ctrl (Alt) + Enter --- script.js | 328 +++++++++++++++++++++++++++--------------------------- 1 file changed, 165 insertions(+), 163 deletions(-) diff --git a/script.js b/script.js index 34cca7651dd..8bf1257da29 100644 --- a/script.js +++ b/script.js @@ -1,163 +1,165 @@ -function gradioApp() { - const elems = document.getElementsByTagName('gradio-app'); - const elem = elems.length == 0 ? document : elems[0]; - - if (elem !== document) { - elem.getElementById = function(id) { - return document.getElementById(id); - }; - } - return elem.shadowRoot ? elem.shadowRoot : elem; -} - -/** - * Get the currently selected top-level UI tab button (e.g. the button that says "Extras"). - */ -function get_uiCurrentTab() { - return gradioApp().querySelector('#tabs > .tab-nav > button.selected'); -} - -/** - * Get the first currently visible top-level UI tab content (e.g. the div hosting the "txt2img" UI). - */ -function get_uiCurrentTabContent() { - return gradioApp().querySelector('#tabs > .tabitem[id^=tab_]:not([style*="display: none"])'); -} - -var uiUpdateCallbacks = []; -var uiAfterUpdateCallbacks = []; -var uiLoadedCallbacks = []; -var uiTabChangeCallbacks = []; -var optionsChangedCallbacks = []; -var uiAfterUpdateTimeout = null; -var uiCurrentTab = null; - -/** - * Register callback to be called at each UI update. - * The callback receives an array of MutationRecords as an argument. - */ -function onUiUpdate(callback) { - uiUpdateCallbacks.push(callback); -} - -/** - * Register callback to be called soon after UI updates. - * The callback receives no arguments. - * - * This is preferred over `onUiUpdate` if you don't need - * access to the MutationRecords, as your function will - * not be called quite as often. - */ -function onAfterUiUpdate(callback) { - uiAfterUpdateCallbacks.push(callback); -} - -/** - * Register callback to be called when the UI is loaded. - * The callback receives no arguments. - */ -function onUiLoaded(callback) { - uiLoadedCallbacks.push(callback); -} - -/** - * Register callback to be called when the UI tab is changed. - * The callback receives no arguments. - */ -function onUiTabChange(callback) { - uiTabChangeCallbacks.push(callback); -} - -/** - * Register callback to be called when the options are changed. - * The callback receives no arguments. - * @param callback - */ -function onOptionsChanged(callback) { - optionsChangedCallbacks.push(callback); -} - -function executeCallbacks(queue, arg) { - for (const callback of queue) { - try { - callback(arg); - } catch (e) { - console.error("error running callback", callback, ":", e); - } - } -} - -/** - * Schedule the execution of the callbacks registered with onAfterUiUpdate. - * The callbacks are executed after a short while, unless another call to this function - * is made before that time. IOW, the callbacks are executed only once, even - * when there are multiple mutations observed. - */ -function scheduleAfterUiUpdateCallbacks() { - clearTimeout(uiAfterUpdateTimeout); - uiAfterUpdateTimeout = setTimeout(function() { - executeCallbacks(uiAfterUpdateCallbacks); - }, 200); -} - -var executedOnLoaded = false; - -document.addEventListener("DOMContentLoaded", function() { - var mutationObserver = new MutationObserver(function(m) { - if (!executedOnLoaded && gradioApp().querySelector('#txt2img_prompt')) { - executedOnLoaded = true; - executeCallbacks(uiLoadedCallbacks); - } - - executeCallbacks(uiUpdateCallbacks, m); - scheduleAfterUiUpdateCallbacks(); - const newTab = get_uiCurrentTab(); - if (newTab && (newTab !== uiCurrentTab)) { - uiCurrentTab = newTab; - executeCallbacks(uiTabChangeCallbacks); - } - }); - mutationObserver.observe(gradioApp(), {childList: true, subtree: true}); -}); - -/** - * Add a ctrl+enter as a shortcut to start a generation - */ -document.addEventListener('keydown', function(e) { - var handled = false; - if (e.key !== undefined) { - if ((e.key == "Enter" && (e.metaKey || e.ctrlKey || e.altKey))) handled = true; - } else if (e.keyCode !== undefined) { - if ((e.keyCode == 13 && (e.metaKey || e.ctrlKey || e.altKey))) handled = true; - } - if (handled) { - var button = get_uiCurrentTabContent().querySelector('button[id$=_generate]'); - if (button) { - button.click(); - } - e.preventDefault(); - } -}); - -/** - * checks that a UI element is not in another hidden element or tab content - */ -function uiElementIsVisible(el) { - if (el === document) { - return true; - } - - const computedStyle = getComputedStyle(el); - const isVisible = computedStyle.display !== 'none'; - - if (!isVisible) return false; - return uiElementIsVisible(el.parentNode); -} - -function uiElementInSight(el) { - const clRect = el.getBoundingClientRect(); - const windowHeight = window.innerHeight; - const isOnScreen = clRect.bottom > 0 && clRect.top < windowHeight; - - return isOnScreen; -} +function gradioApp() { + const elems = document.getElementsByTagName('gradio-app'); + const elem = elems.length == 0 ? document : elems[0]; + + if (elem !== document) { + elem.getElementById = function(id) { + return document.getElementById(id); + }; + } + return elem.shadowRoot ? elem.shadowRoot : elem; +} + +/** + * Get the currently selected top-level UI tab button (e.g. the button that says "Extras"). + */ +function get_uiCurrentTab() { + return gradioApp().querySelector('#tabs > .tab-nav > button.selected'); +} + +/** + * Get the first currently visible top-level UI tab content (e.g. the div hosting the "txt2img" UI). + */ +function get_uiCurrentTabContent() { + return gradioApp().querySelector('#tabs > .tabitem[id^=tab_]:not([style*="display: none"])'); +} + +var uiUpdateCallbacks = []; +var uiAfterUpdateCallbacks = []; +var uiLoadedCallbacks = []; +var uiTabChangeCallbacks = []; +var optionsChangedCallbacks = []; +var uiAfterUpdateTimeout = null; +var uiCurrentTab = null; + +/** + * Register callback to be called at each UI update. + * The callback receives an array of MutationRecords as an argument. + */ +function onUiUpdate(callback) { + uiUpdateCallbacks.push(callback); +} + +/** + * Register callback to be called soon after UI updates. + * The callback receives no arguments. + * + * This is preferred over `onUiUpdate` if you don't need + * access to the MutationRecords, as your function will + * not be called quite as often. + */ +function onAfterUiUpdate(callback) { + uiAfterUpdateCallbacks.push(callback); +} + +/** + * Register callback to be called when the UI is loaded. + * The callback receives no arguments. + */ +function onUiLoaded(callback) { + uiLoadedCallbacks.push(callback); +} + +/** + * Register callback to be called when the UI tab is changed. + * The callback receives no arguments. + */ +function onUiTabChange(callback) { + uiTabChangeCallbacks.push(callback); +} + +/** + * Register callback to be called when the options are changed. + * The callback receives no arguments. + * @param callback + */ +function onOptionsChanged(callback) { + optionsChangedCallbacks.push(callback); +} + +function executeCallbacks(queue, arg) { + for (const callback of queue) { + try { + callback(arg); + } catch (e) { + console.error("error running callback", callback, ":", e); + } + } +} + +/** + * Schedule the execution of the callbacks registered with onAfterUiUpdate. + * The callbacks are executed after a short while, unless another call to this function + * is made before that time. IOW, the callbacks are executed only once, even + * when there are multiple mutations observed. + */ +function scheduleAfterUiUpdateCallbacks() { + clearTimeout(uiAfterUpdateTimeout); + uiAfterUpdateTimeout = setTimeout(function() { + executeCallbacks(uiAfterUpdateCallbacks); + }, 200); +} + +var executedOnLoaded = false; + +document.addEventListener("DOMContentLoaded", function() { + var mutationObserver = new MutationObserver(function(m) { + if (!executedOnLoaded && gradioApp().querySelector('#txt2img_prompt')) { + executedOnLoaded = true; + executeCallbacks(uiLoadedCallbacks); + } + + executeCallbacks(uiUpdateCallbacks, m); + scheduleAfterUiUpdateCallbacks(); + const newTab = get_uiCurrentTab(); + if (newTab && (newTab !== uiCurrentTab)) { + uiCurrentTab = newTab; + executeCallbacks(uiTabChangeCallbacks); + } + }); + mutationObserver.observe(gradioApp(), {childList: true, subtree: true}); +}); + +/** + * Add a Ctrl (Alt) + Enter as a shortcut to start / restart a generation + */ +document.addEventListener('keydown', (e) => { + const isEnter = e.key === 'Enter' || e.keyCode === 13 + const isModifierKey = e.metaKey || e.ctrlKey || e.altKey + + const interruptButton = get_uiCurrentTabContent().querySelector('button[id$=_interrupt]') + const generateButton = get_uiCurrentTabContent().querySelector('button[id$=_generate]') + + if (isEnter && isModifierKey) { + if (interruptButton.style.display === 'block') { + interruptButton.click() + setTimeout(() => generateButton.click(), 500) + } else { + generateButton.click() + } + e.preventDefault() + } +}) + +/** + * checks that a UI element is not in another hidden element or tab content + */ +function uiElementIsVisible(el) { + if (el === document) { + return true; + } + + const computedStyle = getComputedStyle(el); + const isVisible = computedStyle.display !== 'none'; + + if (!isVisible) return false; + return uiElementIsVisible(el.parentNode); +} + +function uiElementInSight(el) { + const clRect = el.getBoundingClientRect(); + const windowHeight = window.innerHeight; + const isOnScreen = clRect.bottom > 0 && clRect.top < windowHeight; + + return isOnScreen; +} From 0d65d0eabded4a79c3168ed14599db3970d414c5 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sun, 15 Oct 2023 08:45:38 +0300 Subject: [PATCH 0170/1174] add an option to not print stack traces on ctrl+c. --- modules/initialize_util.py | 6 +++++- modules/shared_options.py | 1 + 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/modules/initialize_util.py b/modules/initialize_util.py index 2894eee4c1a..2e9b6d895f4 100644 --- a/modules/initialize_util.py +++ b/modules/initialize_util.py @@ -150,10 +150,14 @@ def dumpstacks(): def configure_sigint_handler(): # make the program just exit at ctrl+c without waiting for anything + + from modules import shared + def sigint_handler(sig, frame): print(f'Interrupted with signal {sig} in {frame}') - dumpstacks() + if shared.opts.dump_stacks_on_signal: + dumpstacks() os._exit(0) diff --git a/modules/shared_options.py b/modules/shared_options.py index ce395302ede..6ef9fdd58c1 100644 --- a/modules/shared_options.py +++ b/modules/shared_options.py @@ -112,6 +112,7 @@ "list_hidden_files": OptionInfo(True, "Load models/files in hidden directories").info("directory is hidden if its name starts with \".\""), "disable_mmap_load_safetensors": OptionInfo(False, "Disable memmapping for loading .safetensors files.").info("fixes very slow loading speed in some cases"), "hide_ldm_prints": OptionInfo(True, "Prevent Stability-AI's ldm/sgm modules from printing noise to console."), + "dump_stacks_on_signal": OptionInfo(False, "Print stack traces before exiting the program with ctrl+c."), })) options_templates.update(options_section(('API', "API"), { From 282903bb6798f49af66f6935ee4dc0015895cf7c Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sun, 15 Oct 2023 09:41:02 +0300 Subject: [PATCH 0171/1174] repair unload sd checkpoint button --- modules/api/api.py | 11 +++++------ modules/sd_models.py | 13 +------------ modules/ui_settings.py | 24 +++++++++++++++++------- 3 files changed, 23 insertions(+), 25 deletions(-) diff --git a/modules/api/api.py b/modules/api/api.py index efedafa42e2..09083874792 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -17,15 +17,14 @@ from secrets import compare_digest import modules.shared as shared -from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing, errors, restart, shared_items, script_callbacks, generation_parameters_copypaste +from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing, errors, restart, shared_items, script_callbacks, generation_parameters_copypaste, sd_models from modules.api import models from modules.shared import opts from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images from modules.textual_inversion.textual_inversion import create_embedding, train_embedding from modules.textual_inversion.preprocess import preprocess from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork -from PIL import PngImagePlugin,Image -from modules.sd_models import unload_model_weights, reload_model_weights, checkpoint_aliases +from PIL import PngImagePlugin, Image from modules.sd_models_config import find_checkpoint_config_near_filename from modules.realesrgan_model import get_realesrgan_models from modules import devices @@ -541,12 +540,12 @@ def interruptapi(self): return {} def unloadapi(self): - unload_model_weights() + sd_models.unload_model_weights() return {} def reloadapi(self): - reload_model_weights() + sd_models.send_model_to_device(shared.sd_model) return {} @@ -566,7 +565,7 @@ def get_config(self): def set_config(self, req: dict[str, Any]): checkpoint_name = req.get("sd_model_checkpoint", None) - if checkpoint_name is not None and checkpoint_name not in checkpoint_aliases: + if checkpoint_name is not None and checkpoint_name not in sd_models.checkpoint_aliases: raise RuntimeError(f"model {checkpoint_name!r} not found") for k, v in req.items(): diff --git a/modules/sd_models.py b/modules/sd_models.py index c8efeedca0c..3b6cdea18d4 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -1,7 +1,6 @@ import collections import os.path import sys -import gc import threading import torch @@ -798,17 +797,7 @@ def reload_model_weights(sd_model=None, info=None): def unload_model_weights(sd_model=None, info=None): - timer = Timer() - - if model_data.sd_model: - model_data.sd_model.to(devices.cpu) - sd_hijack.model_hijack.undo_hijack(model_data.sd_model) - model_data.sd_model = None - sd_model = None - gc.collect() - devices.torch_gc() - - print(f"Unloaded weights {timer.summary()}.") + send_model_to_cpu(sd_model or shared.sd_model) return sd_model diff --git a/modules/ui_settings.py b/modules/ui_settings.py index 74a3aef32cc..e054d00ab04 100644 --- a/modules/ui_settings.py +++ b/modules/ui_settings.py @@ -1,6 +1,6 @@ import gradio as gr -from modules import ui_common, shared, script_callbacks, scripts, sd_models, sysinfo +from modules import ui_common, shared, script_callbacks, scripts, sd_models, sysinfo, timer from modules.call_queue import wrap_gradio_call from modules.shared import opts from modules.ui_components import FormRow @@ -177,8 +177,8 @@ def create_ui(self, loadsave, dummy_component): download_localization = gr.Button(value='Download localization template', elem_id="download_localization") reload_script_bodies = gr.Button(value='Reload custom script bodies (No ui updates, No restart)', variant='secondary', elem_id="settings_reload_script_bodies") with gr.Row(): - unload_sd_model = gr.Button(value='Unload SD checkpoint to free VRAM', elem_id="sett_unload_sd_model") - reload_sd_model = gr.Button(value='Reload the last SD checkpoint back into VRAM', elem_id="sett_reload_sd_model") + unload_sd_model = gr.Button(value='Unload SD checkpoint to RAM', elem_id="sett_unload_sd_model") + reload_sd_model = gr.Button(value='Load SD checkpoint to VRAM from RAM', elem_id="sett_reload_sd_model") with gr.Row(): calculate_all_checkpoint_hash = gr.Button(value='Calculate hash for all checkpoint', elem_id="calculate_all_checkpoint_hash") calculate_all_checkpoint_hash_threads = gr.Number(value=1, label="Number of parallel calculations", elem_id="calculate_all_checkpoint_hash_threads", precision=0, minimum=1) @@ -194,16 +194,26 @@ def create_ui(self, loadsave, dummy_component): self.text_settings = gr.Textbox(elem_id="settings_json", value=lambda: opts.dumpjson(), visible=False) + def call_func_and_return_text(func, text): + def handler(): + t = timer.Timer() + func() + t.record(text) + + return f'{text} in {t.total:.1f}s' + + return handler + unload_sd_model.click( - fn=sd_models.unload_model_weights, + fn=call_func_and_return_text(sd_models.unload_model_weights, 'Unloaded the checkpoint'), inputs=[], - outputs=[] + outputs=[self.result] ) reload_sd_model.click( - fn=sd_models.reload_model_weights, + fn=call_func_and_return_text(lambda: sd_models.send_model_to_device(shared.sd_model), 'Loaded the checkpoint'), inputs=[], - outputs=[] + outputs=[self.result] ) request_notifications.click( From 2f6ea8b10312e700f8d01f90dff15d17690ce49c Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sun, 15 Oct 2023 10:12:38 +0300 Subject: [PATCH 0172/1174] respect keyedit_precision_attention setting when converting from old (((attention))) syntax --- javascript/edit-attention.js | 2 ++ 1 file changed, 2 insertions(+) diff --git a/javascript/edit-attention.js b/javascript/edit-attention.js index 8b6dd2406fc..45d9a788a17 100644 --- a/javascript/edit-attention.js +++ b/javascript/edit-attention.js @@ -86,6 +86,8 @@ function keyupEditAttention(event) { weight = (1 / 1.1) ** numParen; } + weight = Math.round(weight / opts.keyedit_precision_attention) * opts.keyedit_precision_attention; + text = text.slice(0, selectionStart - numParen) + "(" + text.slice(selectionStart, selectionEnd) + ":" + weight + ")" + text.slice(selectionEnd + numParen); selectionStart -= numParen - 1; selectionEnd -= numParen - 1; From 77bd953da2f96845b196f242f7ba51138cd54e3b Mon Sep 17 00:00:00 2001 From: Khachatur Avanesian Date: Sun, 15 Oct 2023 10:25:36 +0300 Subject: [PATCH 0173/1174] Update script.js Exclude lambda --- script.js | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/script.js b/script.js index 8bf1257da29..2fab1554bbd 100644 --- a/script.js +++ b/script.js @@ -123,23 +123,25 @@ document.addEventListener("DOMContentLoaded", function() { /** * Add a Ctrl (Alt) + Enter as a shortcut to start / restart a generation */ -document.addEventListener('keydown', (e) => { - const isEnter = e.key === 'Enter' || e.keyCode === 13 - const isModifierKey = e.metaKey || e.ctrlKey || e.altKey +document.addEventListener('keydown', function(e) { + const isEnter = e.key === 'Enter' || e.keyCode === 13; + const isModifierKey = e.metaKey || e.ctrlKey || e.altKey; - const interruptButton = get_uiCurrentTabContent().querySelector('button[id$=_interrupt]') - const generateButton = get_uiCurrentTabContent().querySelector('button[id$=_generate]') + const interruptButton = get_uiCurrentTabContent().querySelector('button[id$=_interrupt]'); + const generateButton = get_uiCurrentTabContent().querySelector('button[id$=_generate]'); if (isEnter && isModifierKey) { if (interruptButton.style.display === 'block') { - interruptButton.click() - setTimeout(() => generateButton.click(), 500) + interruptButton.click(); + setTimeout(function() { + generateButton.click(); + }, 500); } else { - generateButton.click() + generateButton.click(); } - e.preventDefault() + e.preventDefault(); } -}) +}); /** * checks that a UI element is not in another hidden element or tab content From d295e97a0d7ca985ab21e935fd933ce629fba2bf Mon Sep 17 00:00:00 2001 From: Khachatur Avanesian Date: Sun, 15 Oct 2023 10:37:48 +0300 Subject: [PATCH 0174/1174] Update script.js LF instead CRLF --- script.js | 130 ++++++++++++++++++++++++++++-------------------------- 1 file changed, 68 insertions(+), 62 deletions(-) diff --git a/script.js b/script.js index 2fab1554bbd..8af9773fdcb 100644 --- a/script.js +++ b/script.js @@ -1,43 +1,45 @@ function gradioApp() { - const elems = document.getElementsByTagName('gradio-app'); - const elem = elems.length == 0 ? document : elems[0]; + const elems = document.getElementsByTagName('gradio-app') + const elem = elems.length == 0 ? document : elems[0] if (elem !== document) { - elem.getElementById = function(id) { - return document.getElementById(id); - }; + elem.getElementById = function (id) { + return document.getElementById(id) + } } - return elem.shadowRoot ? elem.shadowRoot : elem; + return elem.shadowRoot ? elem.shadowRoot : elem } /** * Get the currently selected top-level UI tab button (e.g. the button that says "Extras"). */ function get_uiCurrentTab() { - return gradioApp().querySelector('#tabs > .tab-nav > button.selected'); + return gradioApp().querySelector('#tabs > .tab-nav > button.selected') } /** * Get the first currently visible top-level UI tab content (e.g. the div hosting the "txt2img" UI). */ function get_uiCurrentTabContent() { - return gradioApp().querySelector('#tabs > .tabitem[id^=tab_]:not([style*="display: none"])'); + return gradioApp().querySelector( + '#tabs > .tabitem[id^=tab_]:not([style*="display: none"])' + ) } -var uiUpdateCallbacks = []; -var uiAfterUpdateCallbacks = []; -var uiLoadedCallbacks = []; -var uiTabChangeCallbacks = []; -var optionsChangedCallbacks = []; -var uiAfterUpdateTimeout = null; -var uiCurrentTab = null; +var uiUpdateCallbacks = [] +var uiAfterUpdateCallbacks = [] +var uiLoadedCallbacks = [] +var uiTabChangeCallbacks = [] +var optionsChangedCallbacks = [] +var uiAfterUpdateTimeout = null +var uiCurrentTab = null /** * Register callback to be called at each UI update. * The callback receives an array of MutationRecords as an argument. */ function onUiUpdate(callback) { - uiUpdateCallbacks.push(callback); + uiUpdateCallbacks.push(callback) } /** @@ -49,7 +51,7 @@ function onUiUpdate(callback) { * not be called quite as often. */ function onAfterUiUpdate(callback) { - uiAfterUpdateCallbacks.push(callback); + uiAfterUpdateCallbacks.push(callback) } /** @@ -57,7 +59,7 @@ function onAfterUiUpdate(callback) { * The callback receives no arguments. */ function onUiLoaded(callback) { - uiLoadedCallbacks.push(callback); + uiLoadedCallbacks.push(callback) } /** @@ -65,7 +67,7 @@ function onUiLoaded(callback) { * The callback receives no arguments. */ function onUiTabChange(callback) { - uiTabChangeCallbacks.push(callback); + uiTabChangeCallbacks.push(callback) } /** @@ -74,15 +76,15 @@ function onUiTabChange(callback) { * @param callback */ function onOptionsChanged(callback) { - optionsChangedCallbacks.push(callback); + optionsChangedCallbacks.push(callback) } function executeCallbacks(queue, arg) { for (const callback of queue) { try { - callback(arg); + callback(arg) } catch (e) { - console.error("error running callback", callback, ":", e); + console.error('error running callback', callback, ':', e) } } } @@ -94,74 +96,78 @@ function executeCallbacks(queue, arg) { * when there are multiple mutations observed. */ function scheduleAfterUiUpdateCallbacks() { - clearTimeout(uiAfterUpdateTimeout); - uiAfterUpdateTimeout = setTimeout(function() { - executeCallbacks(uiAfterUpdateCallbacks); - }, 200); + clearTimeout(uiAfterUpdateTimeout) + uiAfterUpdateTimeout = setTimeout(function () { + executeCallbacks(uiAfterUpdateCallbacks) + }, 200) } -var executedOnLoaded = false; +var executedOnLoaded = false -document.addEventListener("DOMContentLoaded", function() { - var mutationObserver = new MutationObserver(function(m) { +document.addEventListener('DOMContentLoaded', function () { + var mutationObserver = new MutationObserver(function (m) { if (!executedOnLoaded && gradioApp().querySelector('#txt2img_prompt')) { - executedOnLoaded = true; - executeCallbacks(uiLoadedCallbacks); + executedOnLoaded = true + executeCallbacks(uiLoadedCallbacks) } - executeCallbacks(uiUpdateCallbacks, m); - scheduleAfterUiUpdateCallbacks(); - const newTab = get_uiCurrentTab(); - if (newTab && (newTab !== uiCurrentTab)) { - uiCurrentTab = newTab; - executeCallbacks(uiTabChangeCallbacks); + executeCallbacks(uiUpdateCallbacks, m) + scheduleAfterUiUpdateCallbacks() + const newTab = get_uiCurrentTab() + if (newTab && newTab !== uiCurrentTab) { + uiCurrentTab = newTab + executeCallbacks(uiTabChangeCallbacks) } - }); - mutationObserver.observe(gradioApp(), {childList: true, subtree: true}); -}); + }) + mutationObserver.observe(gradioApp(), {childList: true, subtree: true}) +}) /** - * Add a Ctrl (Alt) + Enter as a shortcut to start / restart a generation + * Add a ctrl+enter as a shortcut to start a generation */ -document.addEventListener('keydown', function(e) { - const isEnter = e.key === 'Enter' || e.keyCode === 13; - const isModifierKey = e.metaKey || e.ctrlKey || e.altKey; +document.addEventListener('keydown', function (e) { + const isEnter = e.key === 'Enter' || e.keyCode === 13 + const isModifierKey = e.metaKey || e.ctrlKey || e.altKey - const interruptButton = get_uiCurrentTabContent().querySelector('button[id$=_interrupt]'); - const generateButton = get_uiCurrentTabContent().querySelector('button[id$=_generate]'); + const interruptButton = get_uiCurrentTabContent().querySelector( + 'button[id$=_interrupt]' + ) + const generateButton = get_uiCurrentTabContent().querySelector( + 'button[id$=_generate]' + ) if (isEnter && isModifierKey) { if (interruptButton.style.display === 'block') { - interruptButton.click(); - setTimeout(function() { - generateButton.click(); - }, 500); + interruptButton.click() + setTimeout(function () { + generateButton.click() + }, 500) } else { - generateButton.click(); + generateButton.click() } - e.preventDefault(); + e.preventDefault() } -}); +}) /** * checks that a UI element is not in another hidden element or tab content */ function uiElementIsVisible(el) { if (el === document) { - return true; + return true } - const computedStyle = getComputedStyle(el); - const isVisible = computedStyle.display !== 'none'; + const computedStyle = getComputedStyle(el) + const isVisible = computedStyle.display !== 'none' - if (!isVisible) return false; - return uiElementIsVisible(el.parentNode); + if (!isVisible) return false + return uiElementIsVisible(el.parentNode) } function uiElementInSight(el) { - const clRect = el.getBoundingClientRect(); - const windowHeight = window.innerHeight; - const isOnScreen = clRect.bottom > 0 && clRect.top < windowHeight; + const clRect = el.getBoundingClientRect() + const windowHeight = window.innerHeight + const isOnScreen = clRect.bottom > 0 && clRect.top < windowHeight - return isOnScreen; + return isOnScreen } From 3e223523ced2a19347d0b42b662129947152dc49 Mon Sep 17 00:00:00 2001 From: Khachatur Avanesian Date: Sun, 15 Oct 2023 10:48:50 +0300 Subject: [PATCH 0175/1174] Update script.js --- script.js | 128 ++++++++++++++++++++++++++---------------------------- 1 file changed, 61 insertions(+), 67 deletions(-) diff --git a/script.js b/script.js index 8af9773fdcb..0f63ee9335a 100644 --- a/script.js +++ b/script.js @@ -1,45 +1,43 @@ function gradioApp() { - const elems = document.getElementsByTagName('gradio-app') - const elem = elems.length == 0 ? document : elems[0] + const elems = document.getElementsByTagName('gradio-app'); + const elem = elems.length == 0 ? document : elems[0]; if (elem !== document) { - elem.getElementById = function (id) { - return document.getElementById(id) - } + elem.getElementById = function(id) { + return document.getElementById(id); + }; } - return elem.shadowRoot ? elem.shadowRoot : elem + return elem.shadowRoot ? elem.shadowRoot : elem; } /** * Get the currently selected top-level UI tab button (e.g. the button that says "Extras"). */ function get_uiCurrentTab() { - return gradioApp().querySelector('#tabs > .tab-nav > button.selected') + return gradioApp().querySelector('#tabs > .tab-nav > button.selected'); } /** * Get the first currently visible top-level UI tab content (e.g. the div hosting the "txt2img" UI). */ function get_uiCurrentTabContent() { - return gradioApp().querySelector( - '#tabs > .tabitem[id^=tab_]:not([style*="display: none"])' - ) + return gradioApp().querySelector('#tabs > .tabitem[id^=tab_]:not([style*="display: none"])'); } -var uiUpdateCallbacks = [] -var uiAfterUpdateCallbacks = [] -var uiLoadedCallbacks = [] -var uiTabChangeCallbacks = [] -var optionsChangedCallbacks = [] -var uiAfterUpdateTimeout = null -var uiCurrentTab = null +var uiUpdateCallbacks = []; +var uiAfterUpdateCallbacks = []; +var uiLoadedCallbacks = []; +var uiTabChangeCallbacks = []; +var optionsChangedCallbacks = []; +var uiAfterUpdateTimeout = null; +var uiCurrentTab = null; /** * Register callback to be called at each UI update. * The callback receives an array of MutationRecords as an argument. */ function onUiUpdate(callback) { - uiUpdateCallbacks.push(callback) + uiUpdateCallbacks.push(callback); } /** @@ -51,7 +49,7 @@ function onUiUpdate(callback) { * not be called quite as often. */ function onAfterUiUpdate(callback) { - uiAfterUpdateCallbacks.push(callback) + uiAfterUpdateCallbacks.push(callback); } /** @@ -59,7 +57,7 @@ function onAfterUiUpdate(callback) { * The callback receives no arguments. */ function onUiLoaded(callback) { - uiLoadedCallbacks.push(callback) + uiLoadedCallbacks.push(callback); } /** @@ -67,7 +65,7 @@ function onUiLoaded(callback) { * The callback receives no arguments. */ function onUiTabChange(callback) { - uiTabChangeCallbacks.push(callback) + uiTabChangeCallbacks.push(callback); } /** @@ -76,15 +74,15 @@ function onUiTabChange(callback) { * @param callback */ function onOptionsChanged(callback) { - optionsChangedCallbacks.push(callback) + optionsChangedCallbacks.push(callback); } function executeCallbacks(queue, arg) { for (const callback of queue) { try { - callback(arg) + callback(arg); } catch (e) { - console.error('error running callback', callback, ':', e) + console.error("error running callback", callback, ":", e); } } } @@ -96,78 +94,74 @@ function executeCallbacks(queue, arg) { * when there are multiple mutations observed. */ function scheduleAfterUiUpdateCallbacks() { - clearTimeout(uiAfterUpdateTimeout) - uiAfterUpdateTimeout = setTimeout(function () { - executeCallbacks(uiAfterUpdateCallbacks) - }, 200) + clearTimeout(uiAfterUpdateTimeout); + uiAfterUpdateTimeout = setTimeout(function() { + executeCallbacks(uiAfterUpdateCallbacks); + }, 200); } -var executedOnLoaded = false +var executedOnLoaded = false; -document.addEventListener('DOMContentLoaded', function () { - var mutationObserver = new MutationObserver(function (m) { +document.addEventListener("DOMContentLoaded", function() { + var mutationObserver = new MutationObserver(function(m) { if (!executedOnLoaded && gradioApp().querySelector('#txt2img_prompt')) { - executedOnLoaded = true - executeCallbacks(uiLoadedCallbacks) + executedOnLoaded = true; + executeCallbacks(uiLoadedCallbacks); } - executeCallbacks(uiUpdateCallbacks, m) - scheduleAfterUiUpdateCallbacks() - const newTab = get_uiCurrentTab() - if (newTab && newTab !== uiCurrentTab) { - uiCurrentTab = newTab - executeCallbacks(uiTabChangeCallbacks) + executeCallbacks(uiUpdateCallbacks, m); + scheduleAfterUiUpdateCallbacks(); + const newTab = get_uiCurrentTab(); + if (newTab && (newTab !== uiCurrentTab)) { + uiCurrentTab = newTab; + executeCallbacks(uiTabChangeCallbacks); } - }) - mutationObserver.observe(gradioApp(), {childList: true, subtree: true}) -}) + }); + mutationObserver.observe(gradioApp(), {childList: true, subtree: true}); +}); /** * Add a ctrl+enter as a shortcut to start a generation */ -document.addEventListener('keydown', function (e) { - const isEnter = e.key === 'Enter' || e.keyCode === 13 - const isModifierKey = e.metaKey || e.ctrlKey || e.altKey +document.addEventListener('keydown', function(e) { + const isEnter = e.key === 'Enter' || e.keyCode === 13; + const isModifierKey = e.metaKey || e.ctrlKey || e.altKey; - const interruptButton = get_uiCurrentTabContent().querySelector( - 'button[id$=_interrupt]' - ) - const generateButton = get_uiCurrentTabContent().querySelector( - 'button[id$=_generate]' - ) + const interruptButton = get_uiCurrentTabContent().querySelector('button[id$=_interrupt]'); + const generateButton = get_uiCurrentTabContent().querySelector('button[id$=_generate]'); if (isEnter && isModifierKey) { if (interruptButton.style.display === 'block') { - interruptButton.click() - setTimeout(function () { - generateButton.click() - }, 500) + interruptButton.click(); + setTimeout(function() { + generateButton.click(); + }, 500); } else { - generateButton.click() + generateButton.click(); } - e.preventDefault() + e.preventDefault(); } -}) +}); /** * checks that a UI element is not in another hidden element or tab content */ function uiElementIsVisible(el) { if (el === document) { - return true + return true; } - const computedStyle = getComputedStyle(el) - const isVisible = computedStyle.display !== 'none' + const computedStyle = getComputedStyle(el); + const isVisible = computedStyle.display !== 'none'; - if (!isVisible) return false - return uiElementIsVisible(el.parentNode) + if (!isVisible) return false; + return uiElementIsVisible(el.parentNode); } function uiElementInSight(el) { - const clRect = el.getBoundingClientRect() - const windowHeight = window.innerHeight - const isOnScreen = clRect.bottom > 0 && clRect.top < windowHeight + const clRect = el.getBoundingClientRect(); + const windowHeight = window.innerHeight; + const isOnScreen = clRect.bottom > 0 && clRect.top < windowHeight; - return isOnScreen + return isOnScreen; } From d33cb2b8122f259002ce6ef2e7f5cf30dbe069b5 Mon Sep 17 00:00:00 2001 From: Khachatur Avanesian Date: Sun, 15 Oct 2023 11:01:45 +0300 Subject: [PATCH 0176/1174] Add files via upload LF --- script.js | 334 +++++++++++++++++++++++++++--------------------------- 1 file changed, 167 insertions(+), 167 deletions(-) diff --git a/script.js b/script.js index 0f63ee9335a..5f6ee35427b 100644 --- a/script.js +++ b/script.js @@ -1,167 +1,167 @@ -function gradioApp() { - const elems = document.getElementsByTagName('gradio-app'); - const elem = elems.length == 0 ? document : elems[0]; - - if (elem !== document) { - elem.getElementById = function(id) { - return document.getElementById(id); - }; - } - return elem.shadowRoot ? elem.shadowRoot : elem; -} - -/** - * Get the currently selected top-level UI tab button (e.g. the button that says "Extras"). - */ -function get_uiCurrentTab() { - return gradioApp().querySelector('#tabs > .tab-nav > button.selected'); -} - -/** - * Get the first currently visible top-level UI tab content (e.g. the div hosting the "txt2img" UI). - */ -function get_uiCurrentTabContent() { - return gradioApp().querySelector('#tabs > .tabitem[id^=tab_]:not([style*="display: none"])'); -} - -var uiUpdateCallbacks = []; -var uiAfterUpdateCallbacks = []; -var uiLoadedCallbacks = []; -var uiTabChangeCallbacks = []; -var optionsChangedCallbacks = []; -var uiAfterUpdateTimeout = null; -var uiCurrentTab = null; - -/** - * Register callback to be called at each UI update. - * The callback receives an array of MutationRecords as an argument. - */ -function onUiUpdate(callback) { - uiUpdateCallbacks.push(callback); -} - -/** - * Register callback to be called soon after UI updates. - * The callback receives no arguments. - * - * This is preferred over `onUiUpdate` if you don't need - * access to the MutationRecords, as your function will - * not be called quite as often. - */ -function onAfterUiUpdate(callback) { - uiAfterUpdateCallbacks.push(callback); -} - -/** - * Register callback to be called when the UI is loaded. - * The callback receives no arguments. - */ -function onUiLoaded(callback) { - uiLoadedCallbacks.push(callback); -} - -/** - * Register callback to be called when the UI tab is changed. - * The callback receives no arguments. - */ -function onUiTabChange(callback) { - uiTabChangeCallbacks.push(callback); -} - -/** - * Register callback to be called when the options are changed. - * The callback receives no arguments. - * @param callback - */ -function onOptionsChanged(callback) { - optionsChangedCallbacks.push(callback); -} - -function executeCallbacks(queue, arg) { - for (const callback of queue) { - try { - callback(arg); - } catch (e) { - console.error("error running callback", callback, ":", e); - } - } -} - -/** - * Schedule the execution of the callbacks registered with onAfterUiUpdate. - * The callbacks are executed after a short while, unless another call to this function - * is made before that time. IOW, the callbacks are executed only once, even - * when there are multiple mutations observed. - */ -function scheduleAfterUiUpdateCallbacks() { - clearTimeout(uiAfterUpdateTimeout); - uiAfterUpdateTimeout = setTimeout(function() { - executeCallbacks(uiAfterUpdateCallbacks); - }, 200); -} - -var executedOnLoaded = false; - -document.addEventListener("DOMContentLoaded", function() { - var mutationObserver = new MutationObserver(function(m) { - if (!executedOnLoaded && gradioApp().querySelector('#txt2img_prompt')) { - executedOnLoaded = true; - executeCallbacks(uiLoadedCallbacks); - } - - executeCallbacks(uiUpdateCallbacks, m); - scheduleAfterUiUpdateCallbacks(); - const newTab = get_uiCurrentTab(); - if (newTab && (newTab !== uiCurrentTab)) { - uiCurrentTab = newTab; - executeCallbacks(uiTabChangeCallbacks); - } - }); - mutationObserver.observe(gradioApp(), {childList: true, subtree: true}); -}); - -/** - * Add a ctrl+enter as a shortcut to start a generation - */ -document.addEventListener('keydown', function(e) { - const isEnter = e.key === 'Enter' || e.keyCode === 13; - const isModifierKey = e.metaKey || e.ctrlKey || e.altKey; - - const interruptButton = get_uiCurrentTabContent().querySelector('button[id$=_interrupt]'); - const generateButton = get_uiCurrentTabContent().querySelector('button[id$=_generate]'); - - if (isEnter && isModifierKey) { - if (interruptButton.style.display === 'block') { - interruptButton.click(); - setTimeout(function() { - generateButton.click(); - }, 500); - } else { - generateButton.click(); - } - e.preventDefault(); - } -}); - -/** - * checks that a UI element is not in another hidden element or tab content - */ -function uiElementIsVisible(el) { - if (el === document) { - return true; - } - - const computedStyle = getComputedStyle(el); - const isVisible = computedStyle.display !== 'none'; - - if (!isVisible) return false; - return uiElementIsVisible(el.parentNode); -} - -function uiElementInSight(el) { - const clRect = el.getBoundingClientRect(); - const windowHeight = window.innerHeight; - const isOnScreen = clRect.bottom > 0 && clRect.top < windowHeight; - - return isOnScreen; -} +function gradioApp() { + const elems = document.getElementsByTagName('gradio-app'); + const elem = elems.length == 0 ? document : elems[0]; + + if (elem !== document) { + elem.getElementById = function(id) { + return document.getElementById(id); + }; + } + return elem.shadowRoot ? elem.shadowRoot : elem; +} + +/** + * Get the currently selected top-level UI tab button (e.g. the button that says "Extras"). + */ +function get_uiCurrentTab() { + return gradioApp().querySelector('#tabs > .tab-nav > button.selected'); +} + +/** + * Get the first currently visible top-level UI tab content (e.g. the div hosting the "txt2img" UI). + */ +function get_uiCurrentTabContent() { + return gradioApp().querySelector('#tabs > .tabitem[id^=tab_]:not([style*="display: none"])'); +} + +var uiUpdateCallbacks = []; +var uiAfterUpdateCallbacks = []; +var uiLoadedCallbacks = []; +var uiTabChangeCallbacks = []; +var optionsChangedCallbacks = []; +var uiAfterUpdateTimeout = null; +var uiCurrentTab = null; + +/** + * Register callback to be called at each UI update. + * The callback receives an array of MutationRecords as an argument. + */ +function onUiUpdate(callback) { + uiUpdateCallbacks.push(callback); +} + +/** + * Register callback to be called soon after UI updates. + * The callback receives no arguments. + * + * This is preferred over `onUiUpdate` if you don't need + * access to the MutationRecords, as your function will + * not be called quite as often. + */ +function onAfterUiUpdate(callback) { + uiAfterUpdateCallbacks.push(callback); +} + +/** + * Register callback to be called when the UI is loaded. + * The callback receives no arguments. + */ +function onUiLoaded(callback) { + uiLoadedCallbacks.push(callback); +} + +/** + * Register callback to be called when the UI tab is changed. + * The callback receives no arguments. + */ +function onUiTabChange(callback) { + uiTabChangeCallbacks.push(callback); +} + +/** + * Register callback to be called when the options are changed. + * The callback receives no arguments. + * @param callback + */ +function onOptionsChanged(callback) { + optionsChangedCallbacks.push(callback); +} + +function executeCallbacks(queue, arg) { + for (const callback of queue) { + try { + callback(arg); + } catch (e) { + console.error("error running callback", callback, ":", e); + } + } +} + +/** + * Schedule the execution of the callbacks registered with onAfterUiUpdate. + * The callbacks are executed after a short while, unless another call to this function + * is made before that time. IOW, the callbacks are executed only once, even + * when there are multiple mutations observed. + */ +function scheduleAfterUiUpdateCallbacks() { + clearTimeout(uiAfterUpdateTimeout); + uiAfterUpdateTimeout = setTimeout(function() { + executeCallbacks(uiAfterUpdateCallbacks); + }, 200); +} + +var executedOnLoaded = false; + +document.addEventListener("DOMContentLoaded", function() { + var mutationObserver = new MutationObserver(function(m) { + if (!executedOnLoaded && gradioApp().querySelector('#txt2img_prompt')) { + executedOnLoaded = true; + executeCallbacks(uiLoadedCallbacks); + } + + executeCallbacks(uiUpdateCallbacks, m); + scheduleAfterUiUpdateCallbacks(); + const newTab = get_uiCurrentTab(); + if (newTab && (newTab !== uiCurrentTab)) { + uiCurrentTab = newTab; + executeCallbacks(uiTabChangeCallbacks); + } + }); + mutationObserver.observe(gradioApp(), {childList: true, subtree: true}); +}); + +/** + * Add a ctrl+enter as a shortcut to start a generation + */ +document.addEventListener('keydown', function(e) { + const isEnter = e.key === 'Enter' || e.keyCode === 13; + const isModifierKey = e.metaKey || e.ctrlKey || e.altKey; + + const interruptButton = get_uiCurrentTabContent().querySelector('button[id$=_interrupt]'); + const generateButton = get_uiCurrentTabContent().querySelector('button[id$=_generate]'); + + if (isEnter && isModifierKey) { + if (interruptButton.style.display === 'block') { + interruptButton.click(); + setTimeout(function() { + generateButton.click(); + }, 500); + } else { + generateButton.click(); + } + e.preventDefault(); + } +}); + +/** + * checks that a UI element is not in another hidden element or tab content + */ +function uiElementIsVisible(el) { + if (el === document) { + return true; + } + + const computedStyle = getComputedStyle(el); + const isVisible = computedStyle.display !== 'none'; + + if (!isVisible) return false; + return uiElementIsVisible(el.parentNode); +} + +function uiElementInSight(el) { + const clRect = el.getBoundingClientRect(); + const windowHeight = window.innerHeight; + const isOnScreen = clRect.bottom > 0 && clRect.top < windowHeight; + + return isOnScreen; +} From 8aa13d5dce2789a7d0bd802e6d62453b3c380496 Mon Sep 17 00:00:00 2001 From: Anthony Fu Date: Mon, 16 Oct 2023 14:12:18 +0800 Subject: [PATCH 0177/1174] Interrupt after current generation --- modules/call_queue.py | 1 + modules/img2img.py | 2 +- modules/processing.py | 2 +- modules/shared_options.py | 1 + modules/shared_state.py | 11 +++++++++-- scripts/loopback.py | 6 +++--- scripts/xyz_grid.py | 2 +- 7 files changed, 17 insertions(+), 8 deletions(-) diff --git a/modules/call_queue.py b/modules/call_queue.py index ddf0d57383c..01c6d17f685 100644 --- a/modules/call_queue.py +++ b/modules/call_queue.py @@ -78,6 +78,7 @@ def f(*args, extra_outputs_array=extra_outputs, **kwargs): shared.state.skipped = False shared.state.interrupted = False + shared.state.interrupted_next = False shared.state.job_count = 0 if not add_stats: diff --git a/modules/img2img.py b/modules/img2img.py index 52cb577a646..31f8c2aaf4c 100644 --- a/modules/img2img.py +++ b/modules/img2img.py @@ -49,7 +49,7 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=Fal if state.skipped: state.skipped = False - if state.interrupted: + if state.interrupted or state.interrupted_next: break try: diff --git a/modules/processing.py b/modules/processing.py index 40598f5cff4..e7eecd66d10 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -819,7 +819,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: if state.skipped: state.skipped = False - if state.interrupted: + if state.interrupted or state.interrupted_next: break sd_models.reload_model_weights() # model can be changed for example by refiner diff --git a/modules/shared_options.py b/modules/shared_options.py index 32bf735323f..4638ef068f1 100644 --- a/modules/shared_options.py +++ b/modules/shared_options.py @@ -113,6 +113,7 @@ "disable_mmap_load_safetensors": OptionInfo(False, "Disable memmapping for loading .safetensors files.").info("fixes very slow loading speed in some cases"), "hide_ldm_prints": OptionInfo(True, "Prevent Stability-AI's ldm/sgm modules from printing noise to console."), "dump_stacks_on_signal": OptionInfo(False, "Print stack traces before exiting the program with ctrl+c."), + "interrupt_after_current": OptionInfo(False, "Interrupt generation after current image is finished on batch processing"), })) options_templates.update(options_section(('API', "API"), { diff --git a/modules/shared_state.py b/modules/shared_state.py index a68789cc815..c72c3f63cb2 100644 --- a/modules/shared_state.py +++ b/modules/shared_state.py @@ -12,6 +12,7 @@ class State: skipped = False interrupted = False + interrupted_next = False job = "" job_no = 0 job_count = 0 @@ -76,8 +77,12 @@ def skip(self): log.info("Received skip request") def interrupt(self): - self.interrupted = True - log.info("Received interrupt request") + if shared.opts.interrupt_after_current and self.job_count > 1: + self.interrupted_next = True + log.info("Received interrupt request, interrupt after current job") + else: + self.interrupted = True + log.info("Received interrupt request") def nextjob(self): if shared.opts.live_previews_enable and shared.opts.show_progress_every_n_steps == -1: @@ -91,6 +96,7 @@ def dict(self): obj = { "skipped": self.skipped, "interrupted": self.interrupted, + "interrupted_next": self.interrupted_next, "job": self.job, "job_count": self.job_count, "job_timestamp": self.job_timestamp, @@ -114,6 +120,7 @@ def begin(self, job: str = "(unknown)"): self.id_live_preview = 0 self.skipped = False self.interrupted = False + self.interrupted_next = False self.textinfo = None self.job = job devices.torch_gc() diff --git a/scripts/loopback.py b/scripts/loopback.py index 2d5feaf9b26..ad921269afb 100644 --- a/scripts/loopback.py +++ b/scripts/loopback.py @@ -95,7 +95,7 @@ def calculate_denoising_strength(loop): processed = processing.process_images(p) # Generation cancelled. - if state.interrupted: + if state.interrupted or state.interrupted_next: break if initial_seed is None: @@ -122,8 +122,8 @@ def calculate_denoising_strength(loop): p.inpainting_fill = original_inpainting_fill - if state.interrupted: - break + if state.interrupted or state.interrupted_next: + break if len(history) > 1: grid = images.image_grid(history, rows=1) diff --git a/scripts/xyz_grid.py b/scripts/xyz_grid.py index 0dc255bc43d..495008ada08 100644 --- a/scripts/xyz_grid.py +++ b/scripts/xyz_grid.py @@ -688,7 +688,7 @@ def fix_axis_seeds(axis_opt, axis_list): grid_infotext = [None] * (1 + len(zs)) def cell(x, y, z, ix, iy, iz): - if shared.state.interrupted: + if shared.state.interrupted or state.interrupted_next: return Processed(p, [], p.seed, "") pc = copy(p) From 3d15e58b0a30f2ef1e731f9e429f4d3cf1c259c5 Mon Sep 17 00:00:00 2001 From: Anthony Fu Date: Mon, 16 Oct 2023 15:00:17 +0800 Subject: [PATCH 0178/1174] feat: refactor --- modules/shared_state.py | 12 ++++++------ modules/ui.py | 8 +++++++- 2 files changed, 13 insertions(+), 7 deletions(-) diff --git a/modules/shared_state.py b/modules/shared_state.py index c72c3f63cb2..532fdcd8d12 100644 --- a/modules/shared_state.py +++ b/modules/shared_state.py @@ -77,12 +77,12 @@ def skip(self): log.info("Received skip request") def interrupt(self): - if shared.opts.interrupt_after_current and self.job_count > 1: - self.interrupted_next = True - log.info("Received interrupt request, interrupt after current job") - else: - self.interrupted = True - log.info("Received interrupt request") + self.interrupted = True + log.info("Received interrupt request") + + def interrupt_next(self): + self.interrupted_next = True + log.info("Received interrupt request, interrupt after current job") def nextjob(self): if shared.opts.live_previews_enable and shared.opts.show_progress_every_n_steps == -1: diff --git a/modules/ui.py b/modules/ui.py index bcf39199715..c30093d755d 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -216,8 +216,14 @@ def __init__(self, is_img2img): outputs=[], ) + def interrupt_fn(): + if shared.state.job_count > 1 and shared.opts.interrupt_after_current: + shared.state.interrupt_next() + else: + shared.state.interrupt() + self.interrupt.click( - fn=lambda: shared.state.interrupt(), + fn=interrupt_fn, inputs=[], outputs=[], ) From ec718f76b58b183859ed732e11ec748c41a13f76 Mon Sep 17 00:00:00 2001 From: v0xie <28695009+v0xie@users.noreply.github.com> Date: Tue, 17 Oct 2023 23:35:50 -0700 Subject: [PATCH 0179/1174] wip incorrect OFT implementation --- extensions-builtin/Lora/network_oft.py | 82 ++++++++++++++++++++++++++ extensions-builtin/Lora/networks.py | 5 ++ 2 files changed, 87 insertions(+) create mode 100644 extensions-builtin/Lora/network_oft.py diff --git a/extensions-builtin/Lora/network_oft.py b/extensions-builtin/Lora/network_oft.py new file mode 100644 index 00000000000..9ddb175ce06 --- /dev/null +++ b/extensions-builtin/Lora/network_oft.py @@ -0,0 +1,82 @@ +import torch +import network + + +class ModuleTypeOFT(network.ModuleType): + def create_module(self, net: network.Network, weights: network.NetworkWeights): + if all(x in weights.w for x in ["oft_blocks"]): + return NetworkModuleOFT(net, weights) + + return None + +# adapted from https://github.com/kohya-ss/sd-scripts/blob/main/networks/oft.py +class NetworkModuleOFT(network.NetworkModule): + def __init__(self, net: network.Network, weights: network.NetworkWeights): + super().__init__(net, weights) + + self.oft_blocks = weights.w["oft_blocks"] + self.alpha = weights.w["alpha"] + + self.dim = self.oft_blocks.shape[0] + self.num_blocks = self.dim + + #if type(self.alpha) == torch.Tensor: + # self.alpha = self.alpha.detach().numpy() + + if "Linear" in self.sd_module.__class__.__name__: + self.out_dim = self.sd_module.out_features + elif "Conv" in self.sd_module.__class__.__name__: + self.out_dim = self.sd_module.out_channels + + self.constraint = self.alpha * self.out_dim + self.block_size = self.out_dim // self.num_blocks + + self.oft_multiplier = self.multiplier() + + # replace forward method of original linear rather than replacing the module + # self.org_forward = self.sd_module.forward + # self.sd_module.forward = self.forward + + def get_weight(self): + block_Q = self.oft_blocks - self.oft_blocks.transpose(1, 2) + norm_Q = torch.norm(block_Q.flatten()) + new_norm_Q = torch.clamp(norm_Q, max=self.constraint) + block_Q = block_Q * ((new_norm_Q + 1e-8) / (norm_Q + 1e-8)) + I = torch.eye(self.block_size, device=self.oft_blocks.device).unsqueeze(0).repeat(self.num_blocks, 1, 1) + block_R = torch.matmul(I + block_Q, (I - block_Q).inverse()) + + block_R_weighted = self.oft_multiplier * block_R + (1 - self.oft_multiplier) * I + R = torch.block_diag(*block_R_weighted) + + return R + + def calc_updown(self, orig_weight): + oft_blocks = self.oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype) + block_Q = oft_blocks - oft_blocks.transpose(1, 2) + norm_Q = torch.norm(block_Q.flatten()) + new_norm_Q = torch.clamp(norm_Q, max=self.constraint) + block_Q = block_Q * ((new_norm_Q + 1e-8) / (norm_Q + 1e-8)) + I = torch.eye(self.block_size, device=oft_blocks.device).unsqueeze(0).repeat(self.num_blocks, 1, 1) + block_R = torch.matmul(I + block_Q, (I - block_Q).inverse()) + + block_R_weighted = self.oft_multiplier * block_R + (1 - self.oft_multiplier) * I + R = torch.block_diag(*block_R_weighted) + #R = self.get_weight().to(orig_weight.device, dtype=orig_weight.dtype) + # W = R*W_0 + updown = orig_weight + R + output_shape = [R.size(0), orig_weight.size(1)] + return self.finalize_updown(updown, orig_weight, output_shape) + + # def forward(self, x, y=None): + # x = self.org_forward(x) + # if self.oft_multiplier == 0.0: + # return x + + # R = self.get_weight().to(x.device, dtype=x.dtype) + # if x.dim() == 4: + # x = x.permute(0, 2, 3, 1) + # x = torch.matmul(x, R) + # x = x.permute(0, 3, 1, 2) + # else: + # x = torch.matmul(x, R) + # return x diff --git a/extensions-builtin/Lora/networks.py b/extensions-builtin/Lora/networks.py index 60d8dec4cf5..bd1f1b756e0 100644 --- a/extensions-builtin/Lora/networks.py +++ b/extensions-builtin/Lora/networks.py @@ -11,6 +11,7 @@ import network_lokr import network_full import network_norm +import network_oft import torch from typing import Union @@ -28,6 +29,7 @@ network_full.ModuleTypeFull(), network_norm.ModuleTypeNorm(), network_glora.ModuleTypeGLora(), + network_oft.ModuleTypeOFT(), ] @@ -183,6 +185,9 @@ def load_network(name, network_on_disk): elif sd_module is None and "lora_te1_text_model" in key_network_without_network_parts: key = key_network_without_network_parts.replace("lora_te1_text_model", "0_transformer_text_model") sd_module = shared.sd_model.network_layer_mapping.get(key, None) + elif sd_module is None and "oft_unet" in key_network_without_network_parts: + key = key_network_without_network_parts.replace("oft_unet", "diffusion_model") + sd_module = shared.sd_model.network_layer_mapping.get(key, None) # some SD1 Loras also have correct compvis keys if sd_module is None: From 1c6efdbba774d603c592debaccd6f5ad827bd1b2 Mon Sep 17 00:00:00 2001 From: v0xie <28695009+v0xie@users.noreply.github.com> Date: Wed, 18 Oct 2023 04:16:01 -0700 Subject: [PATCH 0180/1174] inference working but SLOW --- extensions-builtin/Lora/network_oft.py | 73 +++++++++++++------------- extensions-builtin/Lora/networks.py | 42 +++++++++++++-- 2 files changed, 75 insertions(+), 40 deletions(-) diff --git a/extensions-builtin/Lora/network_oft.py b/extensions-builtin/Lora/network_oft.py index 9ddb175ce06..f085eca53de 100644 --- a/extensions-builtin/Lora/network_oft.py +++ b/extensions-builtin/Lora/network_oft.py @@ -12,6 +12,7 @@ def create_module(self, net: network.Network, weights: network.NetworkWeights): # adapted from https://github.com/kohya-ss/sd-scripts/blob/main/networks/oft.py class NetworkModuleOFT(network.NetworkModule): def __init__(self, net: network.Network, weights: network.NetworkWeights): + super().__init__(net, weights) self.oft_blocks = weights.w["oft_blocks"] @@ -20,24 +21,29 @@ def __init__(self, net: network.Network, weights: network.NetworkWeights): self.dim = self.oft_blocks.shape[0] self.num_blocks = self.dim - #if type(self.alpha) == torch.Tensor: - # self.alpha = self.alpha.detach().numpy() - if "Linear" in self.sd_module.__class__.__name__: self.out_dim = self.sd_module.out_features elif "Conv" in self.sd_module.__class__.__name__: self.out_dim = self.sd_module.out_channels - self.constraint = self.alpha * self.out_dim + self.constraint = self.alpha + #self.constraint = self.alpha * self.out_dim self.block_size = self.out_dim // self.num_blocks - self.oft_multiplier = self.multiplier() + self.org_module: list[torch.Module] = [self.sd_module] + + self.R = self.get_weight() - # replace forward method of original linear rather than replacing the module - # self.org_forward = self.sd_module.forward - # self.sd_module.forward = self.forward + self.apply_to() + + # replace forward method of original linear rather than replacing the module + def apply_to(self): + self.org_forward = self.org_module[0].forward + self.org_module[0].forward = self.forward - def get_weight(self): + def get_weight(self, multiplier=None): + if not multiplier: + multiplier = self.multiplier() block_Q = self.oft_blocks - self.oft_blocks.transpose(1, 2) norm_Q = torch.norm(block_Q.flatten()) new_norm_Q = torch.clamp(norm_Q, max=self.constraint) @@ -45,38 +51,31 @@ def get_weight(self): I = torch.eye(self.block_size, device=self.oft_blocks.device).unsqueeze(0).repeat(self.num_blocks, 1, 1) block_R = torch.matmul(I + block_Q, (I - block_Q).inverse()) - block_R_weighted = self.oft_multiplier * block_R + (1 - self.oft_multiplier) * I + block_R_weighted = multiplier * block_R + (1 - multiplier) * I R = torch.block_diag(*block_R_weighted) return R def calc_updown(self, orig_weight): - oft_blocks = self.oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype) - block_Q = oft_blocks - oft_blocks.transpose(1, 2) - norm_Q = torch.norm(block_Q.flatten()) - new_norm_Q = torch.clamp(norm_Q, max=self.constraint) - block_Q = block_Q * ((new_norm_Q + 1e-8) / (norm_Q + 1e-8)) - I = torch.eye(self.block_size, device=oft_blocks.device).unsqueeze(0).repeat(self.num_blocks, 1, 1) - block_R = torch.matmul(I + block_Q, (I - block_Q).inverse()) - - block_R_weighted = self.oft_multiplier * block_R + (1 - self.oft_multiplier) * I - R = torch.block_diag(*block_R_weighted) - #R = self.get_weight().to(orig_weight.device, dtype=orig_weight.dtype) - # W = R*W_0 - updown = orig_weight + R - output_shape = [R.size(0), orig_weight.size(1)] + R = self.R + if orig_weight.dim() == 4: + weight = torch.einsum("oihw, op -> pihw", orig_weight, R) + else: + weight = torch.einsum("oi, op -> pi", orig_weight, R) + updown = orig_weight @ R + output_shape = [orig_weight.size(0), R.size(1)] + #output_shape = [R.size(0), orig_weight.size(1)] return self.finalize_updown(updown, orig_weight, output_shape) - # def forward(self, x, y=None): - # x = self.org_forward(x) - # if self.oft_multiplier == 0.0: - # return x - - # R = self.get_weight().to(x.device, dtype=x.dtype) - # if x.dim() == 4: - # x = x.permute(0, 2, 3, 1) - # x = torch.matmul(x, R) - # x = x.permute(0, 3, 1, 2) - # else: - # x = torch.matmul(x, R) - # return x + def forward(self, x, y=None): + x = self.org_forward(x) + if self.multiplier() == 0.0: + return x + R = self.get_weight().to(x.device, dtype=x.dtype) + if x.dim() == 4: + x = x.permute(0, 2, 3, 1) + x = torch.matmul(x, R) + x = x.permute(0, 3, 1, 2) + else: + x = torch.matmul(x, R) + return x diff --git a/extensions-builtin/Lora/networks.py b/extensions-builtin/Lora/networks.py index bd1f1b756e0..e5e73450bb6 100644 --- a/extensions-builtin/Lora/networks.py +++ b/extensions-builtin/Lora/networks.py @@ -169,6 +169,10 @@ def load_network(name, network_on_disk): else: emb_dict[vec_name] = weight bundle_embeddings[emb_name] = emb_dict + + #if key_network_without_network_parts == "oft_unet": + # print(key_network_without_network_parts) + # pass key = convert_diffusers_name_to_compvis(key_network_without_network_parts, is_sd2) sd_module = shared.sd_model.network_layer_mapping.get(key, None) @@ -185,15 +189,39 @@ def load_network(name, network_on_disk): elif sd_module is None and "lora_te1_text_model" in key_network_without_network_parts: key = key_network_without_network_parts.replace("lora_te1_text_model", "0_transformer_text_model") sd_module = shared.sd_model.network_layer_mapping.get(key, None) - elif sd_module is None and "oft_unet" in key_network_without_network_parts: - key = key_network_without_network_parts.replace("oft_unet", "diffusion_model") - sd_module = shared.sd_model.network_layer_mapping.get(key, None) # some SD1 Loras also have correct compvis keys if sd_module is None: key = key_network_without_network_parts.replace("lora_te1_text_model", "transformer_text_model") sd_module = shared.sd_model.network_layer_mapping.get(key, None) + elif sd_module is None and "oft_unet" in key_network_without_network_parts: + # UNET_TARGET_REPLACE_MODULE_ALL_LINEAR = ["Transformer2DModel"] + # UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"] + UNET_TARGET_REPLACE_MODULE_ATTN_ONLY = ["CrossAttention"] + # TODO: Change matchedm odules based on whether all linear, conv, etc + + key = key_network_without_network_parts.replace("oft_unet", "diffusion_model") + sd_module = shared.sd_model.network_layer_mapping.get(key, None) + #key_no_suffix = key.rsplit("_to_", 1)[0] + ## Match all modules of class CrossAttention + #replace_module_list = [] + #for module_type in UNET_TARGET_REPLACE_MODULE_ATTN_ONLY: + # replace_module_list += [module for k, module in shared.sd_model.network_layer_mapping.items() if module_type in module.__class__.__name__] + + #matched_module = replace_module_list.get(key_no_suffix, None) + #if key.endswith('to_q'): + # sd_module = matched_module.to_q or None + #if key.endswith('to_k'): + # sd_module = matched_module.to_k or None + #if key.endswith('to_v'): + # sd_module = matched_module.to_v or None + #if key.endswith('to_out_0'): + # sd_module = matched_module.to_out[0] or None + #if key.endswith('to_out_1'): + # sd_module = matched_module.to_out[1] or None + + if sd_module is None: keys_failed_to_match[key_network] = key continue @@ -214,6 +242,14 @@ def load_network(name, network_on_disk): raise AssertionError(f"Could not find a module type (out of {', '.join([x.__class__.__name__ for x in module_types])}) that would accept those keys: {', '.join(weights.w)}") net.modules[key] = net_module + + # replaces forward method of original Linear + # applied_to_count = 0 + #for key, created_module in net.modules.items(): + # if isinstance(created_module, network_oft.NetworkModuleOFT): + # net_module.apply_to() + #applied_to_count += 1 + # print(f'Applied OFT modules: {applied_to_count}') embeddings = {} for emb_name, data in bundle_embeddings.items(): From 853e21d98eada4db9a9fd1ae8eda90cf763e2818 Mon Sep 17 00:00:00 2001 From: v0xie <28695009+v0xie@users.noreply.github.com> Date: Wed, 18 Oct 2023 04:27:44 -0700 Subject: [PATCH 0181/1174] faster by using cached R in forward --- extensions-builtin/Lora/network_oft.py | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/extensions-builtin/Lora/network_oft.py b/extensions-builtin/Lora/network_oft.py index f085eca53de..68efb1db984 100644 --- a/extensions-builtin/Lora/network_oft.py +++ b/extensions-builtin/Lora/network_oft.py @@ -57,21 +57,32 @@ def get_weight(self, multiplier=None): return R def calc_updown(self, orig_weight): + # this works R = self.R + + # this causes major deepfrying i.e. just doesn't work + # R = self.R.to(orig_weight.device, dtype=orig_weight.dtype) + if orig_weight.dim() == 4: weight = torch.einsum("oihw, op -> pihw", orig_weight, R) else: weight = torch.einsum("oi, op -> pi", orig_weight, R) + updown = orig_weight @ R - output_shape = [orig_weight.size(0), R.size(1)] - #output_shape = [R.size(0), orig_weight.size(1)] + output_shape = self.oft_blocks.shape + + ## this works + # updown = orig_weight @ R + # output_shape = [orig_weight.size(0), R.size(1)] + return self.finalize_updown(updown, orig_weight, output_shape) def forward(self, x, y=None): x = self.org_forward(x) if self.multiplier() == 0.0: return x - R = self.get_weight().to(x.device, dtype=x.dtype) + #R = self.get_weight().to(x.device, dtype=x.dtype) + R = self.R.to(x.device, dtype=x.dtype) if x.dim() == 4: x = x.permute(0, 2, 3, 1) x = torch.matmul(x, R) From eb01d7f0e0fb46285985803296a25715165fb3f9 Mon Sep 17 00:00:00 2001 From: v0xie <28695009+v0xie@users.noreply.github.com> Date: Wed, 18 Oct 2023 04:56:53 -0700 Subject: [PATCH 0182/1174] faster by calculating R in updown and using cached R in forward --- extensions-builtin/Lora/network_oft.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/extensions-builtin/Lora/network_oft.py b/extensions-builtin/Lora/network_oft.py index 68efb1db984..fd5b0c0fde2 100644 --- a/extensions-builtin/Lora/network_oft.py +++ b/extensions-builtin/Lora/network_oft.py @@ -58,17 +58,18 @@ def get_weight(self, multiplier=None): def calc_updown(self, orig_weight): # this works - R = self.R + # R = self.R + self.R = self.get_weight(self.multiplier()) - # this causes major deepfrying i.e. just doesn't work + # sending R to device causes major deepfrying i.e. just doesn't work # R = self.R.to(orig_weight.device, dtype=orig_weight.dtype) - if orig_weight.dim() == 4: - weight = torch.einsum("oihw, op -> pihw", orig_weight, R) - else: - weight = torch.einsum("oi, op -> pi", orig_weight, R) + # if orig_weight.dim() == 4: + # weight = torch.einsum("oihw, op -> pihw", orig_weight, R) + # else: + # weight = torch.einsum("oi, op -> pi", orig_weight, R) - updown = orig_weight @ R + updown = orig_weight @ self.R output_shape = self.oft_blocks.shape ## this works From 7c128bbdac0da1767c239174e91af6f327845372 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Thu, 19 Oct 2023 13:56:17 +0800 Subject: [PATCH 0183/1174] Add fp8 for sd unet --- extensions-builtin/Lora/network.py | 2 +- extensions-builtin/Lora/network_full.py | 4 ++-- extensions-builtin/Lora/network_glora.py | 10 +++++----- extensions-builtin/Lora/network_hada.py | 12 ++++++------ extensions-builtin/Lora/network_ia3.py | 2 +- extensions-builtin/Lora/network_lokr.py | 18 +++++++++--------- extensions-builtin/Lora/network_lora.py | 6 +++--- extensions-builtin/Lora/network_norm.py | 4 ++-- extensions-builtin/Lora/networks.py | 6 +++--- modules/cmd_args.py | 1 + modules/sd_models.py | 3 +++ 11 files changed, 36 insertions(+), 32 deletions(-) diff --git a/extensions-builtin/Lora/network.py b/extensions-builtin/Lora/network.py index 6021fd8de0f..a62e5eff9e1 100644 --- a/extensions-builtin/Lora/network.py +++ b/extensions-builtin/Lora/network.py @@ -137,7 +137,7 @@ def calc_scale(self): def finalize_updown(self, updown, orig_weight, output_shape, ex_bias=None): if self.bias is not None: updown = updown.reshape(self.bias.shape) - updown += self.bias.to(orig_weight.device, dtype=orig_weight.dtype) + updown += self.bias.to(orig_weight.device, dtype=updown.dtype) updown = updown.reshape(output_shape) if len(output_shape) == 4: diff --git a/extensions-builtin/Lora/network_full.py b/extensions-builtin/Lora/network_full.py index bf6930e96c0..f221c95f3b5 100644 --- a/extensions-builtin/Lora/network_full.py +++ b/extensions-builtin/Lora/network_full.py @@ -18,9 +18,9 @@ def __init__(self, net: network.Network, weights: network.NetworkWeights): def calc_updown(self, orig_weight): output_shape = self.weight.shape - updown = self.weight.to(orig_weight.device, dtype=orig_weight.dtype) + updown = self.weight.to(orig_weight.device) if self.ex_bias is not None: - ex_bias = self.ex_bias.to(orig_weight.device, dtype=orig_weight.dtype) + ex_bias = self.ex_bias.to(orig_weight.device) else: ex_bias = None diff --git a/extensions-builtin/Lora/network_glora.py b/extensions-builtin/Lora/network_glora.py index 492d487078d..efe5c6814fa 100644 --- a/extensions-builtin/Lora/network_glora.py +++ b/extensions-builtin/Lora/network_glora.py @@ -22,12 +22,12 @@ def __init__(self, net: network.Network, weights: network.NetworkWeights): self.w2b = weights.w["b2.weight"] def calc_updown(self, orig_weight): - w1a = self.w1a.to(orig_weight.device, dtype=orig_weight.dtype) - w1b = self.w1b.to(orig_weight.device, dtype=orig_weight.dtype) - w2a = self.w2a.to(orig_weight.device, dtype=orig_weight.dtype) - w2b = self.w2b.to(orig_weight.device, dtype=orig_weight.dtype) + w1a = self.w1a.to(orig_weight.device) + w1b = self.w1b.to(orig_weight.device) + w2a = self.w2a.to(orig_weight.device) + w2b = self.w2b.to(orig_weight.device) output_shape = [w1a.size(0), w1b.size(1)] - updown = ((w2b @ w1b) + ((orig_weight @ w2a) @ w1a)) + updown = ((w2b @ w1b) + ((orig_weight.to(dtype = w1a.dtype) @ w2a) @ w1a)) return self.finalize_updown(updown, orig_weight, output_shape) diff --git a/extensions-builtin/Lora/network_hada.py b/extensions-builtin/Lora/network_hada.py index 5fcb0695fbb..d95a0fd18e3 100644 --- a/extensions-builtin/Lora/network_hada.py +++ b/extensions-builtin/Lora/network_hada.py @@ -27,16 +27,16 @@ def __init__(self, net: network.Network, weights: network.NetworkWeights): self.t2 = weights.w.get("hada_t2") def calc_updown(self, orig_weight): - w1a = self.w1a.to(orig_weight.device, dtype=orig_weight.dtype) - w1b = self.w1b.to(orig_weight.device, dtype=orig_weight.dtype) - w2a = self.w2a.to(orig_weight.device, dtype=orig_weight.dtype) - w2b = self.w2b.to(orig_weight.device, dtype=orig_weight.dtype) + w1a = self.w1a.to(orig_weight.device) + w1b = self.w1b.to(orig_weight.device) + w2a = self.w2a.to(orig_weight.device) + w2b = self.w2b.to(orig_weight.device) output_shape = [w1a.size(0), w1b.size(1)] if self.t1 is not None: output_shape = [w1a.size(1), w1b.size(1)] - t1 = self.t1.to(orig_weight.device, dtype=orig_weight.dtype) + t1 = self.t1.to(orig_weight.device) updown1 = lyco_helpers.make_weight_cp(t1, w1a, w1b) output_shape += t1.shape[2:] else: @@ -45,7 +45,7 @@ def calc_updown(self, orig_weight): updown1 = lyco_helpers.rebuild_conventional(w1a, w1b, output_shape) if self.t2 is not None: - t2 = self.t2.to(orig_weight.device, dtype=orig_weight.dtype) + t2 = self.t2.to(orig_weight.device) updown2 = lyco_helpers.make_weight_cp(t2, w2a, w2b) else: updown2 = lyco_helpers.rebuild_conventional(w2a, w2b, output_shape) diff --git a/extensions-builtin/Lora/network_ia3.py b/extensions-builtin/Lora/network_ia3.py index 7edc4249791..96faeaf3ede 100644 --- a/extensions-builtin/Lora/network_ia3.py +++ b/extensions-builtin/Lora/network_ia3.py @@ -17,7 +17,7 @@ def __init__(self, net: network.Network, weights: network.NetworkWeights): self.on_input = weights.w["on_input"].item() def calc_updown(self, orig_weight): - w = self.w.to(orig_weight.device, dtype=orig_weight.dtype) + w = self.w.to(orig_weight.device) output_shape = [w.size(0), orig_weight.size(1)] if self.on_input: diff --git a/extensions-builtin/Lora/network_lokr.py b/extensions-builtin/Lora/network_lokr.py index 340acdab3d0..fcdaeafd896 100644 --- a/extensions-builtin/Lora/network_lokr.py +++ b/extensions-builtin/Lora/network_lokr.py @@ -37,22 +37,22 @@ def __init__(self, net: network.Network, weights: network.NetworkWeights): def calc_updown(self, orig_weight): if self.w1 is not None: - w1 = self.w1.to(orig_weight.device, dtype=orig_weight.dtype) + w1 = self.w1.to(orig_weight.device) else: - w1a = self.w1a.to(orig_weight.device, dtype=orig_weight.dtype) - w1b = self.w1b.to(orig_weight.device, dtype=orig_weight.dtype) + w1a = self.w1a.to(orig_weight.device) + w1b = self.w1b.to(orig_weight.device) w1 = w1a @ w1b if self.w2 is not None: - w2 = self.w2.to(orig_weight.device, dtype=orig_weight.dtype) + w2 = self.w2.to(orig_weight.device) elif self.t2 is None: - w2a = self.w2a.to(orig_weight.device, dtype=orig_weight.dtype) - w2b = self.w2b.to(orig_weight.device, dtype=orig_weight.dtype) + w2a = self.w2a.to(orig_weight.device) + w2b = self.w2b.to(orig_weight.device) w2 = w2a @ w2b else: - t2 = self.t2.to(orig_weight.device, dtype=orig_weight.dtype) - w2a = self.w2a.to(orig_weight.device, dtype=orig_weight.dtype) - w2b = self.w2b.to(orig_weight.device, dtype=orig_weight.dtype) + t2 = self.t2.to(orig_weight.device) + w2a = self.w2a.to(orig_weight.device) + w2b = self.w2b.to(orig_weight.device) w2 = lyco_helpers.make_weight_cp(t2, w2a, w2b) output_shape = [w1.size(0) * w2.size(0), w1.size(1) * w2.size(1)] diff --git a/extensions-builtin/Lora/network_lora.py b/extensions-builtin/Lora/network_lora.py index 26c0a72c237..4cc4029510f 100644 --- a/extensions-builtin/Lora/network_lora.py +++ b/extensions-builtin/Lora/network_lora.py @@ -61,13 +61,13 @@ def create_module(self, weights, key, none_ok=False): return module def calc_updown(self, orig_weight): - up = self.up_model.weight.to(orig_weight.device, dtype=orig_weight.dtype) - down = self.down_model.weight.to(orig_weight.device, dtype=orig_weight.dtype) + up = self.up_model.weight.to(orig_weight.device) + down = self.down_model.weight.to(orig_weight.device) output_shape = [up.size(0), down.size(1)] if self.mid_model is not None: # cp-decomposition - mid = self.mid_model.weight.to(orig_weight.device, dtype=orig_weight.dtype) + mid = self.mid_model.weight.to(orig_weight.device) updown = lyco_helpers.rebuild_cp_decomposition(up, down, mid) output_shape += mid.shape[2:] else: diff --git a/extensions-builtin/Lora/network_norm.py b/extensions-builtin/Lora/network_norm.py index ce450158068..d25afcbb928 100644 --- a/extensions-builtin/Lora/network_norm.py +++ b/extensions-builtin/Lora/network_norm.py @@ -18,10 +18,10 @@ def __init__(self, net: network.Network, weights: network.NetworkWeights): def calc_updown(self, orig_weight): output_shape = self.w_norm.shape - updown = self.w_norm.to(orig_weight.device, dtype=orig_weight.dtype) + updown = self.w_norm.to(orig_weight.device) if self.b_norm is not None: - ex_bias = self.b_norm.to(orig_weight.device, dtype=orig_weight.dtype) + ex_bias = self.b_norm.to(orig_weight.device) else: ex_bias = None diff --git a/extensions-builtin/Lora/networks.py b/extensions-builtin/Lora/networks.py index 60d8dec4cf5..8ea4ea6092d 100644 --- a/extensions-builtin/Lora/networks.py +++ b/extensions-builtin/Lora/networks.py @@ -381,12 +381,12 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn # inpainting model. zero pad updown to make channel[1] 4 to 9 updown = torch.nn.functional.pad(updown, (0, 0, 0, 0, 0, 5)) - self.weight += updown + self.weight.copy_((self.weight.to(dtype=updown.dtype) + updown).to(dtype=self.weight.dtype)) if ex_bias is not None and hasattr(self, 'bias'): if self.bias is None: - self.bias = torch.nn.Parameter(ex_bias) + self.bias = torch.nn.Parameter(ex_bias).to(self.weight.dtype) else: - self.bias += ex_bias + self.bias.copy_((self.bias.to(dtype=ex_bias.dtype) + ex_bias).to(dtype=self.bias.dtype)) except RuntimeError as e: logging.debug(f"Network {net.name} layer {network_layer_name}: {e}") extra_network_lora.errors[net.name] = extra_network_lora.errors.get(net.name, 0) + 1 diff --git a/modules/cmd_args.py b/modules/cmd_args.py index 4e602a84270..0f14c71e4a8 100644 --- a/modules/cmd_args.py +++ b/modules/cmd_args.py @@ -118,3 +118,4 @@ parser.add_argument("--disable-all-extensions", action='store_true', help="prevent all extensions from running regardless of any other settings", default=False) parser.add_argument("--disable-extra-extensions", action='store_true', help="prevent all extensions except built-in from running regardless of any other settings", default=False) parser.add_argument("--skip-load-model-at-start", action='store_true', help="if load a model at web start, only take effect when --nowebui", ) +parser.add_argument("--opt-unet-fp8-storage", action='store_true', help="use fp8 for SD UNet to save vram", default=False) diff --git a/modules/sd_models.py b/modules/sd_models.py index 3b6cdea18d4..3b8ff8209b3 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -391,6 +391,9 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer devices.dtype_unet = torch.float16 timer.record("apply half()") + if shared.cmd_opts.opt_unet_fp8_storage: + model.model.diffusion_model = model.model.diffusion_model.to(torch.float8_e4m3fn) + timer.record("apply fp8 unet") devices.unet_needs_upcast = shared.cmd_opts.upcast_sampling and devices.dtype == torch.float16 and devices.dtype_unet == torch.float16 From 5f9ddfa46f28ca2aa9e0bd832f6bbd67069be63e Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Thu, 19 Oct 2023 23:57:22 +0800 Subject: [PATCH 0184/1174] Add sdxl only arg --- modules/cmd_args.py | 1 + modules/sd_models.py | 3 +++ 2 files changed, 4 insertions(+) diff --git a/modules/cmd_args.py b/modules/cmd_args.py index 0f14c71e4a8..20bfb2c442b 100644 --- a/modules/cmd_args.py +++ b/modules/cmd_args.py @@ -119,3 +119,4 @@ parser.add_argument("--disable-extra-extensions", action='store_true', help="prevent all extensions except built-in from running regardless of any other settings", default=False) parser.add_argument("--skip-load-model-at-start", action='store_true', help="if load a model at web start, only take effect when --nowebui", ) parser.add_argument("--opt-unet-fp8-storage", action='store_true', help="use fp8 for SD UNet to save vram", default=False) +parser.add_argument("--opt-unet-fp8-storage-xl", action='store_true', help="use fp8 for SD UNet to save vram", default=False) diff --git a/modules/sd_models.py b/modules/sd_models.py index 3b8ff8209b3..08af128fc4e 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -394,6 +394,9 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer if shared.cmd_opts.opt_unet_fp8_storage: model.model.diffusion_model = model.model.diffusion_model.to(torch.float8_e4m3fn) timer.record("apply fp8 unet") + elif model.is_sdxl and shared.cmd_opts.opt_unet_fp8_storage_xl: + model.model.diffusion_model = model.model.diffusion_model.to(torch.float8_e4m3fn) + timer.record("apply fp8 unet for sdxl") devices.unet_needs_upcast = shared.cmd_opts.upcast_sampling and devices.dtype == torch.float16 and devices.dtype_unet == torch.float16 From 321680ccd0e0404223fbdf4f26498f7d0317fb75 Mon Sep 17 00:00:00 2001 From: v0xie <28695009+v0xie@users.noreply.github.com> Date: Thu, 19 Oct 2023 12:41:17 -0700 Subject: [PATCH 0185/1174] refactor: fix constraint, re-use get_weight --- extensions-builtin/Lora/network_oft.py | 40 +++++++++++--------------- 1 file changed, 16 insertions(+), 24 deletions(-) diff --git a/extensions-builtin/Lora/network_oft.py b/extensions-builtin/Lora/network_oft.py index fd5b0c0fde2..2af1bc4cf32 100644 --- a/extensions-builtin/Lora/network_oft.py +++ b/extensions-builtin/Lora/network_oft.py @@ -9,7 +9,7 @@ def create_module(self, net: network.Network, weights: network.NetworkWeights): return None -# adapted from https://github.com/kohya-ss/sd-scripts/blob/main/networks/oft.py +# adapted from kohya's implementation https://github.com/kohya-ss/sd-scripts/blob/main/networks/oft.py class NetworkModuleOFT(network.NetworkModule): def __init__(self, net: network.Network, weights: network.NetworkWeights): @@ -17,7 +17,6 @@ def __init__(self, net: network.Network, weights: network.NetworkWeights): self.oft_blocks = weights.w["oft_blocks"] self.alpha = weights.w["alpha"] - self.dim = self.oft_blocks.shape[0] self.num_blocks = self.dim @@ -26,64 +25,57 @@ def __init__(self, net: network.Network, weights: network.NetworkWeights): elif "Conv" in self.sd_module.__class__.__name__: self.out_dim = self.sd_module.out_channels - self.constraint = self.alpha - #self.constraint = self.alpha * self.out_dim + self.constraint = self.alpha * self.out_dim self.block_size = self.out_dim // self.num_blocks self.org_module: list[torch.Module] = [self.sd_module] - - self.R = self.get_weight() - + self.R = self.get_weight(self.oft_blocks) self.apply_to() # replace forward method of original linear rather than replacing the module + # how do we revert this to unload the weights? def apply_to(self): self.org_forward = self.org_module[0].forward self.org_module[0].forward = self.forward - def get_weight(self, multiplier=None): - if not multiplier: - multiplier = self.multiplier() - block_Q = self.oft_blocks - self.oft_blocks.transpose(1, 2) + def get_weight(self, oft_blocks, multiplier=None): + block_Q = oft_blocks - oft_blocks.transpose(1, 2) norm_Q = torch.norm(block_Q.flatten()) new_norm_Q = torch.clamp(norm_Q, max=self.constraint) block_Q = block_Q * ((new_norm_Q + 1e-8) / (norm_Q + 1e-8)) I = torch.eye(self.block_size, device=self.oft_blocks.device).unsqueeze(0).repeat(self.num_blocks, 1, 1) block_R = torch.matmul(I + block_Q, (I - block_Q).inverse()) - - block_R_weighted = multiplier * block_R + (1 - multiplier) * I - R = torch.block_diag(*block_R_weighted) + #block_R_weighted = multiplier * block_R + (1 - multiplier) * I + #R = torch.block_diag(*block_R_weighted) + R = torch.block_diag(*block_R) return R def calc_updown(self, orig_weight): - # this works - # R = self.R - self.R = self.get_weight(self.multiplier()) + oft_blocks = self.oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype) - # sending R to device causes major deepfrying i.e. just doesn't work - # R = self.R.to(orig_weight.device, dtype=orig_weight.dtype) + R = self.get_weight(oft_blocks) + self.R = R # if orig_weight.dim() == 4: # weight = torch.einsum("oihw, op -> pihw", orig_weight, R) # else: # weight = torch.einsum("oi, op -> pi", orig_weight, R) - updown = orig_weight @ self.R + updown = orig_weight @ R output_shape = self.oft_blocks.shape - ## this works - # updown = orig_weight @ R - # output_shape = [orig_weight.size(0), R.size(1)] - return self.finalize_updown(updown, orig_weight, output_shape) def forward(self, x, y=None): x = self.org_forward(x) if self.multiplier() == 0.0: return x + + # calculating R here is excruciatingly slow #R = self.get_weight().to(x.device, dtype=x.dtype) R = self.R.to(x.device, dtype=x.dtype) + if x.dim() == 4: x = x.permute(0, 2, 3, 1) x = torch.matmul(x, R) From d10c4db57ed08234a7aed5f530f269ff78544ab0 Mon Sep 17 00:00:00 2001 From: v0xie <28695009+v0xie@users.noreply.github.com> Date: Thu, 19 Oct 2023 12:52:14 -0700 Subject: [PATCH 0186/1174] style: formatting --- extensions-builtin/Lora/network_oft.py | 4 +-- extensions-builtin/Lora/networks.py | 35 -------------------------- 2 files changed, 2 insertions(+), 37 deletions(-) diff --git a/extensions-builtin/Lora/network_oft.py b/extensions-builtin/Lora/network_oft.py index 2af1bc4cf32..0a87958e20b 100644 --- a/extensions-builtin/Lora/network_oft.py +++ b/extensions-builtin/Lora/network_oft.py @@ -37,7 +37,7 @@ def __init__(self, net: network.Network, weights: network.NetworkWeights): def apply_to(self): self.org_forward = self.org_module[0].forward self.org_module[0].forward = self.forward - + def get_weight(self, oft_blocks, multiplier=None): block_Q = oft_blocks - oft_blocks.transpose(1, 2) norm_Q = torch.norm(block_Q.flatten()) @@ -66,7 +66,7 @@ def calc_updown(self, orig_weight): output_shape = self.oft_blocks.shape return self.finalize_updown(updown, orig_weight, output_shape) - + def forward(self, x, y=None): x = self.org_forward(x) if self.multiplier() == 0.0: diff --git a/extensions-builtin/Lora/networks.py b/extensions-builtin/Lora/networks.py index e5e73450bb6..78a97033d58 100644 --- a/extensions-builtin/Lora/networks.py +++ b/extensions-builtin/Lora/networks.py @@ -169,10 +169,6 @@ def load_network(name, network_on_disk): else: emb_dict[vec_name] = weight bundle_embeddings[emb_name] = emb_dict - - #if key_network_without_network_parts == "oft_unet": - # print(key_network_without_network_parts) - # pass key = convert_diffusers_name_to_compvis(key_network_without_network_parts, is_sd2) sd_module = shared.sd_model.network_layer_mapping.get(key, None) @@ -196,31 +192,8 @@ def load_network(name, network_on_disk): sd_module = shared.sd_model.network_layer_mapping.get(key, None) elif sd_module is None and "oft_unet" in key_network_without_network_parts: - # UNET_TARGET_REPLACE_MODULE_ALL_LINEAR = ["Transformer2DModel"] - # UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"] - UNET_TARGET_REPLACE_MODULE_ATTN_ONLY = ["CrossAttention"] - # TODO: Change matchedm odules based on whether all linear, conv, etc - key = key_network_without_network_parts.replace("oft_unet", "diffusion_model") sd_module = shared.sd_model.network_layer_mapping.get(key, None) - #key_no_suffix = key.rsplit("_to_", 1)[0] - ## Match all modules of class CrossAttention - #replace_module_list = [] - #for module_type in UNET_TARGET_REPLACE_MODULE_ATTN_ONLY: - # replace_module_list += [module for k, module in shared.sd_model.network_layer_mapping.items() if module_type in module.__class__.__name__] - - #matched_module = replace_module_list.get(key_no_suffix, None) - #if key.endswith('to_q'): - # sd_module = matched_module.to_q or None - #if key.endswith('to_k'): - # sd_module = matched_module.to_k or None - #if key.endswith('to_v'): - # sd_module = matched_module.to_v or None - #if key.endswith('to_out_0'): - # sd_module = matched_module.to_out[0] or None - #if key.endswith('to_out_1'): - # sd_module = matched_module.to_out[1] or None - if sd_module is None: keys_failed_to_match[key_network] = key @@ -242,14 +215,6 @@ def load_network(name, network_on_disk): raise AssertionError(f"Could not find a module type (out of {', '.join([x.__class__.__name__ for x in module_types])}) that would accept those keys: {', '.join(weights.w)}") net.modules[key] = net_module - - # replaces forward method of original Linear - # applied_to_count = 0 - #for key, created_module in net.modules.items(): - # if isinstance(created_module, network_oft.NetworkModuleOFT): - # net_module.apply_to() - #applied_to_count += 1 - # print(f'Applied OFT modules: {applied_to_count}') embeddings = {} for emb_name, data in bundle_embeddings.items(): From 0550659ce6e1c37d1ab05cb8a2cb31d499fa552f Mon Sep 17 00:00:00 2001 From: v0xie <28695009+v0xie@users.noreply.github.com> Date: Thu, 19 Oct 2023 13:13:02 -0700 Subject: [PATCH 0187/1174] style: fix ambiguous variable name --- extensions-builtin/Lora/network_oft.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/extensions-builtin/Lora/network_oft.py b/extensions-builtin/Lora/network_oft.py index 0a87958e20b..4e8382c18c7 100644 --- a/extensions-builtin/Lora/network_oft.py +++ b/extensions-builtin/Lora/network_oft.py @@ -43,8 +43,8 @@ def get_weight(self, oft_blocks, multiplier=None): norm_Q = torch.norm(block_Q.flatten()) new_norm_Q = torch.clamp(norm_Q, max=self.constraint) block_Q = block_Q * ((new_norm_Q + 1e-8) / (norm_Q + 1e-8)) - I = torch.eye(self.block_size, device=self.oft_blocks.device).unsqueeze(0).repeat(self.num_blocks, 1, 1) - block_R = torch.matmul(I + block_Q, (I - block_Q).inverse()) + m_I = torch.eye(self.block_size, device=self.oft_blocks.device).unsqueeze(0).repeat(self.num_blocks, 1, 1) + block_R = torch.matmul(m_I + block_Q, (m_I - block_Q).inverse()) #block_R_weighted = multiplier * block_R + (1 - multiplier) * I #R = torch.block_diag(*block_R_weighted) R = torch.block_diag(*block_R) From 384fab9627942dc7a5771368180bab9cfe0c2877 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sat, 21 Oct 2023 08:45:51 +0300 Subject: [PATCH 0188/1174] rework some of changes for emphasis editing keys, force conversion of old-style emphasis --- javascript/edit-attention.js | 101 ++++++++++++++++------------------- modules/shared_options.py | 3 +- 2 files changed, 47 insertions(+), 57 deletions(-) diff --git a/javascript/edit-attention.js b/javascript/edit-attention.js index 45d9a788a17..f3af9a4c345 100644 --- a/javascript/edit-attention.js +++ b/javascript/edit-attention.js @@ -26,10 +26,15 @@ function keyupEditAttention(event) { // Set the selection to the text between the parenthesis const parenContent = text.substring(beforeParen + 1, selectionStart + afterParen); - if (!/.*:-?[\d.]+/s.test(parenContent)) return false; - const lastColon = parenContent.lastIndexOf(":"); - selectionStart = beforeParen + 1; - selectionEnd = selectionStart + lastColon; + if (/.*:-?[\d.]+/s.test(parenContent)) { + const lastColon = parenContent.lastIndexOf(":"); + selectionStart = beforeParen + 1; + selectionEnd = selectionStart + lastColon; + } else { + selectionStart = beforeParen + 1; + selectionEnd = selectionStart + parenContent.length; + } + target.setSelectionRange(selectionStart, selectionEnd); return true; } @@ -58,7 +63,7 @@ function keyupEditAttention(event) { } // If the user hasn't selected anything, let's select their current parenthesis block or word - if (!selectCurrentParenthesisBlock('<', '>') && !selectCurrentParenthesisBlock('(', ')')) { + if (!selectCurrentParenthesisBlock('<', '>') && !selectCurrentParenthesisBlock('(', ')') && !selectCurrentParenthesisBlock('[', ']')) { selectCurrentWord(); } @@ -66,44 +71,31 @@ function keyupEditAttention(event) { var closeCharacter = ')'; var delta = opts.keyedit_precision_attention; + var start = selectionStart > 0 ? text[selectionStart - 1] : ""; + var end = text[selectionEnd]; - if (selectionStart > 0 && /<.*:-?[\d.]+>/s.test(text.slice(selectionStart - 1, selectionEnd + text.slice(selectionEnd).indexOf(">") + 1))) { + if (start == '<') { closeCharacter = '>'; delta = opts.keyedit_precision_extra; - } else if (selectionStart > 0 && /\(.*\)|\[.*\]/s.test(text.slice(selectionStart - 1, selectionEnd + 1))) { - let start = text[selectionStart - 1]; - let end = text[selectionEnd]; - if (opts.keyedit_convert) { - let numParen = 0; - - while (text[selectionStart - numParen - 1] == start && text[selectionEnd + numParen] == end) { - numParen++; - } - - if (start == "(") { - weight = 1.1 ** numParen; - } else { - weight = (1 / 1.1) ** numParen; - } - - weight = Math.round(weight / opts.keyedit_precision_attention) * opts.keyedit_precision_attention; - - text = text.slice(0, selectionStart - numParen) + "(" + text.slice(selectionStart, selectionEnd) + ":" + weight + ")" + text.slice(selectionEnd + numParen); - selectionStart -= numParen - 1; - selectionEnd -= numParen - 1; + } else if (start == '(' && end == ')' || start == '[' && end == ']') { // convert old-style (((emphasis))) + let numParen = 0; + + while (text[selectionStart - numParen - 1] == start && text[selectionEnd + numParen] == end) { + numParen++; + } + + if (start == "[") { + weight = (1 / 1.1) ** numParen; } else { - closeCharacter = null; - if (isPlus) { - text = text.slice(0, selectionStart) + start + text.slice(selectionStart, selectionEnd) + end + text.slice(selectionEnd); - selectionStart++; - selectionEnd++; - } else { - text = text.slice(0, selectionStart - 1) + text.slice(selectionStart, selectionEnd) + text.slice(selectionEnd + 1); - selectionStart--; - selectionEnd--; - } + weight = 1.1 ** numParen; } - } else if (selectionStart == 0 || !/\(.*:-?[\d.]+\)/s.test(text.slice(selectionStart - 1, selectionEnd + text.slice(selectionEnd).indexOf(")") + 1))) { + + weight = Math.round(weight / opts.keyedit_precision_attention) * opts.keyedit_precision_attention; + + text = text.slice(0, selectionStart - numParen) + "(" + text.slice(selectionStart, selectionEnd) + ":" + weight + ")" + text.slice(selectionEnd + numParen); + selectionStart -= numParen - 1; + selectionEnd -= numParen - 1; + } else if (start != '(') { // do not include spaces at the end while (selectionEnd > selectionStart && text[selectionEnd - 1] == ' ') { selectionEnd--; @@ -119,23 +111,22 @@ function keyupEditAttention(event) { selectionEnd++; } - if (closeCharacter) { - var end = text.slice(selectionEnd + 1).indexOf(closeCharacter) + 1; - var weight = parseFloat(text.slice(selectionEnd + 1, selectionEnd + end)); - if (isNaN(weight)) return; - - weight += isPlus ? delta : -delta; - weight = parseFloat(weight.toPrecision(12)); - if (Number.isInteger(weight)) weight += ".0"; - - if (closeCharacter == ')' && weight == 1) { - var endParenPos = text.substring(selectionEnd).indexOf(')'); - text = text.slice(0, selectionStart - 1) + text.slice(selectionStart, selectionEnd) + text.slice(selectionEnd + endParenPos + 1); - selectionStart--; - selectionEnd--; - } else { - text = text.slice(0, selectionEnd + 1) + weight + text.slice(selectionEnd + end); - } + if (text[selectionEnd] != ':') return; + var weightLength = text.slice(selectionEnd + 1).indexOf(closeCharacter) + 1; + var weight = parseFloat(text.slice(selectionEnd + 1, selectionEnd + weightLength)); + if (isNaN(weight)) return; + + weight += isPlus ? delta : -delta; + weight = parseFloat(weight.toPrecision(12)); + if (Number.isInteger(weight)) weight += ".0"; + + if (closeCharacter == ')' && weight == 1) { + var endParenPos = text.substring(selectionEnd).indexOf(')'); + text = text.slice(0, selectionStart - 1) + text.slice(selectionStart, selectionEnd) + text.slice(selectionEnd + endParenPos + 1); + selectionStart--; + selectionEnd--; + } else { + text = text.slice(0, selectionEnd + 1) + weight + text.slice(selectionEnd + weightLength); } target.focus(); diff --git a/modules/shared_options.py b/modules/shared_options.py index 32bf735323f..0a82216fff5 100644 --- a/modules/shared_options.py +++ b/modules/shared_options.py @@ -259,9 +259,8 @@ "dimensions_and_batch_together": OptionInfo(True, "Show Width/Height and Batch sliders in same row").needs_reload_ui(), "keyedit_precision_attention": OptionInfo(0.1, "Ctrl+up/down precision when editing (attention:1.1)", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}), "keyedit_precision_extra": OptionInfo(0.05, "Ctrl+up/down precision when editing ", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}), - "keyedit_delimiters": OptionInfo(r".,\/!?%^*;:{}=`~()[]<>| ", "Ctrl+up/down word delimiters"), + "keyedit_delimiters": OptionInfo(r".,\/!?%^*;:{}=`~() ", "Ctrl+up/down word delimiters"), "keyedit_delimiters_whitespace": OptionInfo(["Tab", "Carriage Return", "Line Feed"], "Ctrl+up/down whitespace delimiters", gr.CheckboxGroup, lambda: {"choices": ["Tab", "Carriage Return", "Line Feed"]}), - "keyedit_convert": OptionInfo(True, "Convert (attention) to (attention:1.1)"), "keyedit_move": OptionInfo(True, "Alt+left/right moves prompt elements"), "quicksettings_list": OptionInfo(["sd_model_checkpoint"], "Quicksettings list", ui_components.DropdownMulti, lambda: {"choices": list(shared.opts.data_labels.keys())}).js("info", "settingsHintsShowQuicksettings").info("setting entries that appear at the top of page rather than in settings tab").needs_reload_ui(), "ui_tab_order": OptionInfo([], "UI tab order", ui_components.DropdownMulti, lambda: {"choices": list(shared.tab_names)}).needs_reload_ui(), From 464fbcd92118bf00173b9982325fe6348201313e Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sat, 21 Oct 2023 09:09:32 +0300 Subject: [PATCH 0189/1174] fix the situation with emphasis editing (aaaa:1.1) bbbb (cccc:1.1) --- javascript/edit-attention.js | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/javascript/edit-attention.js b/javascript/edit-attention.js index f3af9a4c345..04464100699 100644 --- a/javascript/edit-attention.js +++ b/javascript/edit-attention.js @@ -19,11 +19,17 @@ function keyupEditAttention(event) { let beforeParen = before.lastIndexOf(OPEN); if (beforeParen == -1) return false; + let beforeClosingParen = before.lastIndexOf(CLOSE); + if (beforeClosingParen != -1 && beforeClosingParen > beforeParen) return false; + // Find closing parenthesis around current cursor const after = text.substring(selectionStart); let afterParen = after.indexOf(CLOSE); if (afterParen == -1) return false; + let afterOpeningParen = after.indexOf(OPEN); + if (afterOpeningParen != -1 && afterOpeningParen < beforeParen) return false; + // Set the selection to the text between the parenthesis const parenContent = text.substring(beforeParen + 1, selectionStart + afterParen); if (/.*:-?[\d.]+/s.test(parenContent)) { From 443ca983ade333721930ea2f18f80b45762e2aea Mon Sep 17 00:00:00 2001 From: avantcontra Date: Sun, 22 Oct 2023 03:21:23 +0800 Subject: [PATCH 0190/1174] fix bug when using --gfpgan-models-path --- modules/gfpgan_model.py | 25 ++++++++++++++++++++----- 1 file changed, 20 insertions(+), 5 deletions(-) diff --git a/modules/gfpgan_model.py b/modules/gfpgan_model.py index 8e0f13bdc7d..93567253f51 100644 --- a/modules/gfpgan_model.py +++ b/modules/gfpgan_model.py @@ -9,6 +9,7 @@ model_dir = "GFPGAN" user_path = None model_path = os.path.join(paths.models_path, model_dir) +model_file_path = None model_url = "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth" have_gfpgan = False loaded_gfpgan_model = None @@ -17,24 +18,32 @@ def gfpgann(): global loaded_gfpgan_model global model_path + global model_file_path if loaded_gfpgan_model is not None: loaded_gfpgan_model.gfpgan.to(devices.device_gfpgan) return loaded_gfpgan_model if gfpgan_constructor is None: return None + + models = modelloader.load_models(model_path, model_url, user_path, ext_filter=['.pth']) - models = modelloader.load_models(model_path, model_url, user_path, ext_filter="GFPGAN") if len(models) == 1 and models[0].startswith("http"): model_file = models[0] elif len(models) != 0: - latest_file = max(models, key=os.path.getctime) + gfp_models = [] + for item in models: + if 'GFPGAN' in os.path.basename(item): + gfp_models.append(item) + latest_file = max(gfp_models, key=os.path.getctime) model_file = latest_file else: print("Unable to load gfpgan model!") return None + if hasattr(facexlib.detection.retinaface, 'device'): facexlib.detection.retinaface.device = devices.device_gfpgan + model_file_path = model_file model = gfpgan_constructor(model_path=model_file, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None, device=devices.device_gfpgan) loaded_gfpgan_model = model @@ -77,19 +86,25 @@ def setup_model(dirname): global user_path global have_gfpgan global gfpgan_constructor + global model_file_path + + facexlib_path = model_path + + if dirname is not None: + facexlib_path = dirname load_file_from_url_orig = gfpgan.utils.load_file_from_url facex_load_file_from_url_orig = facexlib.detection.load_file_from_url facex_load_file_from_url_orig2 = facexlib.parsing.load_file_from_url def my_load_file_from_url(**kwargs): - return load_file_from_url_orig(**dict(kwargs, model_dir=model_path)) + return load_file_from_url_orig(**dict(kwargs, model_dir=model_file_path)) def facex_load_file_from_url(**kwargs): - return facex_load_file_from_url_orig(**dict(kwargs, save_dir=model_path, model_dir=None)) + return facex_load_file_from_url_orig(**dict(kwargs, save_dir=facexlib_path, model_dir=None)) def facex_load_file_from_url2(**kwargs): - return facex_load_file_from_url_orig2(**dict(kwargs, save_dir=model_path, model_dir=None)) + return facex_load_file_from_url_orig2(**dict(kwargs, save_dir=facexlib_path, model_dir=None)) gfpgan.utils.load_file_from_url = my_load_file_from_url facexlib.detection.load_file_from_url = facex_load_file_from_url From 236dd55dbe895ba72a64567482ee67ab680c5344 Mon Sep 17 00:00:00 2001 From: avantcontra Date: Sun, 22 Oct 2023 04:32:13 +0800 Subject: [PATCH 0191/1174] fix Blank line contains whitespace --- modules/gfpgan_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/gfpgan_model.py b/modules/gfpgan_model.py index 93567253f51..01d668ecdaf 100644 --- a/modules/gfpgan_model.py +++ b/modules/gfpgan_model.py @@ -25,7 +25,7 @@ def gfpgann(): if gfpgan_constructor is None: return None - + models = modelloader.load_models(model_path, model_url, user_path, ext_filter=['.pth']) if len(models) == 1 and models[0].startswith("http"): From 2d8c894b274d60a3e3563a2ace23c4ebcea9e652 Mon Sep 17 00:00:00 2001 From: v0xie <28695009+v0xie@users.noreply.github.com> Date: Sat, 21 Oct 2023 13:43:31 -0700 Subject: [PATCH 0192/1174] refactor: use forward hook instead of custom forward --- extensions-builtin/Lora/network_oft.py | 33 +++++++++++++++++++------- 1 file changed, 24 insertions(+), 9 deletions(-) diff --git a/extensions-builtin/Lora/network_oft.py b/extensions-builtin/Lora/network_oft.py index 4e8382c18c7..8e561ab0baa 100644 --- a/extensions-builtin/Lora/network_oft.py +++ b/extensions-builtin/Lora/network_oft.py @@ -36,9 +36,11 @@ def __init__(self, net: network.Network, weights: network.NetworkWeights): # how do we revert this to unload the weights? def apply_to(self): self.org_forward = self.org_module[0].forward - self.org_module[0].forward = self.forward + #self.org_module[0].forward = self.forward + self.org_module[0].register_forward_hook(self.forward_hook) def get_weight(self, oft_blocks, multiplier=None): + self.constraint = self.constraint.to(oft_blocks.device, dtype=oft_blocks.dtype) block_Q = oft_blocks - oft_blocks.transpose(1, 2) norm_Q = torch.norm(block_Q.flatten()) new_norm_Q = torch.clamp(norm_Q, max=self.constraint) @@ -66,14 +68,10 @@ def calc_updown(self, orig_weight): output_shape = self.oft_blocks.shape return self.finalize_updown(updown, orig_weight, output_shape) - - def forward(self, x, y=None): - x = self.org_forward(x) - if self.multiplier() == 0.0: - return x - - # calculating R here is excruciatingly slow - #R = self.get_weight().to(x.device, dtype=x.dtype) + + def forward_hook(self, module, args, output): + #print(f'Forward hook in {self.network_key} called') + x = output R = self.R.to(x.device, dtype=x.dtype) if x.dim() == 4: @@ -83,3 +81,20 @@ def forward(self, x, y=None): else: x = torch.matmul(x, R) return x + + # def forward(self, x, y=None): + # x = self.org_forward(x) + # if self.multiplier() == 0.0: + # return x + + # # calculating R here is excruciatingly slow + # #R = self.get_weight().to(x.device, dtype=x.dtype) + # R = self.R.to(x.device, dtype=x.dtype) + + # if x.dim() == 4: + # x = x.permute(0, 2, 3, 1) + # x = torch.matmul(x, R) + # x = x.permute(0, 3, 1, 2) + # else: + # x = torch.matmul(x, R) + # return x From 768354772853a1d27a9bf7e41bd6a6e4eac7a9c7 Mon Sep 17 00:00:00 2001 From: v0xie <28695009+v0xie@users.noreply.github.com> Date: Sat, 21 Oct 2023 14:42:24 -0700 Subject: [PATCH 0193/1174] fix: return orig weights during updown, merge weights before forward --- extensions-builtin/Lora/network_oft.py | 90 ++++++++++++++++++++------ 1 file changed, 69 insertions(+), 21 deletions(-) diff --git a/extensions-builtin/Lora/network_oft.py b/extensions-builtin/Lora/network_oft.py index 8e561ab0baa..f5f32c23839 100644 --- a/extensions-builtin/Lora/network_oft.py +++ b/extensions-builtin/Lora/network_oft.py @@ -1,5 +1,6 @@ import torch import network +from modules import devices class ModuleTypeOFT(network.ModuleType): @@ -29,23 +30,56 @@ def __init__(self, net: network.Network, weights: network.NetworkWeights): self.block_size = self.out_dim // self.num_blocks self.org_module: list[torch.Module] = [self.sd_module] + self.org_weight = self.org_module[0].weight.to(self.org_module[0].weight.device, copy=True) + #self.org_weight = self.org_module[0].weight.to(devices.cpu, copy=True) self.R = self.get_weight(self.oft_blocks) + + self.merged_weight = self.merge_weight() self.apply_to() + self.merged = False + + + def merge_weight(self): + org_sd = self.org_module[0].state_dict() + R = self.R.to(self.org_weight.device, dtype=self.org_weight.dtype) + if self.org_weight.dim() == 4: + weight = torch.einsum("oihw, op -> pihw", self.org_weight, R) + else: + weight = torch.einsum("oi, op -> pi", self.org_weight, R) + org_sd['weight'] = weight + # replace weight + #self.org_module[0].load_state_dict(org_sd) + return weight + pass + + def replace_weight(self, new_weight): + org_sd = self.org_module[0].state_dict() + org_sd['weight'] = new_weight + self.org_module[0].load_state_dict(org_sd) + self.merged = True + + def restore_weight(self): + org_sd = self.org_module[0].state_dict() + org_sd['weight'] = self.org_weight + self.org_module[0].load_state_dict(org_sd) + self.merged = False + # replace forward method of original linear rather than replacing the module # how do we revert this to unload the weights? def apply_to(self): self.org_forward = self.org_module[0].forward #self.org_module[0].forward = self.forward + self.org_module[0].register_forward_pre_hook(self.pre_forward_hook) self.org_module[0].register_forward_hook(self.forward_hook) def get_weight(self, oft_blocks, multiplier=None): - self.constraint = self.constraint.to(oft_blocks.device, dtype=oft_blocks.dtype) + constraint = self.constraint.to(oft_blocks.device, dtype=oft_blocks.dtype) block_Q = oft_blocks - oft_blocks.transpose(1, 2) norm_Q = torch.norm(block_Q.flatten()) - new_norm_Q = torch.clamp(norm_Q, max=self.constraint) + new_norm_Q = torch.clamp(norm_Q, max=constraint) block_Q = block_Q * ((new_norm_Q + 1e-8) / (norm_Q + 1e-8)) - m_I = torch.eye(self.block_size, device=self.oft_blocks.device).unsqueeze(0).repeat(self.num_blocks, 1, 1) + m_I = torch.eye(self.block_size, device=oft_blocks.device).unsqueeze(0).repeat(self.num_blocks, 1, 1) block_R = torch.matmul(m_I + block_Q, (m_I - block_Q).inverse()) #block_R_weighted = multiplier * block_R + (1 - multiplier) * I #R = torch.block_diag(*block_R_weighted) @@ -54,33 +88,47 @@ def get_weight(self, oft_blocks, multiplier=None): return R def calc_updown(self, orig_weight): - oft_blocks = self.oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype) + #oft_blocks = self.oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype) - R = self.get_weight(oft_blocks) - self.R = R + #R = self.R.to(orig_weight.device, dtype=orig_weight.dtype) + ##self.R = R - # if orig_weight.dim() == 4: - # weight = torch.einsum("oihw, op -> pihw", orig_weight, R) - # else: - # weight = torch.einsum("oi, op -> pi", orig_weight, R) + #if orig_weight.dim() == 4: + # weight = torch.einsum("oihw, op -> pihw", orig_weight, R) + #else: + # weight = torch.einsum("oi, op -> pi", orig_weight, R) - updown = orig_weight @ R - output_shape = self.oft_blocks.shape + #updown = orig_weight @ R + #updown = weight + updown = torch.zeros_like(orig_weight, device=orig_weight.device, dtype=orig_weight.dtype) + #updown = orig_weight + output_shape = orig_weight.shape + #orig_weight = self.merged_weight.to(orig_weight.device, dtype=orig_weight.dtype) + #output_shape = self.oft_blocks.shape return self.finalize_updown(updown, orig_weight, output_shape) + def pre_forward_hook(self, module, input): + if not self.merged: + self.replace_weight(self.merged_weight) + + def forward_hook(self, module, args, output): + if self.merged: + pass + #self.restore_weight() #print(f'Forward hook in {self.network_key} called') - x = output - R = self.R.to(x.device, dtype=x.dtype) - if x.dim() == 4: - x = x.permute(0, 2, 3, 1) - x = torch.matmul(x, R) - x = x.permute(0, 3, 1, 2) - else: - x = torch.matmul(x, R) - return x + #x = output + #R = self.R.to(x.device, dtype=x.dtype) + + #if x.dim() == 4: + # x = x.permute(0, 2, 3, 1) + # x = torch.matmul(x, R) + # x = x.permute(0, 3, 1, 2) + #else: + # x = torch.matmul(x, R) + #return x # def forward(self, x, y=None): # x = self.org_forward(x) From fce86ab7d75690785f0f5b496f1b3aee922c0ae3 Mon Sep 17 00:00:00 2001 From: v0xie <28695009+v0xie@users.noreply.github.com> Date: Sat, 21 Oct 2023 16:03:54 -0700 Subject: [PATCH 0194/1174] fix: support multiplier, no forward pass hook --- extensions-builtin/Lora/network_oft.py | 43 ++++++++++++++++++++------ 1 file changed, 33 insertions(+), 10 deletions(-) diff --git a/extensions-builtin/Lora/network_oft.py b/extensions-builtin/Lora/network_oft.py index f5f32c23839..e0672ba6d82 100644 --- a/extensions-builtin/Lora/network_oft.py +++ b/extensions-builtin/Lora/network_oft.py @@ -32,21 +32,27 @@ def __init__(self, net: network.Network, weights: network.NetworkWeights): self.org_module: list[torch.Module] = [self.sd_module] self.org_weight = self.org_module[0].weight.to(self.org_module[0].weight.device, copy=True) #self.org_weight = self.org_module[0].weight.to(devices.cpu, copy=True) - self.R = self.get_weight(self.oft_blocks) + init_multiplier = self.multiplier() * self.calc_scale() + self.last_multiplier = init_multiplier + self.R = self.get_weight(self.oft_blocks, init_multiplier) self.merged_weight = self.merge_weight() self.apply_to() self.merged = False + # weights_backup = getattr(self.org_module[0], 'network_weights_backup', None) + # if weights_backup is None: + # self.org_module[0].network_weights_backup = self.org_weight + def merge_weight(self): - org_sd = self.org_module[0].state_dict() + #org_sd = self.org_module[0].state_dict() R = self.R.to(self.org_weight.device, dtype=self.org_weight.dtype) if self.org_weight.dim() == 4: weight = torch.einsum("oihw, op -> pihw", self.org_weight, R) else: weight = torch.einsum("oi, op -> pi", self.org_weight, R) - org_sd['weight'] = weight + #org_sd['weight'] = weight # replace weight #self.org_module[0].load_state_dict(org_sd) return weight @@ -74,6 +80,7 @@ def apply_to(self): self.org_module[0].register_forward_hook(self.forward_hook) def get_weight(self, oft_blocks, multiplier=None): + multiplier = multiplier.to(oft_blocks.device, dtype=oft_blocks.dtype) constraint = self.constraint.to(oft_blocks.device, dtype=oft_blocks.dtype) block_Q = oft_blocks - oft_blocks.transpose(1, 2) norm_Q = torch.norm(block_Q.flatten()) @@ -81,9 +88,9 @@ def get_weight(self, oft_blocks, multiplier=None): block_Q = block_Q * ((new_norm_Q + 1e-8) / (norm_Q + 1e-8)) m_I = torch.eye(self.block_size, device=oft_blocks.device).unsqueeze(0).repeat(self.num_blocks, 1, 1) block_R = torch.matmul(m_I + block_Q, (m_I - block_Q).inverse()) - #block_R_weighted = multiplier * block_R + (1 - multiplier) * I - #R = torch.block_diag(*block_R_weighted) - R = torch.block_diag(*block_R) + block_R_weighted = multiplier * block_R + (1 - multiplier) * m_I + R = torch.block_diag(*block_R_weighted) + #R = torch.block_diag(*block_R) return R @@ -93,6 +100,8 @@ def calc_updown(self, orig_weight): #R = self.R.to(orig_weight.device, dtype=orig_weight.dtype) ##self.R = R + #R = self.R.to(orig_weight.device, dtype=orig_weight.dtype) + ##self.R = R #if orig_weight.dim() == 4: # weight = torch.einsum("oihw, op -> pihw", orig_weight, R) #else: @@ -103,19 +112,33 @@ def calc_updown(self, orig_weight): updown = torch.zeros_like(orig_weight, device=orig_weight.device, dtype=orig_weight.dtype) #updown = orig_weight output_shape = orig_weight.shape - #orig_weight = self.merged_weight.to(orig_weight.device, dtype=orig_weight.dtype) + orig_weight = self.merged_weight.to(orig_weight.device, dtype=orig_weight.dtype) #output_shape = self.oft_blocks.shape return self.finalize_updown(updown, orig_weight, output_shape) def pre_forward_hook(self, module, input): - if not self.merged: + multiplier = self.multiplier() * self.calc_scale() + if not multiplier==self.last_multiplier or not self.merged: + + #if multiplier != self.last_multiplier or not self.merged: + self.R = self.get_weight(self.oft_blocks, multiplier) + self.last_multiplier = multiplier + self.merged_weight = self.merge_weight() self.replace_weight(self.merged_weight) + #elif not self.merged: + # self.replace_weight(self.merged_weight) def forward_hook(self, module, args, output): - if self.merged: - pass + pass + #output = output * self.multiplier() * self.calc_scale() + #if len(args) > 0: + # y = args[0] + # output = output + y + #return output + #if self.merged: + # pass #self.restore_weight() #print(f'Forward hook in {self.network_key} called') From 76f5abdbdb739133eff2ccefa36eac62bea3fa08 Mon Sep 17 00:00:00 2001 From: v0xie <28695009+v0xie@users.noreply.github.com> Date: Sat, 21 Oct 2023 16:07:45 -0700 Subject: [PATCH 0195/1174] style: cleanup oft --- extensions-builtin/Lora/network_oft.py | 82 +++----------------------- 1 file changed, 7 insertions(+), 75 deletions(-) diff --git a/extensions-builtin/Lora/network_oft.py b/extensions-builtin/Lora/network_oft.py index e0672ba6d82..e462ccb1b76 100644 --- a/extensions-builtin/Lora/network_oft.py +++ b/extensions-builtin/Lora/network_oft.py @@ -1,6 +1,5 @@ import torch import network -from modules import devices class ModuleTypeOFT(network.ModuleType): @@ -31,33 +30,24 @@ def __init__(self, net: network.Network, weights: network.NetworkWeights): self.org_module: list[torch.Module] = [self.sd_module] self.org_weight = self.org_module[0].weight.to(self.org_module[0].weight.device, copy=True) - #self.org_weight = self.org_module[0].weight.to(devices.cpu, copy=True) + init_multiplier = self.multiplier() * self.calc_scale() self.last_multiplier = init_multiplier + self.R = self.get_weight(self.oft_blocks, init_multiplier) self.merged_weight = self.merge_weight() self.apply_to() self.merged = False - # weights_backup = getattr(self.org_module[0], 'network_weights_backup', None) - # if weights_backup is None: - # self.org_module[0].network_weights_backup = self.org_weight - - def merge_weight(self): - #org_sd = self.org_module[0].state_dict() R = self.R.to(self.org_weight.device, dtype=self.org_weight.dtype) if self.org_weight.dim() == 4: weight = torch.einsum("oihw, op -> pihw", self.org_weight, R) else: weight = torch.einsum("oi, op -> pi", self.org_weight, R) - #org_sd['weight'] = weight - # replace weight - #self.org_module[0].load_state_dict(org_sd) return weight - pass - + def replace_weight(self, new_weight): org_sd = self.org_module[0].state_dict() org_sd['weight'] = new_weight @@ -70,9 +60,7 @@ def restore_weight(self): self.org_module[0].load_state_dict(org_sd) self.merged = False - - # replace forward method of original linear rather than replacing the module - # how do we revert this to unload the weights? + # FIXME: hook forward method of original linear, but how do we undo the hook when we are done? def apply_to(self): self.org_forward = self.org_module[0].forward #self.org_module[0].forward = self.forward @@ -90,82 +78,26 @@ def get_weight(self, oft_blocks, multiplier=None): block_R = torch.matmul(m_I + block_Q, (m_I - block_Q).inverse()) block_R_weighted = multiplier * block_R + (1 - multiplier) * m_I R = torch.block_diag(*block_R_weighted) - #R = torch.block_diag(*block_R) return R def calc_updown(self, orig_weight): - #oft_blocks = self.oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype) - - #R = self.R.to(orig_weight.device, dtype=orig_weight.dtype) - ##self.R = R - - #R = self.R.to(orig_weight.device, dtype=orig_weight.dtype) - ##self.R = R - #if orig_weight.dim() == 4: - # weight = torch.einsum("oihw, op -> pihw", orig_weight, R) - #else: - # weight = torch.einsum("oi, op -> pi", orig_weight, R) - - #updown = orig_weight @ R - #updown = weight updown = torch.zeros_like(orig_weight, device=orig_weight.device, dtype=orig_weight.dtype) - #updown = orig_weight output_shape = orig_weight.shape orig_weight = self.merged_weight.to(orig_weight.device, dtype=orig_weight.dtype) #output_shape = self.oft_blocks.shape return self.finalize_updown(updown, orig_weight, output_shape) - + def pre_forward_hook(self, module, input): multiplier = self.multiplier() * self.calc_scale() - if not multiplier==self.last_multiplier or not self.merged: - #if multiplier != self.last_multiplier or not self.merged: + if not multiplier==self.last_multiplier or not self.merged: self.R = self.get_weight(self.oft_blocks, multiplier) self.last_multiplier = multiplier self.merged_weight = self.merge_weight() self.replace_weight(self.merged_weight) - #elif not self.merged: - # self.replace_weight(self.merged_weight) - + def forward_hook(self, module, args, output): pass - #output = output * self.multiplier() * self.calc_scale() - #if len(args) > 0: - # y = args[0] - # output = output + y - #return output - #if self.merged: - # pass - #self.restore_weight() - #print(f'Forward hook in {self.network_key} called') - - #x = output - #R = self.R.to(x.device, dtype=x.dtype) - - #if x.dim() == 4: - # x = x.permute(0, 2, 3, 1) - # x = torch.matmul(x, R) - # x = x.permute(0, 3, 1, 2) - #else: - # x = torch.matmul(x, R) - #return x - - # def forward(self, x, y=None): - # x = self.org_forward(x) - # if self.multiplier() == 0.0: - # return x - - # # calculating R here is excruciatingly slow - # #R = self.get_weight().to(x.device, dtype=x.dtype) - # R = self.R.to(x.device, dtype=x.dtype) - - # if x.dim() == 4: - # x = x.permute(0, 2, 3, 1) - # x = torch.matmul(x, R) - # x = x.permute(0, 3, 1, 2) - # else: - # x = torch.matmul(x, R) - # return x From de8ee92ed88b855098e273f576a27f4789f0693d Mon Sep 17 00:00:00 2001 From: v0xie <28695009+v0xie@users.noreply.github.com> Date: Sat, 21 Oct 2023 17:37:17 -0700 Subject: [PATCH 0196/1174] fix: use merge_weight to cache value --- extensions-builtin/Lora/network_oft.py | 57 ++++++++++++++++++-------- 1 file changed, 40 insertions(+), 17 deletions(-) diff --git a/extensions-builtin/Lora/network_oft.py b/extensions-builtin/Lora/network_oft.py index e462ccb1b76..ebe6740c5c0 100644 --- a/extensions-builtin/Lora/network_oft.py +++ b/extensions-builtin/Lora/network_oft.py @@ -29,23 +29,27 @@ def __init__(self, net: network.Network, weights: network.NetworkWeights): self.block_size = self.out_dim // self.num_blocks self.org_module: list[torch.Module] = [self.sd_module] - self.org_weight = self.org_module[0].weight.to(self.org_module[0].weight.device, copy=True) + #self.org_weight = self.org_module[0].weight.to(self.org_module[0].weight.device, copy=True) init_multiplier = self.multiplier() * self.calc_scale() self.last_multiplier = init_multiplier self.R = self.get_weight(self.oft_blocks, init_multiplier) + self.hooks = [] self.merged_weight = self.merge_weight() - self.apply_to() + + #self.apply_to() + self.applied = False self.merged = False def merge_weight(self): - R = self.R.to(self.org_weight.device, dtype=self.org_weight.dtype) - if self.org_weight.dim() == 4: - weight = torch.einsum("oihw, op -> pihw", self.org_weight, R) + org_weight = self.org_module[0].weight + R = self.R.to(org_weight.device, dtype=org_weight.dtype) + if org_weight.dim() == 4: + weight = torch.einsum("oihw, op -> pihw", org_weight, R) else: - weight = torch.einsum("oi, op -> pi", self.org_weight, R) + weight = torch.einsum("oi, op -> pi", org_weight, R) return weight def replace_weight(self, new_weight): @@ -55,17 +59,29 @@ def replace_weight(self, new_weight): self.merged = True def restore_weight(self): - org_sd = self.org_module[0].state_dict() - org_sd['weight'] = self.org_weight - self.org_module[0].load_state_dict(org_sd) - self.merged = False + pass + #org_sd = self.org_module[0].state_dict() + #org_sd['weight'] = self.org_weight + #self.org_module[0].load_state_dict(org_sd) + #self.merged = False # FIXME: hook forward method of original linear, but how do we undo the hook when we are done? def apply_to(self): - self.org_forward = self.org_module[0].forward - #self.org_module[0].forward = self.forward - self.org_module[0].register_forward_pre_hook(self.pre_forward_hook) - self.org_module[0].register_forward_hook(self.forward_hook) + if not self.applied: + self.org_forward = self.org_module[0].forward + #self.org_module[0].forward = self.forward + prehook = self.org_module[0].register_forward_pre_hook(self.pre_forward_hook) + hook = self.org_module[0].register_forward_hook(self.forward_hook) + self.hooks.append(prehook) + self.hooks.append(hook) + self.applied = True + + def remove_from(self): + if self.applied: + for hook in self.hooks: + hook.remove() + self.hooks = [] + self.applied = False def get_weight(self, oft_blocks, multiplier=None): multiplier = multiplier.to(oft_blocks.device, dtype=oft_blocks.dtype) @@ -82,14 +98,22 @@ def get_weight(self, oft_blocks, multiplier=None): return R def calc_updown(self, orig_weight): + if not self.applied: + self.apply_to() + + self.merged_weight = self.merged_weight.to(orig_weight.device, dtype=orig_weight.dtype) + updown = torch.zeros_like(orig_weight, device=orig_weight.device, dtype=orig_weight.dtype) output_shape = orig_weight.shape - orig_weight = self.merged_weight.to(orig_weight.device, dtype=orig_weight.dtype) + orig_weight = self.merged_weight #output_shape = self.oft_blocks.shape return self.finalize_updown(updown, orig_weight, output_shape) def pre_forward_hook(self, module, input): + #if not self.applied: + # self.apply_to() + multiplier = self.multiplier() * self.calc_scale() if not multiplier==self.last_multiplier or not self.merged: @@ -98,6 +122,5 @@ def pre_forward_hook(self, module, input): self.merged_weight = self.merge_weight() self.replace_weight(self.merged_weight) - def forward_hook(self, module, args, output): - pass + pass \ No newline at end of file From 4a50c9638c3eac860fb05ae603cd61aabf4cd1a9 Mon Sep 17 00:00:00 2001 From: v0xie <28695009+v0xie@users.noreply.github.com> Date: Sun, 22 Oct 2023 08:54:24 -0700 Subject: [PATCH 0197/1174] refactor: remove used OFT functions --- extensions-builtin/Lora/network_oft.py | 82 ++++---------------------- 1 file changed, 10 insertions(+), 72 deletions(-) diff --git a/extensions-builtin/Lora/network_oft.py b/extensions-builtin/Lora/network_oft.py index ebe6740c5c0..3034a407eb3 100644 --- a/extensions-builtin/Lora/network_oft.py +++ b/extensions-builtin/Lora/network_oft.py @@ -29,98 +29,36 @@ def __init__(self, net: network.Network, weights: network.NetworkWeights): self.block_size = self.out_dim // self.num_blocks self.org_module: list[torch.Module] = [self.sd_module] - #self.org_weight = self.org_module[0].weight.to(self.org_module[0].weight.device, copy=True) - init_multiplier = self.multiplier() * self.calc_scale() - self.last_multiplier = init_multiplier - - self.R = self.get_weight(self.oft_blocks, init_multiplier) - - self.hooks = [] - self.merged_weight = self.merge_weight() - - #self.apply_to() - self.applied = False - self.merged = False - - def merge_weight(self): - org_weight = self.org_module[0].weight - R = self.R.to(org_weight.device, dtype=org_weight.dtype) + def merge_weight(self, R_weight, org_weight): + R_weight = R_weight.to(org_weight.device, dtype=org_weight.dtype) if org_weight.dim() == 4: - weight = torch.einsum("oihw, op -> pihw", org_weight, R) + weight = torch.einsum("oihw, op -> pihw", org_weight, R_weight) else: - weight = torch.einsum("oi, op -> pi", org_weight, R) + weight = torch.einsum("oi, op -> pi", org_weight, R_weight) return weight - def replace_weight(self, new_weight): - org_sd = self.org_module[0].state_dict() - org_sd['weight'] = new_weight - self.org_module[0].load_state_dict(org_sd) - self.merged = True - - def restore_weight(self): - pass - #org_sd = self.org_module[0].state_dict() - #org_sd['weight'] = self.org_weight - #self.org_module[0].load_state_dict(org_sd) - #self.merged = False - - # FIXME: hook forward method of original linear, but how do we undo the hook when we are done? - def apply_to(self): - if not self.applied: - self.org_forward = self.org_module[0].forward - #self.org_module[0].forward = self.forward - prehook = self.org_module[0].register_forward_pre_hook(self.pre_forward_hook) - hook = self.org_module[0].register_forward_hook(self.forward_hook) - self.hooks.append(prehook) - self.hooks.append(hook) - self.applied = True - - def remove_from(self): - if self.applied: - for hook in self.hooks: - hook.remove() - self.hooks = [] - self.applied = False - def get_weight(self, oft_blocks, multiplier=None): - multiplier = multiplier.to(oft_blocks.device, dtype=oft_blocks.dtype) constraint = self.constraint.to(oft_blocks.device, dtype=oft_blocks.dtype) + block_Q = oft_blocks - oft_blocks.transpose(1, 2) norm_Q = torch.norm(block_Q.flatten()) new_norm_Q = torch.clamp(norm_Q, max=constraint) block_Q = block_Q * ((new_norm_Q + 1e-8) / (norm_Q + 1e-8)) m_I = torch.eye(self.block_size, device=oft_blocks.device).unsqueeze(0).repeat(self.num_blocks, 1, 1) block_R = torch.matmul(m_I + block_Q, (m_I - block_Q).inverse()) + block_R_weighted = multiplier * block_R + (1 - multiplier) * m_I R = torch.block_diag(*block_R_weighted) return R def calc_updown(self, orig_weight): - if not self.applied: - self.apply_to() - - self.merged_weight = self.merged_weight.to(orig_weight.device, dtype=orig_weight.dtype) + R = self.get_weight(self.oft_blocks, self.multiplier()) + merged_weight = self.merge_weight(R, orig_weight) - updown = torch.zeros_like(orig_weight, device=orig_weight.device, dtype=orig_weight.dtype) + updown = merged_weight.to(orig_weight.device, dtype=orig_weight.dtype) - orig_weight output_shape = orig_weight.shape - orig_weight = self.merged_weight - #output_shape = self.oft_blocks.shape + orig_weight = orig_weight return self.finalize_updown(updown, orig_weight, output_shape) - - def pre_forward_hook(self, module, input): - #if not self.applied: - # self.apply_to() - - multiplier = self.multiplier() * self.calc_scale() - - if not multiplier==self.last_multiplier or not self.merged: - self.R = self.get_weight(self.oft_blocks, multiplier) - self.last_multiplier = multiplier - self.merged_weight = self.merge_weight() - self.replace_weight(self.merged_weight) - - def forward_hook(self, module, args, output): - pass \ No newline at end of file From 3b8515d2c9abad7f0ccaac0215803716e861ee0e Mon Sep 17 00:00:00 2001 From: v0xie <28695009+v0xie@users.noreply.github.com> Date: Sun, 22 Oct 2023 09:27:48 -0700 Subject: [PATCH 0198/1174] fix: multiplier applied twice in finalize_updown --- extensions-builtin/Lora/network_oft.py | 23 ++++++++++++++++++++++- 1 file changed, 22 insertions(+), 1 deletion(-) diff --git a/extensions-builtin/Lora/network_oft.py b/extensions-builtin/Lora/network_oft.py index 3034a407eb3..efbdd296abe 100644 --- a/extensions-builtin/Lora/network_oft.py +++ b/extensions-builtin/Lora/network_oft.py @@ -54,7 +54,8 @@ def get_weight(self, oft_blocks, multiplier=None): return R def calc_updown(self, orig_weight): - R = self.get_weight(self.oft_blocks, self.multiplier()) + multiplier = self.multiplier() * self.calc_scale() + R = self.get_weight(self.oft_blocks, multiplier) merged_weight = self.merge_weight(R, orig_weight) updown = merged_weight.to(orig_weight.device, dtype=orig_weight.dtype) - orig_weight @@ -62,3 +63,23 @@ def calc_updown(self, orig_weight): orig_weight = orig_weight return self.finalize_updown(updown, orig_weight, output_shape) + + # override to remove the multiplier/scale factor; it's already multiplied in get_weight + def finalize_updown(self, updown, orig_weight, output_shape, ex_bias=None): + #return super().finalize_updown(updown, orig_weight, output_shape, ex_bias) + + if self.bias is not None: + updown = updown.reshape(self.bias.shape) + updown += self.bias.to(orig_weight.device, dtype=orig_weight.dtype) + updown = updown.reshape(output_shape) + + if len(output_shape) == 4: + updown = updown.reshape(output_shape) + + if orig_weight.size().numel() == updown.size().numel(): + updown = updown.reshape(orig_weight.shape) + + if ex_bias is not None: + ex_bias = ex_bias * self.multiplier() + + return updown, ex_bias From 6523edb8a45d4e09f11f3b4e1d133afa6fb65e53 Mon Sep 17 00:00:00 2001 From: v0xie <28695009+v0xie@users.noreply.github.com> Date: Sun, 22 Oct 2023 09:31:15 -0700 Subject: [PATCH 0199/1174] style: conform style --- extensions-builtin/Lora/network_oft.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/extensions-builtin/Lora/network_oft.py b/extensions-builtin/Lora/network_oft.py index efbdd296abe..e43c9a1dfff 100644 --- a/extensions-builtin/Lora/network_oft.py +++ b/extensions-builtin/Lora/network_oft.py @@ -63,7 +63,7 @@ def calc_updown(self, orig_weight): orig_weight = orig_weight return self.finalize_updown(updown, orig_weight, output_shape) - + # override to remove the multiplier/scale factor; it's already multiplied in get_weight def finalize_updown(self, updown, orig_weight, output_shape, ex_bias=None): #return super().finalize_updown(updown, orig_weight, output_shape, ex_bias) From 88b2ef3b04c37ec068fdfea9ba2596645e981b46 Mon Sep 17 00:00:00 2001 From: David Benson Date: Mon, 23 Oct 2023 08:16:26 -0400 Subject: [PATCH 0200/1174] Update prompts_from_file script to allow concatenating entries with the general prompt. --- scripts/prompts_from_file.py | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/scripts/prompts_from_file.py b/scripts/prompts_from_file.py index 50320d553bd..1aadf113594 100644 --- a/scripts/prompts_from_file.py +++ b/scripts/prompts_from_file.py @@ -108,6 +108,7 @@ def title(self): def ui(self, is_img2img): checkbox_iterate = gr.Checkbox(label="Iterate seed every line", value=False, elem_id=self.elem_id("checkbox_iterate")) checkbox_iterate_batch = gr.Checkbox(label="Use same random seed for all lines", value=False, elem_id=self.elem_id("checkbox_iterate_batch")) + prompt_position = gr.Radio(["start", "end"], label="Insert prompts at the", elem_id=self.elem_id("prompt_position"), value="start") prompt_txt = gr.Textbox(label="List of prompt inputs", lines=1, elem_id=self.elem_id("prompt_txt")) file = gr.File(label="Upload prompt inputs", type='binary', elem_id=self.elem_id("file")) @@ -118,9 +119,9 @@ def ui(self, is_img2img): # We don't shrink back to 1, because that causes the control to ignore [enter], and it may # be unclear to the user that shift-enter is needed. prompt_txt.change(lambda tb: gr.update(lines=7) if ("\n" in tb) else gr.update(lines=2), inputs=[prompt_txt], outputs=[prompt_txt], show_progress=False) - return [checkbox_iterate, checkbox_iterate_batch, prompt_txt] + return [checkbox_iterate, checkbox_iterate_batch, prompt_position, prompt_txt] - def run(self, p, checkbox_iterate, checkbox_iterate_batch, prompt_txt: str): + def run(self, p, checkbox_iterate, checkbox_iterate_batch, prompt_position, prompt_txt: str): lines = [x for x in (x.strip() for x in prompt_txt.splitlines()) if x] p.do_not_save_grid = True @@ -158,6 +159,18 @@ def run(self, p, checkbox_iterate, checkbox_iterate_batch, prompt_txt: str): for k, v in args.items(): setattr(copy_p, k, v) + if args.get("prompt") and p.prompt: + if prompt_position == "start": + copy_p.prompt = args.get("prompt") + " " + p.prompt + else: + copy_p.prompt = p.prompt + " " + args.get("prompt") + + if args.get("negative_prompt") and p.negative_prompt: + if prompt_position == "start": + copy_p.negative_prompt = args.get("negative_prompt") + " " + p.negative_prompt + else: + copy_p.negative_prompt = p.negative_prompt + " " + args.get("negative_prompt") + proc = process_images(copy_p) images += proc.images From dfc4c27b2402a35a1820ffa549e74bb79873aaaa Mon Sep 17 00:00:00 2001 From: David Benson Date: Mon, 23 Oct 2023 08:26:40 -0400 Subject: [PATCH 0201/1174] linting issue --- scripts/prompts_from_file.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/prompts_from_file.py b/scripts/prompts_from_file.py index 1aadf113594..3c09bb97660 100644 --- a/scripts/prompts_from_file.py +++ b/scripts/prompts_from_file.py @@ -164,7 +164,7 @@ def run(self, p, checkbox_iterate, checkbox_iterate_batch, prompt_position, prom copy_p.prompt = args.get("prompt") + " " + p.prompt else: copy_p.prompt = p.prompt + " " + args.get("prompt") - + if args.get("negative_prompt") and p.negative_prompt: if prompt_position == "start": copy_p.negative_prompt = args.get("negative_prompt") + " " + p.negative_prompt From eaa9f5162fbca2ebcb2682eb861bc7e5510a2b66 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Tue, 24 Oct 2023 01:49:05 +0800 Subject: [PATCH 0202/1174] Add CPU fp8 support Since norm layer need fp32, I only convert the linear operation layer(conv2d/linear) And TE have some pytorch function not support bf16 amp in CPU. I add a condition to indicate if the autocast is for unet. --- modules/devices.py | 6 +++++- modules/processing.py | 2 +- modules/sd_models.py | 20 ++++++++++++++++---- 3 files changed, 22 insertions(+), 6 deletions(-) diff --git a/modules/devices.py b/modules/devices.py index 1d4eb563517..0cd2b55d279 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -71,6 +71,7 @@ def enable_tf32(): errors.run(enable_tf32, "Enabling TF32") cpu: torch.device = torch.device("cpu") +fp8: bool = False device: torch.device = None device_interrogate: torch.device = None device_gfpgan: torch.device = None @@ -93,10 +94,13 @@ def cond_cast_float(input): nv_rng = None -def autocast(disable=False): +def autocast(disable=False, unet=False): if disable: return contextlib.nullcontext() + if unet and fp8 and device==cpu: + return torch.autocast("cpu", dtype=torch.bfloat16, enabled=True) + if dtype == torch.float32 or shared.cmd_opts.precision == "full": return contextlib.nullcontext() diff --git a/modules/processing.py b/modules/processing.py index 40598f5cff4..2df8a7ea7a8 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -865,7 +865,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: if p.n_iter > 1: shared.state.job = f"Batch {n+1} out of {p.n_iter}" - with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast(): + with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast(unet=True): samples_ddim = p.sample(conditioning=p.c, unconditional_conditioning=p.uc, seeds=p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, prompts=p.prompts) if getattr(samples_ddim, 'already_decoded', False): diff --git a/modules/sd_models.py b/modules/sd_models.py index 08af128fc4e..c5fe57bfec0 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -391,12 +391,24 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer devices.dtype_unet = torch.float16 timer.record("apply half()") - if shared.cmd_opts.opt_unet_fp8_storage: + + if shared.cmd_opts.opt_unet_fp8_storage: + enable_fp8 = True + elif model.is_sdxl and shared.cmd_opts.opt_unet_fp8_storage_xl: + enable_fp8 = True + + if enable_fp8: + devices.fp8 = True + if devices.device == devices.cpu: + for module in model.model.diffusion_model.modules(): + if isinstance(module, torch.nn.Conv2d): + module.to(torch.float8_e4m3fn) + elif isinstance(module, torch.nn.Linear): + module.to(torch.float8_e4m3fn) + timer.record("apply fp8 unet for cpu") + else: model.model.diffusion_model = model.model.diffusion_model.to(torch.float8_e4m3fn) timer.record("apply fp8 unet") - elif model.is_sdxl and shared.cmd_opts.opt_unet_fp8_storage_xl: - model.model.diffusion_model = model.model.diffusion_model.to(torch.float8_e4m3fn) - timer.record("apply fp8 unet for sdxl") devices.unet_needs_upcast = shared.cmd_opts.upcast_sampling and devices.dtype == torch.float16 and devices.dtype_unet == torch.float16 From 9c1eba2af3a6f9cd6282b3a367656793cbe70c01 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Tue, 24 Oct 2023 02:11:27 +0800 Subject: [PATCH 0203/1174] Fix lint --- modules/sd_models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/sd_models.py b/modules/sd_models.py index c5fe57bfec0..44d4038b25b 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -396,7 +396,7 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer enable_fp8 = True elif model.is_sdxl and shared.cmd_opts.opt_unet_fp8_storage_xl: enable_fp8 = True - + if enable_fp8: devices.fp8 = True if devices.device == devices.cpu: From 1df6c8bfec4715610d64684b6ad2fa38c76c1df6 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Wed, 25 Oct 2023 11:36:43 +0800 Subject: [PATCH 0204/1174] fp8 for TE --- modules/sd_models.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/modules/sd_models.py b/modules/sd_models.py index 44d4038b25b..69395294300 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -407,6 +407,13 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer module.to(torch.float8_e4m3fn) timer.record("apply fp8 unet for cpu") else: + if model.is_sdxl: + cond_stage = model.conditioner + else: + cond_stage = model.cond_stage_model + for module in cond_stage.modules(): + if isinstance(module, torch.nn.Linear): + module.to(torch.float8_e4m3fn) model.model.diffusion_model = model.model.diffusion_model.to(torch.float8_e4m3fn) timer.record("apply fp8 unet") From 4830b251366436ee8499c003fe87e46ddb4a4581 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Wed, 25 Oct 2023 11:53:37 +0800 Subject: [PATCH 0205/1174] Fix alphas_cumprod dtype --- modules/sd_models.py | 1 + 1 file changed, 1 insertion(+) diff --git a/modules/sd_models.py b/modules/sd_models.py index 69395294300..23660454505 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -416,6 +416,7 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer module.to(torch.float8_e4m3fn) model.model.diffusion_model = model.model.diffusion_model.to(torch.float8_e4m3fn) timer.record("apply fp8 unet") + model.alphas_cumprod = model.alphas_cumprod.to(torch.float32) devices.unet_needs_upcast = shared.cmd_opts.upcast_sampling and devices.dtype == torch.float16 and devices.dtype_unet == torch.float16 From bf5067f50ca32cd4764638702e3cc38bca8bfd8b Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Wed, 25 Oct 2023 12:54:28 +0800 Subject: [PATCH 0206/1174] Fix alphas cumprod --- modules/sd_models.py | 3 ++- modules/sd_models_xl.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/modules/sd_models.py b/modules/sd_models.py index 23660454505..7ed89a9c93f 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -396,6 +396,8 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer enable_fp8 = True elif model.is_sdxl and shared.cmd_opts.opt_unet_fp8_storage_xl: enable_fp8 = True + else: + enable_fp8 = False if enable_fp8: devices.fp8 = True @@ -416,7 +418,6 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer module.to(torch.float8_e4m3fn) model.model.diffusion_model = model.model.diffusion_model.to(torch.float8_e4m3fn) timer.record("apply fp8 unet") - model.alphas_cumprod = model.alphas_cumprod.to(torch.float32) devices.unet_needs_upcast = shared.cmd_opts.upcast_sampling and devices.dtype == torch.float16 and devices.dtype_unet == torch.float16 diff --git a/modules/sd_models_xl.py b/modules/sd_models_xl.py index 0112332161f..11259a369de 100644 --- a/modules/sd_models_xl.py +++ b/modules/sd_models_xl.py @@ -93,7 +93,7 @@ def extend_sdxl(model): model.parameterization = "v" if isinstance(model.denoiser.scaling, sgm.modules.diffusionmodules.denoiser_scaling.VScaling) else "eps" discretization = sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization() - model.alphas_cumprod = torch.asarray(discretization.alphas_cumprod, device=devices.device, dtype=dtype) + model.alphas_cumprod = torch.asarray(discretization.alphas_cumprod, device=devices.device, dtype=torch.float32) model.conditioner.wrapped = torch.nn.Module() From dda067f64d3289cee3ffd65767126cb30ae73b13 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Wed, 25 Oct 2023 19:53:22 +0800 Subject: [PATCH 0207/1174] ignore mps for fp8 --- modules/sd_models.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/modules/sd_models.py b/modules/sd_models.py index 7ed89a9c93f..ccb6afd2ade 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -392,7 +392,9 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer devices.dtype_unet = torch.float16 timer.record("apply half()") - if shared.cmd_opts.opt_unet_fp8_storage: + if devices.get_optimal_device_name() == "mps": + enable_fp8 = False + elif shared.cmd_opts.opt_unet_fp8_storage: enable_fp8 = True elif model.is_sdxl and shared.cmd_opts.opt_unet_fp8_storage_xl: enable_fp8 = True From 0beb131c7ffae6f756a6339206da311232a36970 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Wed, 25 Oct 2023 20:07:37 +0800 Subject: [PATCH 0208/1174] change torch version --- modules/launch_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/modules/launch_utils.py b/modules/launch_utils.py index 8cdbafa50c8..636da679ecc 100644 --- a/modules/launch_utils.py +++ b/modules/launch_utils.py @@ -308,8 +308,8 @@ def requirements_met(requirements_file): def prepare_environment(): - torch_index_url = os.environ.get('TORCH_INDEX_URL', "https://download.pytorch.org/whl/cu118") - torch_command = os.environ.get('TORCH_COMMAND', f"pip install torch==2.0.1 torchvision==0.15.2 --extra-index-url {torch_index_url}") + torch_index_url = os.environ.get('TORCH_INDEX_URL', "https://download.pytorch.org/whl/cu121") + torch_command = os.environ.get('TORCH_COMMAND', f"pip install torch==2.1.0 torchvision==0.16.0 --extra-index-url {torch_index_url}") requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt") xformers_package = os.environ.get('XFORMERS_PACKAGE', 'xformers==0.0.20') From 5121846d34d74aee9b55d48d35c1559a710051b0 Mon Sep 17 00:00:00 2001 From: Won-Kyu Park Date: Wed, 25 Oct 2023 21:37:55 +0900 Subject: [PATCH 0209/1174] call state.jobnext() before postproces*() --- modules/processing.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/modules/processing.py b/modules/processing.py index 40598f5cff4..70ad1ebed0a 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -886,6 +886,8 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: devices.torch_gc() + state.nextjob() + if p.scripts is not None: p.scripts.postprocess_batch(p, x_samples_ddim, batch_number=n) @@ -958,8 +960,6 @@ def infotext(index=0, use_main_prompt=False): devices.torch_gc() - state.nextjob() - if not infotexts: infotexts.append(Processed(p, []).infotext(p, 0)) From d4d3134f6d2d232c7bcfa80900a362921e644976 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Sat, 28 Oct 2023 15:24:26 +0800 Subject: [PATCH 0210/1174] ManualCast for 10/16 series gpu --- modules/devices.py | 57 ++++++++++++++++++++++++++++++++++++++----- modules/processing.py | 2 +- modules/sd_models.py | 21 +++++++++------- 3 files changed, 64 insertions(+), 16 deletions(-) diff --git a/modules/devices.py b/modules/devices.py index 0cd2b55d279..c05f2b35e5e 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -16,6 +16,23 @@ def has_mps() -> bool: return mac_specific.has_mps +def cuda_no_autocast(device_id=None) -> bool: + if device_id is None: + device_id = get_cuda_device_id() + return ( + torch.cuda.get_device_capability(device_id) == (7, 5) + and torch.cuda.get_device_name(device_id).startswith("NVIDIA GeForce GTX 16") + ) + + +def get_cuda_device_id(): + return ( + int(shared.cmd_opts.device_id) + if shared.cmd_opts.device_id is not None and shared.cmd_opts.device_id.isdigit() + else 0 + ) or torch.cuda.current_device() + + def get_cuda_device_string(): if shared.cmd_opts.device_id is not None: return f"cuda:{shared.cmd_opts.device_id}" @@ -60,8 +77,7 @@ def enable_tf32(): # enabling benchmark option seems to enable a range of cards to do fp16 when they otherwise can't # see https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/4407 - device_id = (int(shared.cmd_opts.device_id) if shared.cmd_opts.device_id is not None and shared.cmd_opts.device_id.isdigit() else 0) or torch.cuda.current_device() - if torch.cuda.get_device_capability(device_id) == (7, 5) and torch.cuda.get_device_name(device_id).startswith("NVIDIA GeForce GTX 16"): + if cuda_no_autocast(): torch.backends.cudnn.benchmark = True torch.backends.cuda.matmul.allow_tf32 = True @@ -92,15 +108,44 @@ def cond_cast_float(input): nv_rng = None - - -def autocast(disable=False, unet=False): +patch_module_list = [ + torch.nn.Linear, + torch.nn.Conv2d, + torch.nn.MultiheadAttention, + torch.nn.GroupNorm, + torch.nn.LayerNorm, +] + +@contextlib.contextmanager +def manual_autocast(): + def manual_cast_forward(self, *args, **kwargs): + org_dtype = next(self.parameters()).dtype + self.to(dtype) + result = self.org_forward(*args, **kwargs) + self.to(org_dtype) + return result + for module_type in patch_module_list: + org_forward = module_type.forward + module_type.forward = manual_cast_forward + module_type.org_forward = org_forward + try: + yield None + finally: + for module_type in patch_module_list: + module_type.forward = module_type.org_forward + + +def autocast(disable=False): + print(fp8, dtype, shared.cmd_opts.precision, device) if disable: return contextlib.nullcontext() - if unet and fp8 and device==cpu: + if fp8 and device==cpu: return torch.autocast("cpu", dtype=torch.bfloat16, enabled=True) + if fp8 and (dtype == torch.float32 or shared.cmd_opts.precision == "full" or cuda_no_autocast()): + return manual_autocast() + if dtype == torch.float32 or shared.cmd_opts.precision == "full": return contextlib.nullcontext() diff --git a/modules/processing.py b/modules/processing.py index 2df8a7ea7a8..40598f5cff4 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -865,7 +865,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: if p.n_iter > 1: shared.state.job = f"Batch {n+1} out of {p.n_iter}" - with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast(unet=True): + with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast(): samples_ddim = p.sample(conditioning=p.c, unconditional_conditioning=p.uc, seeds=p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, prompts=p.prompts) if getattr(samples_ddim, 'already_decoded', False): diff --git a/modules/sd_models.py b/modules/sd_models.py index ccb6afd2ade..31bcb913a51 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -403,23 +403,26 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer if enable_fp8: devices.fp8 = True + if model.is_sdxl: + cond_stage = model.conditioner + else: + cond_stage = model.cond_stage_model + + for module in cond_stage.modules(): + if isinstance(module, torch.nn.Linear): + module.to(torch.float8_e4m3fn) + if devices.device == devices.cpu: for module in model.model.diffusion_model.modules(): if isinstance(module, torch.nn.Conv2d): module.to(torch.float8_e4m3fn) elif isinstance(module, torch.nn.Linear): module.to(torch.float8_e4m3fn) - timer.record("apply fp8 unet for cpu") else: - if model.is_sdxl: - cond_stage = model.conditioner - else: - cond_stage = model.cond_stage_model - for module in cond_stage.modules(): - if isinstance(module, torch.nn.Linear): - module.to(torch.float8_e4m3fn) model.model.diffusion_model = model.model.diffusion_model.to(torch.float8_e4m3fn) - timer.record("apply fp8 unet") + timer.record("apply fp8") + else: + devices.fp8 = False devices.unet_needs_upcast = shared.cmd_opts.upcast_sampling and devices.dtype == torch.float16 and devices.dtype_unet == torch.float16 From ddc2a3499b8cd120b4a42358bcd33137ce1d1e75 Mon Sep 17 00:00:00 2001 From: KohakuBlueleaf Date: Sat, 28 Oct 2023 16:52:35 +0800 Subject: [PATCH 0211/1174] Add MPS manual cast --- modules/devices.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/modules/devices.py b/modules/devices.py index c05f2b35e5e..d7c905c283a 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -121,6 +121,8 @@ def manual_autocast(): def manual_cast_forward(self, *args, **kwargs): org_dtype = next(self.parameters()).dtype self.to(dtype) + args = [arg.to(dtype) if isinstance(arg, torch.Tensor) else arg for arg in args] + kwargs = {k: v.to(dtype) if isinstance(v, torch.Tensor) else v for k, v in kwargs.items()} result = self.org_forward(*args, **kwargs) self.to(org_dtype) return result @@ -136,7 +138,6 @@ def manual_cast_forward(self, *args, **kwargs): def autocast(disable=False): - print(fp8, dtype, shared.cmd_opts.precision, device) if disable: return contextlib.nullcontext() @@ -146,6 +147,9 @@ def autocast(disable=False): if fp8 and (dtype == torch.float32 or shared.cmd_opts.precision == "full" or cuda_no_autocast()): return manual_autocast() + if has_mps() and shared.cmd_opts.precision != "full": + return manual_autocast() + if dtype == torch.float32 or shared.cmd_opts.precision == "full": return contextlib.nullcontext() From f2b83517aa49268219706f9718d734c6f242fd93 Mon Sep 17 00:00:00 2001 From: Nick Harrison <42382362+nickpharrison@users.noreply.github.com> Date: Sun, 29 Oct 2023 15:40:13 +0000 Subject: [PATCH 0212/1174] Add new arguments to known command prompts --- modules/cmd_args.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/modules/cmd_args.py b/modules/cmd_args.py index 4e602a84270..c6d4b61239e 100644 --- a/modules/cmd_args.py +++ b/modules/cmd_args.py @@ -76,7 +76,9 @@ parser.add_argument("--show-negative-prompt", action='store_true', help="does not do anything", default=False) parser.add_argument("--ui-config-file", type=str, help="filename to use for ui configuration", default=os.path.join(data_path, 'ui-config.json')) parser.add_argument("--hide-ui-dir-config", action='store_true', help="hide directory configuration from webui", default=False) -parser.add_argument("--freeze-settings", action='store_true', help="disable editing settings", default=False) +parser.add_argument("--freeze-settings", action='store_true', help="disable editing of all settings globally", default=False) +parser.add_argument("--freeze-settings-in-sections", type=str, help='disable editing settings in specific sections of the settings page by specifying a comma-delimited list such like "saving-images,upscaling". The list of setting names can be found in the modules/shared_options.py file', default=None) +parser.add_argument("--freeze-specific-settings", type=str, help='disable editing of individual settings by specifying a comma-delimited list like "samples_save,samples_format". The list of setting names can be found in the config.json file', default=None) parser.add_argument("--ui-settings-file", type=str, help="filename to use for ui settings", default=os.path.join(data_path, 'config.json')) parser.add_argument("--gradio-debug", action='store_true', help="launch gradio with --debug option") parser.add_argument("--gradio-auth", type=str, help='set gradio authentication like "username:password"; or comma-delimit multiple like "u1:p1,u2:p2,u3:p3"', default=None) @@ -90,7 +92,7 @@ parser.add_argument("--theme", type=str, help="launches the UI with light or dark theme", default=None) parser.add_argument("--use-textbox-seed", action='store_true', help="use textbox for seeds in UI (no up/down, but possible to input long seeds)", default=False) parser.add_argument("--disable-console-progressbars", action='store_true', help="do not output progressbars to console", default=False) -parser.add_argument("--enable-console-prompts", action='store_true', help="does not do anything", default=False) # Legacy compatibility, use as default value shared.opts.enable_console_prompts +parser.add_argument("--enable-console-prompts", action='store_true', help="print prompts to console when generating with txt2img and img2img", default=False) parser.add_argument('--vae-path', type=str, help='Checkpoint to use as VAE; setting this argument disables all settings related to VAE', default=None) parser.add_argument("--disable-safe-unpickle", action='store_true', help="disable checking pytorch models for malicious code", default=False) parser.add_argument("--api", action='store_true', help="use api=True to launch the API together with the webui (use --nowebui instead for only the API)") @@ -112,9 +114,8 @@ parser.add_argument("--no-hashing", action='store_true', help="disable sha256 hashing of checkpoints to help loading performance", default=False) parser.add_argument("--no-download-sd-model", action='store_true', help="don't download SD1.5 model even if no model is found in --ckpt-dir", default=False) parser.add_argument('--subpath', type=str, help='customize the subpath for gradio, use with reverse proxy') -parser.add_argument('--add-stop-route', action='store_true', help='does not do anything') +parser.add_argument('--add-stop-route', action='store_true', help='add /_stop route to stop server') parser.add_argument('--api-server-stop', action='store_true', help='enable server stop/restart/kill via api') parser.add_argument('--timeout-keep-alive', type=int, default=30, help='set timeout_keep_alive for uvicorn') parser.add_argument("--disable-all-extensions", action='store_true', help="prevent all extensions from running regardless of any other settings", default=False) -parser.add_argument("--disable-extra-extensions", action='store_true', help="prevent all extensions except built-in from running regardless of any other settings", default=False) -parser.add_argument("--skip-load-model-at-start", action='store_true', help="if load a model at web start, only take effect when --nowebui", ) +parser.add_argument("--disable-extra-extensions", action='store_true', help=" prevent all extensions except built-in from running regardless of any other settings", default=False) From 844c23975f369f85c7179379ec419e1bd067de18 Mon Sep 17 00:00:00 2001 From: Nick Harrison <42382362+nickpharrison@users.noreply.github.com> Date: Sun, 29 Oct 2023 15:40:58 +0000 Subject: [PATCH 0213/1174] Add assertions for checking additional settings freezing parameters --- modules/options.py | 23 +++++++++++++++++++---- 1 file changed, 19 insertions(+), 4 deletions(-) diff --git a/modules/options.py b/modules/options.py index ab40aff73f0..1ac32d9503a 100644 --- a/modules/options.py +++ b/modules/options.py @@ -85,18 +85,35 @@ def __setattr__(self, key, value): if self.data is not None: if key in self.data or key in self.data_labels: + + # Check that settings aren't globally frozen assert not cmd_opts.freeze_settings, "changing settings is disabled" + # Get the info related to the setting being changed info = self.data_labels.get(key, None) if info.do_not_save: return + # Restrict component arguments comp_args = info.component_args if info else None if isinstance(comp_args, dict) and comp_args.get('visible', True) is False: - raise RuntimeError(f"not possible to set {key} because it is restricted") + raise RuntimeError(f"not possible to set '{key}' because it is restricted") + # Check that this section isn't frozen + if cmd_opts.freeze_settings_in_sections is not None: + frozen_sections = list(map(str.strip, cmd_opts.freeze_settings_in_sections.split(','))) # Trim whitespace from section names + section_key = info.section[0] + section_name = info.section[1] + assert section_key not in frozen_sections, f"not possible to set '{key}' because settings in section '{section_name}' ({section_key}) are frozen with --freeze-settings-in-sections" + + # Check that this section of the settings isn't frozen + if cmd_opts.freeze_specific_settings is not None: + frozen_keys = list(map(str.strip, cmd_opts.freeze_specific_settings.split(','))) # Trim whitespace from setting keys + assert key not in frozen_keys, f"not possible to set '{key}' because this setting is frozen with --freeze-specific-settings" + + # Check shorthand option which disables editing options in "saving-paths" if cmd_opts.hide_ui_dir_config and key in self.restricted_opts: - raise RuntimeError(f"not possible to set {key} because it is restricted") + raise RuntimeError(f"not possible to set '{key}' because it is restricted with --hide_ui_dir_config") self.data[key] = value return @@ -210,8 +227,6 @@ def dumpjson(self): def add_option(self, key, info): self.data_labels[key] = info - if key not in self.data: - self.data[key] = info.default def reorder(self): """reorder settings so that all items related to section always go together""" From be31e7e71a08dc27543d31aa6e6532463ccbf20f Mon Sep 17 00:00:00 2001 From: Nick Harrison <42382362+nickpharrison@users.noreply.github.com> Date: Sun, 29 Oct 2023 16:05:01 +0000 Subject: [PATCH 0214/1174] Remove blank line whitespace --- modules/options.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/options.py b/modules/options.py index 1ac32d9503a..e270a42df85 100644 --- a/modules/options.py +++ b/modules/options.py @@ -85,7 +85,7 @@ def __setattr__(self, key, value): if self.data is not None: if key in self.data or key in self.data_labels: - + # Check that settings aren't globally frozen assert not cmd_opts.freeze_settings, "changing settings is disabled" From fbc5c531b9cfa949d60dae19420d01f8af186b55 Mon Sep 17 00:00:00 2001 From: Meerkov Date: Sun, 29 Oct 2023 15:37:08 -0700 Subject: [PATCH 0215/1174] Fix #13796 Fix comment error that makes understanding scheduling more confusing. --- modules/prompt_parser.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/prompt_parser.py b/modules/prompt_parser.py index 334efeef317..86b7acb50a8 100644 --- a/modules/prompt_parser.py +++ b/modules/prompt_parser.py @@ -5,7 +5,7 @@ from typing import List import lark -# a prompt like this: "fantasy landscape with a [mountain:lake:0.25] and [an oak:a christmas tree:0.75][ in foreground::0.6][ in background:0.25] [shoddy:masterful:0.5]" +# a prompt like this: "fantasy landscape with a [mountain:lake:0.25] and [an oak:a christmas tree:0.75][ in foreground::0.6][: in background:0.25] [shoddy:masterful:0.5]" # will be represented with prompt_schedule like this (assuming steps=100): # [25, 'fantasy landscape with a mountain and an oak in foreground shoddy'] # [50, 'fantasy landscape with a lake and an oak in foreground in background shoddy'] From a2fad6ee055f3f4e98e46b6c2d912776fe608214 Mon Sep 17 00:00:00 2001 From: v0xie <28695009+v0xie@users.noreply.github.com> Date: Wed, 1 Nov 2023 22:34:27 -0700 Subject: [PATCH 0216/1174] test implementation based on kohaku diag-oft implementation --- extensions-builtin/Lora/network_oft.py | 59 +++++++++++++++++--------- 1 file changed, 38 insertions(+), 21 deletions(-) diff --git a/extensions-builtin/Lora/network_oft.py b/extensions-builtin/Lora/network_oft.py index e43c9a1dfff..ff61b3699bd 100644 --- a/extensions-builtin/Lora/network_oft.py +++ b/extensions-builtin/Lora/network_oft.py @@ -1,5 +1,6 @@ import torch import network +from einops import rearrange class ModuleTypeOFT(network.ModuleType): @@ -30,35 +31,51 @@ def __init__(self, net: network.Network, weights: network.NetworkWeights): self.org_module: list[torch.Module] = [self.sd_module] - def merge_weight(self, R_weight, org_weight): - R_weight = R_weight.to(org_weight.device, dtype=org_weight.dtype) - if org_weight.dim() == 4: - weight = torch.einsum("oihw, op -> pihw", org_weight, R_weight) - else: - weight = torch.einsum("oi, op -> pi", org_weight, R_weight) - return weight + # def merge_weight(self, R_weight, org_weight): + # R_weight = R_weight.to(org_weight.device, dtype=org_weight.dtype) + # if org_weight.dim() == 4: + # weight = torch.einsum("oihw, op -> pihw", org_weight, R_weight) + # else: + # weight = torch.einsum("oi, op -> pi", org_weight, R_weight) + # weight = torch.einsum( + # "k n m, k n ... -> k m ...", + # self.oft_diag * scale + torch.eye(self.block_size, device=device), + # org_weight + # ) + # return weight def get_weight(self, oft_blocks, multiplier=None): - constraint = self.constraint.to(oft_blocks.device, dtype=oft_blocks.dtype) + # constraint = self.constraint.to(oft_blocks.device, dtype=oft_blocks.dtype) - block_Q = oft_blocks - oft_blocks.transpose(1, 2) - norm_Q = torch.norm(block_Q.flatten()) - new_norm_Q = torch.clamp(norm_Q, max=constraint) - block_Q = block_Q * ((new_norm_Q + 1e-8) / (norm_Q + 1e-8)) - m_I = torch.eye(self.block_size, device=oft_blocks.device).unsqueeze(0).repeat(self.num_blocks, 1, 1) - block_R = torch.matmul(m_I + block_Q, (m_I - block_Q).inverse()) + # block_Q = oft_blocks - oft_blocks.transpose(1, 2) + # norm_Q = torch.norm(block_Q.flatten()) + # new_norm_Q = torch.clamp(norm_Q, max=constraint) + # block_Q = block_Q * ((new_norm_Q + 1e-8) / (norm_Q + 1e-8)) + # m_I = torch.eye(self.block_size, device=oft_blocks.device).unsqueeze(0).repeat(self.num_blocks, 1, 1) + # block_R = torch.matmul(m_I + block_Q, (m_I - block_Q).inverse()) - block_R_weighted = multiplier * block_R + (1 - multiplier) * m_I - R = torch.block_diag(*block_R_weighted) + # block_R_weighted = multiplier * block_R + (1 - multiplier) * m_I + # R = torch.block_diag(*block_R_weighted) + #return R + return self.oft_blocks - return R def calc_updown(self, orig_weight): multiplier = self.multiplier() * self.calc_scale() - R = self.get_weight(self.oft_blocks, multiplier) - merged_weight = self.merge_weight(R, orig_weight) - - updown = merged_weight.to(orig_weight.device, dtype=orig_weight.dtype) - orig_weight + #R = self.get_weight(self.oft_blocks, multiplier) + R = self.oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype) + #merged_weight = self.merge_weight(R, orig_weight) + + orig_weight = rearrange(orig_weight, '(k n) ... -> k n ...', k=self.num_blocks, n=self.block_size) + weight = torch.einsum( + 'k n m, k n ... -> k m ...', + R * multiplier + torch.eye(self.block_size, device=orig_weight.device), + orig_weight + ) + weight = rearrange(weight, 'k m ... -> (k m) ...') + + #updown = merged_weight.to(orig_weight.device, dtype=orig_weight.dtype) - orig_weight + updown = weight.to(orig_weight.device, dtype=orig_weight.dtype) - orig_weight output_shape = orig_weight.shape orig_weight = orig_weight From 65ccd6305fcf72347d5ed68f03095dced865ef6e Mon Sep 17 00:00:00 2001 From: v0xie <28695009+v0xie@users.noreply.github.com> Date: Thu, 2 Nov 2023 00:11:32 -0700 Subject: [PATCH 0217/1174] detect diag_oft type --- extensions-builtin/Lora/networks.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/extensions-builtin/Lora/networks.py b/extensions-builtin/Lora/networks.py index 78a97033d58..7f814706adb 100644 --- a/extensions-builtin/Lora/networks.py +++ b/extensions-builtin/Lora/networks.py @@ -191,10 +191,17 @@ def load_network(name, network_on_disk): key = key_network_without_network_parts.replace("lora_te1_text_model", "transformer_text_model") sd_module = shared.sd_model.network_layer_mapping.get(key, None) + # kohya_ss OFT module elif sd_module is None and "oft_unet" in key_network_without_network_parts: key = key_network_without_network_parts.replace("oft_unet", "diffusion_model") sd_module = shared.sd_model.network_layer_mapping.get(key, None) + # KohakuBlueLeaf OFT module + if sd_module is None and "oft_diag" in key: + key = key_network_without_network_parts.replace("lora_unet", "diffusion_model") + key = key_network_without_network_parts.replace("lora_te1_text_model", "0_transformer_text_model") + sd_module = shared.sd_model.network_layer_mapping.get(key, None) + if sd_module is None: keys_failed_to_match[key_network] = key continue From d727ddfccdc6d474767be9dc3bf504150e81a8a5 Mon Sep 17 00:00:00 2001 From: v0xie <28695009+v0xie@users.noreply.github.com> Date: Thu, 2 Nov 2023 00:13:11 -0700 Subject: [PATCH 0218/1174] no idea what i'm doing, trying to support both type of OFT, kblueleaf diag_oft has MultiheadAttn which kohya's doesn't?, attempt create new module based off network_lora.py, errors about tensor dim mismatch --- extensions-builtin/Lora/network_oft.py | 192 +++++++++++++++++++------ 1 file changed, 145 insertions(+), 47 deletions(-) diff --git a/extensions-builtin/Lora/network_oft.py b/extensions-builtin/Lora/network_oft.py index ff61b3699bd..e102eafc1f8 100644 --- a/extensions-builtin/Lora/network_oft.py +++ b/extensions-builtin/Lora/network_oft.py @@ -1,11 +1,12 @@ import torch import network from einops import rearrange +from modules import devices class ModuleTypeOFT(network.ModuleType): def create_module(self, net: network.Network, weights: network.NetworkWeights): - if all(x in weights.w for x in ["oft_blocks"]): + if all(x in weights.w for x in ["oft_blocks"]) or all(x in weights.w for x in ["oft_diag"]): return NetworkModuleOFT(net, weights) return None @@ -16,66 +17,117 @@ def __init__(self, net: network.Network, weights: network.NetworkWeights): super().__init__(net, weights) - self.oft_blocks = weights.w["oft_blocks"] - self.alpha = weights.w["alpha"] - self.dim = self.oft_blocks.shape[0] - self.num_blocks = self.dim - - if "Linear" in self.sd_module.__class__.__name__: + self.lin_module = None + # kohya-ss + if "oft_blocks" in weights.w.keys(): + self.is_kohya = True + self.oft_blocks = weights.w["oft_blocks"] + self.alpha = weights.w["alpha"] + self.dim = self.oft_blocks.shape[0] + elif "oft_diag" in weights.w.keys(): + self.is_kohya = False + self.oft_blocks = weights.w["oft_diag"] + # alpha is rank if alpha is 0 or None + if self.alpha is None: + pass + self.dim = self.oft_blocks.shape[0] # FIXME: almost certainly incorrect, assumes tensor is shape [*, m, n] + else: + raise ValueError("oft_blocks or oft_diag must be in weights dict") + + is_linear = type(self.sd_module) in [torch.nn.Linear, torch.nn.modules.linear.NonDynamicallyQuantizableLinear] + is_conv = type(self.sd_module) in [torch.nn.Conv2d] + is_other_linear = type(self.sd_module) in [ torch.nn.MultiheadAttention] + #if "Linear" in self.sd_module.__class__.__name__ or is_linear: + if is_linear: self.out_dim = self.sd_module.out_features - elif "Conv" in self.sd_module.__class__.__name__: + #elif hasattr(self.sd_module, "embed_dim"): + # self.out_dim = self.sd_module.embed_dim + #else: + # raise ValueError("Linear sd_module must have out_features or embed_dim") + elif is_other_linear: + self.out_dim = self.sd_module.embed_dim + elif is_conv: self.out_dim = self.sd_module.out_channels + else: + raise ValueError("sd_module must be Linear or Conv") + - self.constraint = self.alpha * self.out_dim - self.block_size = self.out_dim // self.num_blocks + if self.is_kohya: + self.num_blocks = self.dim + self.block_size = self.out_dim // self.num_blocks + self.constraint = self.alpha * self.out_dim + #elif is_linear or is_conv: + else: + self.num_blocks, self.block_size = factorization(self.out_dim, self.dim) + self.constraint = None self.org_module: list[torch.Module] = [self.sd_module] - # def merge_weight(self, R_weight, org_weight): - # R_weight = R_weight.to(org_weight.device, dtype=org_weight.dtype) - # if org_weight.dim() == 4: - # weight = torch.einsum("oihw, op -> pihw", org_weight, R_weight) - # else: - # weight = torch.einsum("oi, op -> pi", org_weight, R_weight) - # weight = torch.einsum( - # "k n m, k n ... -> k m ...", - # self.oft_diag * scale + torch.eye(self.block_size, device=device), - # org_weight - # ) - # return weight + # if is_other_linear: + # weight = self.oft_blocks.reshape(self.oft_blocks.shape[0], -1) + # module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False) + # with torch.no_grad(): + # if weight.shape != module.weight.shape: + # weight = weight.reshape(module.weight.shape) + # module.weight.copy_(weight) + # module.to(device=devices.cpu, dtype=devices.dtype) + # module.weight.requires_grad_(False) + # self.lin_module = module + #return module + + def merge_weight(self, R_weight, org_weight): + R_weight = R_weight.to(org_weight.device, dtype=org_weight.dtype) + if org_weight.dim() == 4: + weight = torch.einsum("oihw, op -> pihw", org_weight, R_weight) + else: + weight = torch.einsum("oi, op -> pi", org_weight, R_weight) + #weight = torch.einsum( + # "k n m, k n ... -> k m ...", + # self.oft_diag * scale + torch.eye(self.block_size, device=device), + # org_weight + #) + return weight def get_weight(self, oft_blocks, multiplier=None): - # constraint = self.constraint.to(oft_blocks.device, dtype=oft_blocks.dtype) + if self.constraint is not None: + constraint = self.constraint.to(oft_blocks.device, dtype=oft_blocks.dtype) - # block_Q = oft_blocks - oft_blocks.transpose(1, 2) - # norm_Q = torch.norm(block_Q.flatten()) - # new_norm_Q = torch.clamp(norm_Q, max=constraint) - # block_Q = block_Q * ((new_norm_Q + 1e-8) / (norm_Q + 1e-8)) - # m_I = torch.eye(self.block_size, device=oft_blocks.device).unsqueeze(0).repeat(self.num_blocks, 1, 1) - # block_R = torch.matmul(m_I + block_Q, (m_I - block_Q).inverse()) + block_Q = oft_blocks - oft_blocks.transpose(1, 2) + norm_Q = torch.norm(block_Q.flatten()) + if self.constraint is not None: + new_norm_Q = torch.clamp(norm_Q, max=constraint) + else: + new_norm_Q = norm_Q + block_Q = block_Q * ((new_norm_Q + 1e-8) / (norm_Q + 1e-8)) + m_I = torch.eye(self.block_size, device=oft_blocks.device).unsqueeze(0).repeat(self.num_blocks, 1, 1) + block_R = torch.matmul(m_I + block_Q, (m_I - block_Q).inverse()) - # block_R_weighted = multiplier * block_R + (1 - multiplier) * m_I - # R = torch.block_diag(*block_R_weighted) - #return R - return self.oft_blocks + block_R_weighted = multiplier * block_R + (1 - multiplier) * m_I + R = torch.block_diag(*block_R_weighted) + return R + #return self.oft_blocks def calc_updown(self, orig_weight): multiplier = self.multiplier() * self.calc_scale() - #R = self.get_weight(self.oft_blocks, multiplier) - R = self.oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype) - #merged_weight = self.merge_weight(R, orig_weight) - - orig_weight = rearrange(orig_weight, '(k n) ... -> k n ...', k=self.num_blocks, n=self.block_size) - weight = torch.einsum( - 'k n m, k n ... -> k m ...', - R * multiplier + torch.eye(self.block_size, device=orig_weight.device), - orig_weight - ) - weight = rearrange(weight, 'k m ... -> (k m) ...') - - #updown = merged_weight.to(orig_weight.device, dtype=orig_weight.dtype) - orig_weight - updown = weight.to(orig_weight.device, dtype=orig_weight.dtype) - orig_weight + R = self.get_weight(self.oft_blocks, multiplier) + #R = self.oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype) + merged_weight = self.merge_weight(R, orig_weight) + + #if self.lin_module is not None: + # R = self.lin_module.weight.to(orig_weight.device, dtype=orig_weight.dtype) + # weight = torch.mul(torch.mul(R, multiplier), orig_weight) + #else: + # orig_weight = rearrange(orig_weight, '(k n) ... -> k n ...', k=self.num_blocks, n=self.block_size) + # weight = torch.einsum( + # 'k n m, k n ... -> k m ...', + # R * multiplier + torch.eye(self.block_size, device=orig_weight.device), + # orig_weight + # ) + # weight = rearrange(weight, 'k m ... -> (k m) ...') + + updown = merged_weight.to(orig_weight.device, dtype=orig_weight.dtype) - orig_weight + #updown = weight.to(orig_weight.device, dtype=orig_weight.dtype) - orig_weight output_shape = orig_weight.shape orig_weight = orig_weight @@ -100,3 +152,49 @@ def finalize_updown(self, updown, orig_weight, output_shape, ex_bias=None): ex_bias = ex_bias * self.multiplier() return updown, ex_bias + +# copied from https://github.com/KohakuBlueleaf/LyCORIS/blob/dev/lycoris/modules/lokr.py +def factorization(dimension: int, factor:int=-1) -> tuple[int, int]: + ''' + return a tuple of two value of input dimension decomposed by the number closest to factor + second value is higher or equal than first value. + + In LoRA with Kroneckor Product, first value is a value for weight scale. + secon value is a value for weight. + + Becuase of non-commutative property, A⊗B ≠ B⊗A. Meaning of two matrices is slightly different. + + examples) + factor + -1 2 4 8 16 ... + 127 -> 1, 127 127 -> 1, 127 127 -> 1, 127 127 -> 1, 127 127 -> 1, 127 + 128 -> 8, 16 128 -> 2, 64 128 -> 4, 32 128 -> 8, 16 128 -> 8, 16 + 250 -> 10, 25 250 -> 2, 125 250 -> 2, 125 250 -> 5, 50 250 -> 10, 25 + 360 -> 8, 45 360 -> 2, 180 360 -> 4, 90 360 -> 8, 45 360 -> 12, 30 + 512 -> 16, 32 512 -> 2, 256 512 -> 4, 128 512 -> 8, 64 512 -> 16, 32 + 1024 -> 32, 32 1024 -> 2, 512 1024 -> 4, 256 1024 -> 8, 128 1024 -> 16, 64 + ''' + + if factor > 0 and (dimension % factor) == 0: + m = factor + n = dimension // factor + if m > n: + n, m = m, n + return m, n + if factor < 0: + factor = dimension + m, n = 1, dimension + length = m + n + while m length or new_m>factor: + break + else: + m, n = new_m, new_n + if m > n: + n, m = m, n + return m, n + From 759515316e8ec536f34fad616e8c6a33674a164b Mon Sep 17 00:00:00 2001 From: Emily Zeng Date: Thu, 2 Nov 2023 21:54:48 -0400 Subject: [PATCH 0219/1174] added accordion settings options --- modules/shared_options.py | 2 + modules/ui.py | 502 +++++++++++++++++++------------------- 2 files changed, 254 insertions(+), 250 deletions(-) diff --git a/modules/shared_options.py b/modules/shared_options.py index 0a82216fff5..5b07dd04150 100644 --- a/modules/shared_options.py +++ b/modules/shared_options.py @@ -270,6 +270,8 @@ "hires_fix_show_sampler": OptionInfo(False, "Hires fix: show hires checkpoint and sampler selection").needs_reload_ui(), "hires_fix_show_prompts": OptionInfo(False, "Hires fix: show hires prompt and negative prompt").needs_reload_ui(), "disable_token_counters": OptionInfo(False, "Disable prompt token counters").needs_reload_ui(), + "txt2img_settings_accordion": OptionInfo(False, "Settings in txt2img hidden under Accordion").needs_reload_ui(), + "img2img_settings_accordion": OptionInfo(False, "Settings in img2img hidden under Accordion").needs_reload_ui(), })) diff --git a/modules/ui.py b/modules/ui.py index bcf39199715..d05b9f55d89 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -344,84 +344,85 @@ def create_ui(): extra_tabs.__enter__() with gr.Tab("Generation", id="txt2img_generation") as txt2img_generation_tab, ResizeHandleRow(equal_height=False): - with gr.Column(variant='compact', elem_id="txt2img_settings"): - scripts.scripts_txt2img.prepare_ui() + with gr.Accordion("Open for Settings", open=False) if shared.opts.img2img_settings_accordion else gr.Group(): + with gr.Column(variant='compact', elem_id="txt2img_settings"): + scripts.scripts_txt2img.prepare_ui() - for category in ordered_ui_categories(): - if category == "sampler": - steps, sampler_name = create_sampler_and_steps_selection(sd_samplers.visible_sampler_names(), "txt2img") + for category in ordered_ui_categories(): + if category == "sampler": + steps, sampler_name = create_sampler_and_steps_selection(sd_samplers.visible_sampler_names(), "txt2img") - elif category == "dimensions": - with FormRow(): - with gr.Column(elem_id="txt2img_column_size", scale=4): - width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="txt2img_width") - height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="txt2img_height") + elif category == "dimensions": + with FormRow(): + with gr.Column(elem_id="txt2img_column_size", scale=4): + width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="txt2img_width") + height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="txt2img_height") - with gr.Column(elem_id="txt2img_dimensions_row", scale=1, elem_classes="dimensions-tools"): - res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="txt2img_res_switch_btn", tooltip="Switch width/height") + with gr.Column(elem_id="txt2img_dimensions_row", scale=1, elem_classes="dimensions-tools"): + res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="txt2img_res_switch_btn", tooltip="Switch width/height") - if opts.dimensions_and_batch_together: - with gr.Column(elem_id="txt2img_column_batch"): - batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="txt2img_batch_count") - batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="txt2img_batch_size") + if opts.dimensions_and_batch_together: + with gr.Column(elem_id="txt2img_column_batch"): + batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="txt2img_batch_count") + batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="txt2img_batch_size") - elif category == "cfg": - with gr.Row(): - cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0, elem_id="txt2img_cfg_scale") + elif category == "cfg": + with gr.Row(): + cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0, elem_id="txt2img_cfg_scale") - elif category == "checkboxes": - with FormRow(elem_classes="checkboxes-row", variant="compact"): - pass + elif category == "checkboxes": + with FormRow(elem_classes="checkboxes-row", variant="compact"): + pass - elif category == "accordions": - with gr.Row(elem_id="txt2img_accordions", elem_classes="accordions"): - with InputAccordion(False, label="Hires. fix", elem_id="txt2img_hr") as enable_hr: - with enable_hr.extra(): - hr_final_resolution = FormHTML(value="", elem_id="txtimg_hr_finalres", label="Upscaled resolution", interactive=False, min_width=0) + elif category == "accordions": + with gr.Row(elem_id="txt2img_accordions", elem_classes="accordions"): + with InputAccordion(False, label="Hires. fix", elem_id="txt2img_hr") as enable_hr: + with enable_hr.extra(): + hr_final_resolution = FormHTML(value="", elem_id="txtimg_hr_finalres", label="Upscaled resolution", interactive=False, min_width=0) - with FormRow(elem_id="txt2img_hires_fix_row1", variant="compact"): - hr_upscaler = gr.Dropdown(label="Upscaler", elem_id="txt2img_hr_upscaler", choices=[*shared.latent_upscale_modes, *[x.name for x in shared.sd_upscalers]], value=shared.latent_upscale_default_mode) - hr_second_pass_steps = gr.Slider(minimum=0, maximum=150, step=1, label='Hires steps', value=0, elem_id="txt2img_hires_steps") - denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.7, elem_id="txt2img_denoising_strength") + with FormRow(elem_id="txt2img_hires_fix_row1", variant="compact"): + hr_upscaler = gr.Dropdown(label="Upscaler", elem_id="txt2img_hr_upscaler", choices=[*shared.latent_upscale_modes, *[x.name for x in shared.sd_upscalers]], value=shared.latent_upscale_default_mode) + hr_second_pass_steps = gr.Slider(minimum=0, maximum=150, step=1, label='Hires steps', value=0, elem_id="txt2img_hires_steps") + denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.7, elem_id="txt2img_denoising_strength") - with FormRow(elem_id="txt2img_hires_fix_row2", variant="compact"): - hr_scale = gr.Slider(minimum=1.0, maximum=4.0, step=0.05, label="Upscale by", value=2.0, elem_id="txt2img_hr_scale") - hr_resize_x = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize width to", value=0, elem_id="txt2img_hr_resize_x") - hr_resize_y = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize height to", value=0, elem_id="txt2img_hr_resize_y") + with FormRow(elem_id="txt2img_hires_fix_row2", variant="compact"): + hr_scale = gr.Slider(minimum=1.0, maximum=4.0, step=0.05, label="Upscale by", value=2.0, elem_id="txt2img_hr_scale") + hr_resize_x = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize width to", value=0, elem_id="txt2img_hr_resize_x") + hr_resize_y = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize height to", value=0, elem_id="txt2img_hr_resize_y") - with FormRow(elem_id="txt2img_hires_fix_row3", variant="compact", visible=opts.hires_fix_show_sampler) as hr_sampler_container: + with FormRow(elem_id="txt2img_hires_fix_row3", variant="compact", visible=opts.hires_fix_show_sampler) as hr_sampler_container: - hr_checkpoint_name = gr.Dropdown(label='Hires checkpoint', elem_id="hr_checkpoint", choices=["Use same checkpoint"] + modules.sd_models.checkpoint_tiles(use_short=True), value="Use same checkpoint") - create_refresh_button(hr_checkpoint_name, modules.sd_models.list_models, lambda: {"choices": ["Use same checkpoint"] + modules.sd_models.checkpoint_tiles(use_short=True)}, "hr_checkpoint_refresh") + hr_checkpoint_name = gr.Dropdown(label='Hires checkpoint', elem_id="hr_checkpoint", choices=["Use same checkpoint"] + modules.sd_models.checkpoint_tiles(use_short=True), value="Use same checkpoint") + create_refresh_button(hr_checkpoint_name, modules.sd_models.list_models, lambda: {"choices": ["Use same checkpoint"] + modules.sd_models.checkpoint_tiles(use_short=True)}, "hr_checkpoint_refresh") - hr_sampler_name = gr.Dropdown(label='Hires sampling method', elem_id="hr_sampler", choices=["Use same sampler"] + sd_samplers.visible_sampler_names(), value="Use same sampler") + hr_sampler_name = gr.Dropdown(label='Hires sampling method', elem_id="hr_sampler", choices=["Use same sampler"] + sd_samplers.visible_sampler_names(), value="Use same sampler") - with FormRow(elem_id="txt2img_hires_fix_row4", variant="compact", visible=opts.hires_fix_show_prompts) as hr_prompts_container: - with gr.Column(scale=80): - with gr.Row(): - hr_prompt = gr.Textbox(label="Hires prompt", elem_id="hires_prompt", show_label=False, lines=3, placeholder="Prompt for hires fix pass.\nLeave empty to use the same prompt as in first pass.", elem_classes=["prompt"]) - with gr.Column(scale=80): - with gr.Row(): - hr_negative_prompt = gr.Textbox(label="Hires negative prompt", elem_id="hires_neg_prompt", show_label=False, lines=3, placeholder="Negative prompt for hires fix pass.\nLeave empty to use the same negative prompt as in first pass.", elem_classes=["prompt"]) + with FormRow(elem_id="txt2img_hires_fix_row4", variant="compact", visible=opts.hires_fix_show_prompts) as hr_prompts_container: + with gr.Column(scale=80): + with gr.Row(): + hr_prompt = gr.Textbox(label="Hires prompt", elem_id="hires_prompt", show_label=False, lines=3, placeholder="Prompt for hires fix pass.\nLeave empty to use the same prompt as in first pass.", elem_classes=["prompt"]) + with gr.Column(scale=80): + with gr.Row(): + hr_negative_prompt = gr.Textbox(label="Hires negative prompt", elem_id="hires_neg_prompt", show_label=False, lines=3, placeholder="Negative prompt for hires fix pass.\nLeave empty to use the same negative prompt as in first pass.", elem_classes=["prompt"]) - scripts.scripts_txt2img.setup_ui_for_section(category) + scripts.scripts_txt2img.setup_ui_for_section(category) - elif category == "batch": - if not opts.dimensions_and_batch_together: - with FormRow(elem_id="txt2img_column_batch"): - batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="txt2img_batch_count") - batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="txt2img_batch_size") + elif category == "batch": + if not opts.dimensions_and_batch_together: + with FormRow(elem_id="txt2img_column_batch"): + batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="txt2img_batch_count") + batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="txt2img_batch_size") - elif category == "override_settings": - with FormRow(elem_id="txt2img_override_settings_row") as row: - override_settings = create_override_settings_dropdown('txt2img', row) + elif category == "override_settings": + with FormRow(elem_id="txt2img_override_settings_row") as row: + override_settings = create_override_settings_dropdown('txt2img', row) - elif category == "scripts": - with FormGroup(elem_id="txt2img_script_container"): - custom_inputs = scripts.scripts_txt2img.setup_ui() + elif category == "scripts": + with FormGroup(elem_id="txt2img_script_container"): + custom_inputs = scripts.scripts_txt2img.setup_ui() - if category not in {"accordions"}: - scripts.scripts_txt2img.setup_ui_for_section(category) + if category not in {"accordions"}: + scripts.scripts_txt2img.setup_ui_for_section(category) hr_resolution_preview_inputs = [enable_hr, width, height, hr_scale, hr_resize_x, hr_resize_y] @@ -560,214 +561,215 @@ def create_ui(): extra_tabs.__enter__() with gr.Tab("Generation", id="img2img_generation") as img2img_generation_tab, ResizeHandleRow(equal_height=False): - with gr.Column(variant='compact', elem_id="img2img_settings"): - copy_image_buttons = [] - copy_image_destinations = {} - - def add_copy_image_controls(tab_name, elem): - with gr.Row(variant="compact", elem_id=f"img2img_copy_to_{tab_name}"): - gr.HTML("Copy image to: ", elem_id=f"img2img_label_copy_to_{tab_name}") - - for title, name in zip(['img2img', 'sketch', 'inpaint', 'inpaint sketch'], ['img2img', 'sketch', 'inpaint', 'inpaint_sketch']): - if name == tab_name: - gr.Button(title, interactive=False) - copy_image_destinations[name] = elem - continue - - button = gr.Button(title) - copy_image_buttons.append((button, name, elem)) - - with gr.Tabs(elem_id="mode_img2img"): - img2img_selected_tab = gr.State(0) - - with gr.TabItem('img2img', id='img2img', elem_id="img2img_img2img_tab") as tab_img2img: - init_img = gr.Image(label="Image for img2img", elem_id="img2img_image", show_label=False, source="upload", interactive=True, type="pil", tool="editor", image_mode="RGBA", height=opts.img2img_editor_height) - add_copy_image_controls('img2img', init_img) - - with gr.TabItem('Sketch', id='img2img_sketch', elem_id="img2img_img2img_sketch_tab") as tab_sketch: - sketch = gr.Image(label="Image for img2img", elem_id="img2img_sketch", show_label=False, source="upload", interactive=True, type="pil", tool="color-sketch", image_mode="RGB", height=opts.img2img_editor_height, brush_color=opts.img2img_sketch_default_brush_color) - add_copy_image_controls('sketch', sketch) - - with gr.TabItem('Inpaint', id='inpaint', elem_id="img2img_inpaint_tab") as tab_inpaint: - init_img_with_mask = gr.Image(label="Image for inpainting with mask", show_label=False, elem_id="img2maskimg", source="upload", interactive=True, type="pil", tool="sketch", image_mode="RGBA", height=opts.img2img_editor_height, brush_color=opts.img2img_inpaint_mask_brush_color) - add_copy_image_controls('inpaint', init_img_with_mask) - - with gr.TabItem('Inpaint sketch', id='inpaint_sketch', elem_id="img2img_inpaint_sketch_tab") as tab_inpaint_color: - inpaint_color_sketch = gr.Image(label="Color sketch inpainting", show_label=False, elem_id="inpaint_sketch", source="upload", interactive=True, type="pil", tool="color-sketch", image_mode="RGB", height=opts.img2img_editor_height, brush_color=opts.img2img_inpaint_sketch_default_brush_color) - inpaint_color_sketch_orig = gr.State(None) - add_copy_image_controls('inpaint_sketch', inpaint_color_sketch) - - def update_orig(image, state): - if image is not None: - same_size = state is not None and state.size == image.size - has_exact_match = np.any(np.all(np.array(image) == np.array(state), axis=-1)) - edited = same_size and has_exact_match - return image if not edited or state is None else state - - inpaint_color_sketch.change(update_orig, [inpaint_color_sketch, inpaint_color_sketch_orig], inpaint_color_sketch_orig) - - with gr.TabItem('Inpaint upload', id='inpaint_upload', elem_id="img2img_inpaint_upload_tab") as tab_inpaint_upload: - init_img_inpaint = gr.Image(label="Image for img2img", show_label=False, source="upload", interactive=True, type="pil", elem_id="img_inpaint_base") - init_mask_inpaint = gr.Image(label="Mask", source="upload", interactive=True, type="pil", image_mode="RGBA", elem_id="img_inpaint_mask") - - with gr.TabItem('Batch', id='batch', elem_id="img2img_batch_tab") as tab_batch: - hidden = '
Disabled when launched with --hide-ui-dir-config.' if shared.cmd_opts.hide_ui_dir_config else '' - gr.HTML( - "

Process images in a directory on the same machine where the server is running." + - "
Use an empty output directory to save pictures normally instead of writing to the output directory." + - f"
Add inpaint batch mask directory to enable inpaint batch processing." - f"{hidden}

" + with gr.Accordion("Open for Settings", open=False) if shared.opts.img2img_settings_accordion else gr.Group(): + with gr.Column(variant='compact', elem_id="img2img_settings"): + copy_image_buttons = [] + copy_image_destinations = {} + + def add_copy_image_controls(tab_name, elem): + with gr.Row(variant="compact", elem_id=f"img2img_copy_to_{tab_name}"): + gr.HTML("Copy image to: ", elem_id=f"img2img_label_copy_to_{tab_name}") + + for title, name in zip(['img2img', 'sketch', 'inpaint', 'inpaint sketch'], ['img2img', 'sketch', 'inpaint', 'inpaint_sketch']): + if name == tab_name: + gr.Button(title, interactive=False) + copy_image_destinations[name] = elem + continue + + button = gr.Button(title) + copy_image_buttons.append((button, name, elem)) + + with gr.Tabs(elem_id="mode_img2img"): + img2img_selected_tab = gr.State(0) + + with gr.TabItem('img2img', id='img2img', elem_id="img2img_img2img_tab") as tab_img2img: + init_img = gr.Image(label="Image for img2img", elem_id="img2img_image", show_label=False, source="upload", interactive=True, type="pil", tool="editor", image_mode="RGBA", height=opts.img2img_editor_height) + add_copy_image_controls('img2img', init_img) + + with gr.TabItem('Sketch', id='img2img_sketch', elem_id="img2img_img2img_sketch_tab") as tab_sketch: + sketch = gr.Image(label="Image for img2img", elem_id="img2img_sketch", show_label=False, source="upload", interactive=True, type="pil", tool="color-sketch", image_mode="RGB", height=opts.img2img_editor_height, brush_color=opts.img2img_sketch_default_brush_color) + add_copy_image_controls('sketch', sketch) + + with gr.TabItem('Inpaint', id='inpaint', elem_id="img2img_inpaint_tab") as tab_inpaint: + init_img_with_mask = gr.Image(label="Image for inpainting with mask", show_label=False, elem_id="img2maskimg", source="upload", interactive=True, type="pil", tool="sketch", image_mode="RGBA", height=opts.img2img_editor_height, brush_color=opts.img2img_inpaint_mask_brush_color) + add_copy_image_controls('inpaint', init_img_with_mask) + + with gr.TabItem('Inpaint sketch', id='inpaint_sketch', elem_id="img2img_inpaint_sketch_tab") as tab_inpaint_color: + inpaint_color_sketch = gr.Image(label="Color sketch inpainting", show_label=False, elem_id="inpaint_sketch", source="upload", interactive=True, type="pil", tool="color-sketch", image_mode="RGB", height=opts.img2img_editor_height, brush_color=opts.img2img_inpaint_sketch_default_brush_color) + inpaint_color_sketch_orig = gr.State(None) + add_copy_image_controls('inpaint_sketch', inpaint_color_sketch) + + def update_orig(image, state): + if image is not None: + same_size = state is not None and state.size == image.size + has_exact_match = np.any(np.all(np.array(image) == np.array(state), axis=-1)) + edited = same_size and has_exact_match + return image if not edited or state is None else state + + inpaint_color_sketch.change(update_orig, [inpaint_color_sketch, inpaint_color_sketch_orig], inpaint_color_sketch_orig) + + with gr.TabItem('Inpaint upload', id='inpaint_upload', elem_id="img2img_inpaint_upload_tab") as tab_inpaint_upload: + init_img_inpaint = gr.Image(label="Image for img2img", show_label=False, source="upload", interactive=True, type="pil", elem_id="img_inpaint_base") + init_mask_inpaint = gr.Image(label="Mask", source="upload", interactive=True, type="pil", image_mode="RGBA", elem_id="img_inpaint_mask") + + with gr.TabItem('Batch', id='batch', elem_id="img2img_batch_tab") as tab_batch: + hidden = '
Disabled when launched with --hide-ui-dir-config.' if shared.cmd_opts.hide_ui_dir_config else '' + gr.HTML( + "

Process images in a directory on the same machine where the server is running." + + "
Use an empty output directory to save pictures normally instead of writing to the output directory." + + f"
Add inpaint batch mask directory to enable inpaint batch processing." + f"{hidden}

" + ) + img2img_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs, elem_id="img2img_batch_input_dir") + img2img_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, elem_id="img2img_batch_output_dir") + img2img_batch_inpaint_mask_dir = gr.Textbox(label="Inpaint batch mask directory (required for inpaint batch processing only)", **shared.hide_dirs, elem_id="img2img_batch_inpaint_mask_dir") + with gr.Accordion("PNG info", open=False): + img2img_batch_use_png_info = gr.Checkbox(label="Append png info to prompts", **shared.hide_dirs, elem_id="img2img_batch_use_png_info") + img2img_batch_png_info_dir = gr.Textbox(label="PNG info directory", **shared.hide_dirs, placeholder="Leave empty to use input directory", elem_id="img2img_batch_png_info_dir") + img2img_batch_png_info_props = gr.CheckboxGroup(["Prompt", "Negative prompt", "Seed", "CFG scale", "Sampler", "Steps", "Model hash"], label="Parameters to take from png info", info="Prompts from png info will be appended to prompts set in ui.") + + img2img_tabs = [tab_img2img, tab_sketch, tab_inpaint, tab_inpaint_color, tab_inpaint_upload, tab_batch] + + for i, tab in enumerate(img2img_tabs): + tab.select(fn=lambda tabnum=i: tabnum, inputs=[], outputs=[img2img_selected_tab]) + + def copy_image(img): + if isinstance(img, dict) and 'image' in img: + return img['image'] + + return img + + for button, name, elem in copy_image_buttons: + button.click( + fn=copy_image, + inputs=[elem], + outputs=[copy_image_destinations[name]], + ) + button.click( + fn=lambda: None, + _js=f"switch_to_{name.replace(' ', '_')}", + inputs=[], + outputs=[], ) - img2img_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs, elem_id="img2img_batch_input_dir") - img2img_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, elem_id="img2img_batch_output_dir") - img2img_batch_inpaint_mask_dir = gr.Textbox(label="Inpaint batch mask directory (required for inpaint batch processing only)", **shared.hide_dirs, elem_id="img2img_batch_inpaint_mask_dir") - with gr.Accordion("PNG info", open=False): - img2img_batch_use_png_info = gr.Checkbox(label="Append png info to prompts", **shared.hide_dirs, elem_id="img2img_batch_use_png_info") - img2img_batch_png_info_dir = gr.Textbox(label="PNG info directory", **shared.hide_dirs, placeholder="Leave empty to use input directory", elem_id="img2img_batch_png_info_dir") - img2img_batch_png_info_props = gr.CheckboxGroup(["Prompt", "Negative prompt", "Seed", "CFG scale", "Sampler", "Steps", "Model hash"], label="Parameters to take from png info", info="Prompts from png info will be appended to prompts set in ui.") - - img2img_tabs = [tab_img2img, tab_sketch, tab_inpaint, tab_inpaint_color, tab_inpaint_upload, tab_batch] - - for i, tab in enumerate(img2img_tabs): - tab.select(fn=lambda tabnum=i: tabnum, inputs=[], outputs=[img2img_selected_tab]) - - def copy_image(img): - if isinstance(img, dict) and 'image' in img: - return img['image'] - - return img - - for button, name, elem in copy_image_buttons: - button.click( - fn=copy_image, - inputs=[elem], - outputs=[copy_image_destinations[name]], - ) - button.click( - fn=lambda: None, - _js=f"switch_to_{name.replace(' ', '_')}", - inputs=[], - outputs=[], - ) - - with FormRow(): - resize_mode = gr.Radio(label="Resize mode", elem_id="resize_mode", choices=["Just resize", "Crop and resize", "Resize and fill", "Just resize (latent upscale)"], type="index", value="Just resize") - - scripts.scripts_img2img.prepare_ui() - - for category in ordered_ui_categories(): - if category == "sampler": - steps, sampler_name = create_sampler_and_steps_selection(sd_samplers.visible_sampler_names(), "img2img") - - elif category == "dimensions": - with FormRow(): - with gr.Column(elem_id="img2img_column_size", scale=4): - selected_scale_tab = gr.State(value=0) - - with gr.Tabs(): - with gr.Tab(label="Resize to", elem_id="img2img_tab_resize_to") as tab_scale_to: - with FormRow(): - with gr.Column(elem_id="img2img_column_size", scale=4): - width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="img2img_width") - height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="img2img_height") - with gr.Column(elem_id="img2img_dimensions_row", scale=1, elem_classes="dimensions-tools"): - res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="img2img_res_switch_btn", tooltip="Switch width/height") - detect_image_size_btn = ToolButton(value=detect_image_size_symbol, elem_id="img2img_detect_image_size_btn", tooltip="Auto detect size from img2img") - - with gr.Tab(label="Resize by", elem_id="img2img_tab_resize_by") as tab_scale_by: - scale_by = gr.Slider(minimum=0.05, maximum=4.0, step=0.05, label="Scale", value=1.0, elem_id="img2img_scale") - - with FormRow(): - scale_by_html = FormHTML(resize_from_to_html(0, 0, 0.0), elem_id="img2img_scale_resolution_preview") - gr.Slider(label="Unused", elem_id="img2img_unused_scale_by_slider") - button_update_resize_to = gr.Button(visible=False, elem_id="img2img_update_resize_to") - - on_change_args = dict( - fn=resize_from_to_html, - _js="currentImg2imgSourceResolution", - inputs=[dummy_component, dummy_component, scale_by], - outputs=scale_by_html, - show_progress=False, - ) - - scale_by.release(**on_change_args) - button_update_resize_to.click(**on_change_args) - - # the code below is meant to update the resolution label after the image in the image selection UI has changed. - # as it is now the event keeps firing continuously for inpaint edits, which ruins the page with constant requests. - # I assume this must be a gradio bug and for now we'll just do it for non-inpaint inputs. - for component in [init_img, sketch]: - component.change(fn=lambda: None, _js="updateImg2imgResizeToTextAfterChangingImage", inputs=[], outputs=[], show_progress=False) - tab_scale_to.select(fn=lambda: 0, inputs=[], outputs=[selected_scale_tab]) - tab_scale_by.select(fn=lambda: 1, inputs=[], outputs=[selected_scale_tab]) + with FormRow(): + resize_mode = gr.Radio(label="Resize mode", elem_id="resize_mode", choices=["Just resize", "Crop and resize", "Resize and fill", "Just resize (latent upscale)"], type="index", value="Just resize") - if opts.dimensions_and_batch_together: - with gr.Column(elem_id="img2img_column_batch"): - batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="img2img_batch_count") - batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="img2img_batch_size") + scripts.scripts_img2img.prepare_ui() - elif category == "denoising": - denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.75, elem_id="img2img_denoising_strength") + for category in ordered_ui_categories(): + if category == "sampler": + steps, sampler_name = create_sampler_and_steps_selection(sd_samplers.visible_sampler_names(), "img2img") - elif category == "cfg": - with gr.Row(): - cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0, elem_id="img2img_cfg_scale") - image_cfg_scale = gr.Slider(minimum=0, maximum=3.0, step=0.05, label='Image CFG Scale', value=1.5, elem_id="img2img_image_cfg_scale", visible=False) + elif category == "dimensions": + with FormRow(): + with gr.Column(elem_id="img2img_column_size", scale=4): + selected_scale_tab = gr.State(value=0) + + with gr.Tabs(): + with gr.Tab(label="Resize to", elem_id="img2img_tab_resize_to") as tab_scale_to: + with FormRow(): + with gr.Column(elem_id="img2img_column_size", scale=4): + width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="img2img_width") + height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="img2img_height") + with gr.Column(elem_id="img2img_dimensions_row", scale=1, elem_classes="dimensions-tools"): + res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="img2img_res_switch_btn", tooltip="Switch width/height") + detect_image_size_btn = ToolButton(value=detect_image_size_symbol, elem_id="img2img_detect_image_size_btn", tooltip="Auto detect size from img2img") + + with gr.Tab(label="Resize by", elem_id="img2img_tab_resize_by") as tab_scale_by: + scale_by = gr.Slider(minimum=0.05, maximum=4.0, step=0.05, label="Scale", value=1.0, elem_id="img2img_scale") + + with FormRow(): + scale_by_html = FormHTML(resize_from_to_html(0, 0, 0.0), elem_id="img2img_scale_resolution_preview") + gr.Slider(label="Unused", elem_id="img2img_unused_scale_by_slider") + button_update_resize_to = gr.Button(visible=False, elem_id="img2img_update_resize_to") + + on_change_args = dict( + fn=resize_from_to_html, + _js="currentImg2imgSourceResolution", + inputs=[dummy_component, dummy_component, scale_by], + outputs=scale_by_html, + show_progress=False, + ) + + scale_by.release(**on_change_args) + button_update_resize_to.click(**on_change_args) + + # the code below is meant to update the resolution label after the image in the image selection UI has changed. + # as it is now the event keeps firing continuously for inpaint edits, which ruins the page with constant requests. + # I assume this must be a gradio bug and for now we'll just do it for non-inpaint inputs. + for component in [init_img, sketch]: + component.change(fn=lambda: None, _js="updateImg2imgResizeToTextAfterChangingImage", inputs=[], outputs=[], show_progress=False) + + tab_scale_to.select(fn=lambda: 0, inputs=[], outputs=[selected_scale_tab]) + tab_scale_by.select(fn=lambda: 1, inputs=[], outputs=[selected_scale_tab]) + + if opts.dimensions_and_batch_together: + with gr.Column(elem_id="img2img_column_batch"): + batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="img2img_batch_count") + batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="img2img_batch_size") + + elif category == "denoising": + denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.75, elem_id="img2img_denoising_strength") + + elif category == "cfg": + with gr.Row(): + cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0, elem_id="img2img_cfg_scale") + image_cfg_scale = gr.Slider(minimum=0, maximum=3.0, step=0.05, label='Image CFG Scale', value=1.5, elem_id="img2img_image_cfg_scale", visible=False) - elif category == "checkboxes": - with FormRow(elem_classes="checkboxes-row", variant="compact"): - pass + elif category == "checkboxes": + with FormRow(elem_classes="checkboxes-row", variant="compact"): + pass - elif category == "accordions": - with gr.Row(elem_id="img2img_accordions", elem_classes="accordions"): - scripts.scripts_img2img.setup_ui_for_section(category) + elif category == "accordions": + with gr.Row(elem_id="img2img_accordions", elem_classes="accordions"): + scripts.scripts_img2img.setup_ui_for_section(category) - elif category == "batch": - if not opts.dimensions_and_batch_together: - with FormRow(elem_id="img2img_column_batch"): - batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="img2img_batch_count") - batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="img2img_batch_size") + elif category == "batch": + if not opts.dimensions_and_batch_together: + with FormRow(elem_id="img2img_column_batch"): + batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="img2img_batch_count") + batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="img2img_batch_size") - elif category == "override_settings": - with FormRow(elem_id="img2img_override_settings_row") as row: - override_settings = create_override_settings_dropdown('img2img', row) + elif category == "override_settings": + with FormRow(elem_id="img2img_override_settings_row") as row: + override_settings = create_override_settings_dropdown('img2img', row) - elif category == "scripts": - with FormGroup(elem_id="img2img_script_container"): - custom_inputs = scripts.scripts_img2img.setup_ui() + elif category == "scripts": + with FormGroup(elem_id="img2img_script_container"): + custom_inputs = scripts.scripts_img2img.setup_ui() - elif category == "inpaint": - with FormGroup(elem_id="inpaint_controls", visible=False) as inpaint_controls: - with FormRow(): - mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=4, elem_id="img2img_mask_blur") - mask_alpha = gr.Slider(label="Mask transparency", visible=False, elem_id="img2img_mask_alpha") + elif category == "inpaint": + with FormGroup(elem_id="inpaint_controls", visible=False) as inpaint_controls: + with FormRow(): + mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=4, elem_id="img2img_mask_blur") + mask_alpha = gr.Slider(label="Mask transparency", visible=False, elem_id="img2img_mask_alpha") - with FormRow(): - inpainting_mask_invert = gr.Radio(label='Mask mode', choices=['Inpaint masked', 'Inpaint not masked'], value='Inpaint masked', type="index", elem_id="img2img_mask_mode") + with FormRow(): + inpainting_mask_invert = gr.Radio(label='Mask mode', choices=['Inpaint masked', 'Inpaint not masked'], value='Inpaint masked', type="index", elem_id="img2img_mask_mode") - with FormRow(): - inpainting_fill = gr.Radio(label='Masked content', choices=['fill', 'original', 'latent noise', 'latent nothing'], value='original', type="index", elem_id="img2img_inpainting_fill") + with FormRow(): + inpainting_fill = gr.Radio(label='Masked content', choices=['fill', 'original', 'latent noise', 'latent nothing'], value='original', type="index", elem_id="img2img_inpainting_fill") - with FormRow(): - with gr.Column(): - inpaint_full_res = gr.Radio(label="Inpaint area", choices=["Whole picture", "Only masked"], type="index", value="Whole picture", elem_id="img2img_inpaint_full_res") + with FormRow(): + with gr.Column(): + inpaint_full_res = gr.Radio(label="Inpaint area", choices=["Whole picture", "Only masked"], type="index", value="Whole picture", elem_id="img2img_inpaint_full_res") - with gr.Column(scale=4): - inpaint_full_res_padding = gr.Slider(label='Only masked padding, pixels', minimum=0, maximum=256, step=4, value=32, elem_id="img2img_inpaint_full_res_padding") + with gr.Column(scale=4): + inpaint_full_res_padding = gr.Slider(label='Only masked padding, pixels', minimum=0, maximum=256, step=4, value=32, elem_id="img2img_inpaint_full_res_padding") - def select_img2img_tab(tab): - return gr.update(visible=tab in [2, 3, 4]), gr.update(visible=tab == 3), + def select_img2img_tab(tab): + return gr.update(visible=tab in [2, 3, 4]), gr.update(visible=tab == 3), - for i, elem in enumerate(img2img_tabs): - elem.select( - fn=lambda tab=i: select_img2img_tab(tab), - inputs=[], - outputs=[inpaint_controls, mask_alpha], - ) + for i, elem in enumerate(img2img_tabs): + elem.select( + fn=lambda tab=i: select_img2img_tab(tab), + inputs=[], + outputs=[inpaint_controls, mask_alpha], + ) - if category not in {"accordions"}: - scripts.scripts_img2img.setup_ui_for_section(category) + if category not in {"accordions"}: + scripts.scripts_img2img.setup_ui_for_section(category) img2img_gallery, generation_info, html_info, html_log = create_output_panel("img2img", opts.outdir_img2img_samples) From 8052a4971e1be48e1df2535284a7791cd1ad39ae Mon Sep 17 00:00:00 2001 From: missionfloyd Date: Fri, 3 Nov 2023 00:59:19 -0600 Subject: [PATCH 0220/1174] Fix parenthesis auto selection Fixes #13813 --- javascript/edit-attention.js | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/javascript/edit-attention.js b/javascript/edit-attention.js index 04464100699..688c2f112d6 100644 --- a/javascript/edit-attention.js +++ b/javascript/edit-attention.js @@ -28,7 +28,7 @@ function keyupEditAttention(event) { if (afterParen == -1) return false; let afterOpeningParen = after.indexOf(OPEN); - if (afterOpeningParen != -1 && afterOpeningParen < beforeParen) return false; + if (afterOpeningParen != -1 && afterOpeningParen < afterParen) return false; // Set the selection to the text between the parenthesis const parenContent = text.substring(beforeParen + 1, selectionStart + afterParen); From cc80a09d82afae793800a033a1f525f5dc797cff Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Sat, 4 Nov 2023 00:50:30 +0900 Subject: [PATCH 0221/1174] Update requirements_versions.txt --- requirements_versions.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements_versions.txt b/requirements_versions.txt index 7d27f2be3bb..cb7403a9d4b 100644 --- a/requirements_versions.txt +++ b/requirements_versions.txt @@ -29,3 +29,4 @@ torch torchdiffeq==0.2.3 torchsde==0.2.6 transformers==4.30.2 +httpx==0.24.1 From bda2ecdbf58fd33b4ad3036ed5cc13eef02747ae Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Fri, 3 Nov 2023 19:44:57 +0300 Subject: [PATCH 0222/1174] Merge pull request #13839 from AUTOMATIC1111/httpx==0.24.1 requirements_versions httpx==0.24.1 --- requirements_versions.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements_versions.txt b/requirements_versions.txt index f8ae1f385ae..e84bd4270a8 100644 --- a/requirements_versions.txt +++ b/requirements_versions.txt @@ -29,3 +29,4 @@ torch torchdiffeq==0.2.3 torchsde==0.2.5 transformers==4.30.2 +httpx==0.24.1 From 4afaaf8a020c1df457bcf7250cb1c7f609699fa7 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Fri, 3 Nov 2023 19:50:14 +0300 Subject: [PATCH 0223/1174] add changelog entry --- CHANGELOG.md | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 1cd3572c8e0..2c72359fc42 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,8 @@ +## 1.6.1 + +### Bug Fixes: + * fix an error causing the webui to fail to start ([#13839](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13839)) + ## 1.6.0 ### Features: From fe1967a4c4a02eccfa45b65ee19a5b0773ced31c Mon Sep 17 00:00:00 2001 From: v0xie <28695009+v0xie@users.noreply.github.com> Date: Fri, 3 Nov 2023 17:52:55 -0700 Subject: [PATCH 0224/1174] skip multihead attn for now --- extensions-builtin/Lora/network_oft.py | 54 ++++++++++++++++++-------- 1 file changed, 37 insertions(+), 17 deletions(-) diff --git a/extensions-builtin/Lora/network_oft.py b/extensions-builtin/Lora/network_oft.py index e102eafc1f8..979a20476ad 100644 --- a/extensions-builtin/Lora/network_oft.py +++ b/extensions-builtin/Lora/network_oft.py @@ -18,6 +18,7 @@ def __init__(self, net: network.Network, weights: network.NetworkWeights): super().__init__(net, weights) self.lin_module = None + self.org_module: list[torch.Module] = [self.sd_module] # kohya-ss if "oft_blocks" in weights.w.keys(): self.is_kohya = True @@ -30,7 +31,7 @@ def __init__(self, net: network.Network, weights: network.NetworkWeights): # alpha is rank if alpha is 0 or None if self.alpha is None: pass - self.dim = self.oft_blocks.shape[0] # FIXME: almost certainly incorrect, assumes tensor is shape [*, m, n] + self.dim = self.oft_blocks.shape[1] # FIXME: almost certainly incorrect, assumes tensor is shape [*, m, n] else: raise ValueError("oft_blocks or oft_diag must be in weights dict") @@ -46,6 +47,12 @@ def __init__(self, net: network.Network, weights: network.NetworkWeights): # raise ValueError("Linear sd_module must have out_features or embed_dim") elif is_other_linear: self.out_dim = self.sd_module.embed_dim + #self.org_weight = self.org_module[0].weight +# if hasattr(self.sd_module, "in_proj_weight"): +# self.in_proj_dim = self.sd_module.in_proj_weight.shape[1] +# if hasattr(self.sd_module, "out_proj_weight"): +# self.out_proj_dim = self.sd_module.out_proj_weight.shape[0] +# self.in_proj_dim = self.sd_module.in_proj_weight.shape[1] elif is_conv: self.out_dim = self.sd_module.out_channels else: @@ -58,10 +65,9 @@ def __init__(self, net: network.Network, weights: network.NetworkWeights): self.constraint = self.alpha * self.out_dim #elif is_linear or is_conv: else: - self.num_blocks, self.block_size = factorization(self.out_dim, self.dim) + self.block_size, self.num_blocks = factorization(self.out_dim, self.dim) self.constraint = None - self.org_module: list[torch.Module] = [self.sd_module] # if is_other_linear: # weight = self.oft_blocks.reshape(self.oft_blocks.shape[0], -1) @@ -110,25 +116,39 @@ def get_weight(self, oft_blocks, multiplier=None): def calc_updown(self, orig_weight): multiplier = self.multiplier() * self.calc_scale() - R = self.get_weight(self.oft_blocks, multiplier) - #R = self.oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype) - merged_weight = self.merge_weight(R, orig_weight) + is_other_linear = type(self.sd_module) in [ torch.nn.MultiheadAttention] + if self.is_kohya and not is_other_linear: + R = self.get_weight(self.oft_blocks, multiplier) + #R = self.oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype) + merged_weight = self.merge_weight(R, orig_weight) + elif not self.is_kohya and not is_other_linear: + if is_other_linear and orig_weight.shape[0] != orig_weight.shape[1]: + orig_weight=orig_weight.permute(1, 0) + R = self.oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype) + merged_weight = rearrange(orig_weight, '(k n) ... -> k n ...', k=self.num_blocks, n=self.block_size) + #orig_weight = rearrange(orig_weight, '(k n) ... -> k n ...', k=self.block_size, n=self.num_blocks) + merged_weight = torch.einsum( + 'k n m, k n ... -> k m ...', + R * multiplier + torch.eye(self.block_size, device=orig_weight.device), + merged_weight + ) + merged_weight = rearrange(merged_weight, 'k m ... -> (k m) ...') + if is_other_linear and orig_weight.shape[0] != orig_weight.shape[1]: + orig_weight=orig_weight.permute(1, 0) + #merged_weight=merged_weight.permute(1, 0) + updown = merged_weight.to(orig_weight.device, dtype=orig_weight.dtype) - orig_weight + #updown = weight.to(orig_weight.device, dtype=orig_weight.dtype) - orig_weight + output_shape = orig_weight.shape + else: + # skip for now + updown = torch.zeros([orig_weight.shape[1], orig_weight.shape[1]], device=orig_weight.device, dtype=orig_weight.dtype) + output_shape = (orig_weight.shape[1], orig_weight.shape[1]) #if self.lin_module is not None: # R = self.lin_module.weight.to(orig_weight.device, dtype=orig_weight.dtype) # weight = torch.mul(torch.mul(R, multiplier), orig_weight) #else: - # orig_weight = rearrange(orig_weight, '(k n) ... -> k n ...', k=self.num_blocks, n=self.block_size) - # weight = torch.einsum( - # 'k n m, k n ... -> k m ...', - # R * multiplier + torch.eye(self.block_size, device=orig_weight.device), - # orig_weight - # ) - # weight = rearrange(weight, 'k m ... -> (k m) ...') - - updown = merged_weight.to(orig_weight.device, dtype=orig_weight.dtype) - orig_weight - #updown = weight.to(orig_weight.device, dtype=orig_weight.dtype) - orig_weight - output_shape = orig_weight.shape + orig_weight = orig_weight return self.finalize_updown(updown, orig_weight, output_shape) From f6c8201e5663ca2182a66c8eca63ce4801d52849 Mon Sep 17 00:00:00 2001 From: v0xie <28695009+v0xie@users.noreply.github.com> Date: Fri, 3 Nov 2023 19:35:15 -0700 Subject: [PATCH 0225/1174] refactor: move factorization to lyco_helpers, separate calc_updown for kohya and kb --- extensions-builtin/Lora/lyco_helpers.py | 47 +++++++++ extensions-builtin/Lora/network_oft.py | 131 ++++++------------------ 2 files changed, 77 insertions(+), 101 deletions(-) diff --git a/extensions-builtin/Lora/lyco_helpers.py b/extensions-builtin/Lora/lyco_helpers.py index 279b34bc928..1679a0ce633 100644 --- a/extensions-builtin/Lora/lyco_helpers.py +++ b/extensions-builtin/Lora/lyco_helpers.py @@ -19,3 +19,50 @@ def rebuild_cp_decomposition(up, down, mid): up = up.reshape(up.size(0), -1) down = down.reshape(down.size(0), -1) return torch.einsum('n m k l, i n, m j -> i j k l', mid, up, down) + + +# copied from https://github.com/KohakuBlueleaf/LyCORIS/blob/dev/lycoris/modules/lokr.py +def factorization(dimension: int, factor:int=-1) -> tuple[int, int]: + ''' + return a tuple of two value of input dimension decomposed by the number closest to factor + second value is higher or equal than first value. + + In LoRA with Kroneckor Product, first value is a value for weight scale. + secon value is a value for weight. + + Becuase of non-commutative property, A⊗B ≠ B⊗A. Meaning of two matrices is slightly different. + + examples) + factor + -1 2 4 8 16 ... + 127 -> 1, 127 127 -> 1, 127 127 -> 1, 127 127 -> 1, 127 127 -> 1, 127 + 128 -> 8, 16 128 -> 2, 64 128 -> 4, 32 128 -> 8, 16 128 -> 8, 16 + 250 -> 10, 25 250 -> 2, 125 250 -> 2, 125 250 -> 5, 50 250 -> 10, 25 + 360 -> 8, 45 360 -> 2, 180 360 -> 4, 90 360 -> 8, 45 360 -> 12, 30 + 512 -> 16, 32 512 -> 2, 256 512 -> 4, 128 512 -> 8, 64 512 -> 16, 32 + 1024 -> 32, 32 1024 -> 2, 512 1024 -> 4, 256 1024 -> 8, 128 1024 -> 16, 64 + ''' + + if factor > 0 and (dimension % factor) == 0: + m = factor + n = dimension // factor + if m > n: + n, m = m, n + return m, n + if factor < 0: + factor = dimension + m, n = 1, dimension + length = m + n + while m length or new_m>factor: + break + else: + m, n = new_m, new_n + if m > n: + n, m = m, n + return m, n + diff --git a/extensions-builtin/Lora/network_oft.py b/extensions-builtin/Lora/network_oft.py index 979a20476ad..2be67fe5305 100644 --- a/extensions-builtin/Lora/network_oft.py +++ b/extensions-builtin/Lora/network_oft.py @@ -1,7 +1,7 @@ import torch import network +from lyco_helpers import factorization from einops import rearrange -from modules import devices class ModuleTypeOFT(network.ModuleType): @@ -11,7 +11,8 @@ def create_module(self, net: network.Network, weights: network.NetworkWeights): return None -# adapted from kohya's implementation https://github.com/kohya-ss/sd-scripts/blob/main/networks/oft.py +# adapted from kohya-ss' implementation https://github.com/kohya-ss/sd-scripts/blob/main/networks/oft.py +# and KohakuBlueleaf's implementation https://github.com/KohakuBlueleaf/LyCORIS/blob/dev/lycoris/modules/diag_oft.py class NetworkModuleOFT(network.NetworkModule): def __init__(self, net: network.Network, weights: network.NetworkWeights): @@ -19,6 +20,7 @@ def __init__(self, net: network.Network, weights: network.NetworkWeights): self.lin_module = None self.org_module: list[torch.Module] = [self.sd_module] + # kohya-ss if "oft_blocks" in weights.w.keys(): self.is_kohya = True @@ -37,61 +39,31 @@ def __init__(self, net: network.Network, weights: network.NetworkWeights): is_linear = type(self.sd_module) in [torch.nn.Linear, torch.nn.modules.linear.NonDynamicallyQuantizableLinear] is_conv = type(self.sd_module) in [torch.nn.Conv2d] - is_other_linear = type(self.sd_module) in [ torch.nn.MultiheadAttention] - #if "Linear" in self.sd_module.__class__.__name__ or is_linear: + is_other_linear = type(self.sd_module) in [torch.nn.MultiheadAttention] + if is_linear: self.out_dim = self.sd_module.out_features - #elif hasattr(self.sd_module, "embed_dim"): - # self.out_dim = self.sd_module.embed_dim - #else: - # raise ValueError("Linear sd_module must have out_features or embed_dim") elif is_other_linear: self.out_dim = self.sd_module.embed_dim - #self.org_weight = self.org_module[0].weight -# if hasattr(self.sd_module, "in_proj_weight"): -# self.in_proj_dim = self.sd_module.in_proj_weight.shape[1] -# if hasattr(self.sd_module, "out_proj_weight"): -# self.out_proj_dim = self.sd_module.out_proj_weight.shape[0] -# self.in_proj_dim = self.sd_module.in_proj_weight.shape[1] elif is_conv: self.out_dim = self.sd_module.out_channels else: raise ValueError("sd_module must be Linear or Conv") - if self.is_kohya: self.num_blocks = self.dim self.block_size = self.out_dim // self.num_blocks self.constraint = self.alpha * self.out_dim - #elif is_linear or is_conv: else: self.block_size, self.num_blocks = factorization(self.out_dim, self.dim) self.constraint = None - - # if is_other_linear: - # weight = self.oft_blocks.reshape(self.oft_blocks.shape[0], -1) - # module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False) - # with torch.no_grad(): - # if weight.shape != module.weight.shape: - # weight = weight.reshape(module.weight.shape) - # module.weight.copy_(weight) - # module.to(device=devices.cpu, dtype=devices.dtype) - # module.weight.requires_grad_(False) - # self.lin_module = module - #return module - def merge_weight(self, R_weight, org_weight): R_weight = R_weight.to(org_weight.device, dtype=org_weight.dtype) if org_weight.dim() == 4: weight = torch.einsum("oihw, op -> pihw", org_weight, R_weight) else: weight = torch.einsum("oi, op -> pi", org_weight, R_weight) - #weight = torch.einsum( - # "k n m, k n ... -> k m ...", - # self.oft_diag * scale + torch.eye(self.block_size, device=device), - # org_weight - #) return weight def get_weight(self, oft_blocks, multiplier=None): @@ -111,48 +83,51 @@ def get_weight(self, oft_blocks, multiplier=None): block_R_weighted = multiplier * block_R + (1 - multiplier) * m_I R = torch.block_diag(*block_R_weighted) return R - #return self.oft_blocks + def calc_updown_kohya(self, orig_weight, multiplier): + R = self.get_weight(self.oft_blocks, multiplier) + merged_weight = self.merge_weight(R, orig_weight) - def calc_updown(self, orig_weight): - multiplier = self.multiplier() * self.calc_scale() - is_other_linear = type(self.sd_module) in [ torch.nn.MultiheadAttention] - if self.is_kohya and not is_other_linear: - R = self.get_weight(self.oft_blocks, multiplier) - #R = self.oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype) - merged_weight = self.merge_weight(R, orig_weight) - elif not self.is_kohya and not is_other_linear: + updown = merged_weight.to(orig_weight.device, dtype=orig_weight.dtype) - orig_weight + output_shape = orig_weight.shape + orig_weight = orig_weight + return self.finalize_updown(updown, orig_weight, output_shape) + + def calc_updown_kb(self, orig_weight, multiplier): + is_other_linear = type(self.sd_module) in [torch.nn.MultiheadAttention] + + if not is_other_linear: if is_other_linear and orig_weight.shape[0] != orig_weight.shape[1]: orig_weight=orig_weight.permute(1, 0) + R = self.oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype) merged_weight = rearrange(orig_weight, '(k n) ... -> k n ...', k=self.num_blocks, n=self.block_size) - #orig_weight = rearrange(orig_weight, '(k n) ... -> k n ...', k=self.block_size, n=self.num_blocks) merged_weight = torch.einsum( 'k n m, k n ... -> k m ...', R * multiplier + torch.eye(self.block_size, device=orig_weight.device), - merged_weight + merged_weight ) merged_weight = rearrange(merged_weight, 'k m ... -> (k m) ...') + if is_other_linear and orig_weight.shape[0] != orig_weight.shape[1]: orig_weight=orig_weight.permute(1, 0) - #merged_weight=merged_weight.permute(1, 0) + updown = merged_weight.to(orig_weight.device, dtype=orig_weight.dtype) - orig_weight - #updown = weight.to(orig_weight.device, dtype=orig_weight.dtype) - orig_weight output_shape = orig_weight.shape else: - # skip for now + # FIXME: skip MultiheadAttention for now updown = torch.zeros([orig_weight.shape[1], orig_weight.shape[1]], device=orig_weight.device, dtype=orig_weight.dtype) output_shape = (orig_weight.shape[1], orig_weight.shape[1]) - #if self.lin_module is not None: - # R = self.lin_module.weight.to(orig_weight.device, dtype=orig_weight.dtype) - # weight = torch.mul(torch.mul(R, multiplier), orig_weight) - #else: - - orig_weight = orig_weight - return self.finalize_updown(updown, orig_weight, output_shape) + def calc_updown(self, orig_weight): + multiplier = self.multiplier() * self.calc_scale() + if self.is_kohya: + return self.calc_updown_kohya(orig_weight, multiplier) + else: + return self.calc_updown_kb(orig_weight, multiplier) + # override to remove the multiplier/scale factor; it's already multiplied in get_weight def finalize_updown(self, updown, orig_weight, output_shape, ex_bias=None): #return super().finalize_updown(updown, orig_weight, output_shape, ex_bias) @@ -172,49 +147,3 @@ def finalize_updown(self, updown, orig_weight, output_shape, ex_bias=None): ex_bias = ex_bias * self.multiplier() return updown, ex_bias - -# copied from https://github.com/KohakuBlueleaf/LyCORIS/blob/dev/lycoris/modules/lokr.py -def factorization(dimension: int, factor:int=-1) -> tuple[int, int]: - ''' - return a tuple of two value of input dimension decomposed by the number closest to factor - second value is higher or equal than first value. - - In LoRA with Kroneckor Product, first value is a value for weight scale. - secon value is a value for weight. - - Becuase of non-commutative property, A⊗B ≠ B⊗A. Meaning of two matrices is slightly different. - - examples) - factor - -1 2 4 8 16 ... - 127 -> 1, 127 127 -> 1, 127 127 -> 1, 127 127 -> 1, 127 127 -> 1, 127 - 128 -> 8, 16 128 -> 2, 64 128 -> 4, 32 128 -> 8, 16 128 -> 8, 16 - 250 -> 10, 25 250 -> 2, 125 250 -> 2, 125 250 -> 5, 50 250 -> 10, 25 - 360 -> 8, 45 360 -> 2, 180 360 -> 4, 90 360 -> 8, 45 360 -> 12, 30 - 512 -> 16, 32 512 -> 2, 256 512 -> 4, 128 512 -> 8, 64 512 -> 16, 32 - 1024 -> 32, 32 1024 -> 2, 512 1024 -> 4, 256 1024 -> 8, 128 1024 -> 16, 64 - ''' - - if factor > 0 and (dimension % factor) == 0: - m = factor - n = dimension // factor - if m > n: - n, m = m, n - return m, n - if factor < 0: - factor = dimension - m, n = 1, dimension - length = m + n - while m length or new_m>factor: - break - else: - m, n = new_m, new_n - if m > n: - n, m = m, n - return m, n - From 329c8bacce706811776e1c1c6a0d39b46886a268 Mon Sep 17 00:00:00 2001 From: v0xie <28695009+v0xie@users.noreply.github.com> Date: Sat, 4 Nov 2023 14:54:36 -0700 Subject: [PATCH 0226/1174] refactor: use same updown for both kohya OFT and LyCORIS diag-oft --- extensions-builtin/Lora/network_oft.py | 91 +++++++++++++++++++++----- 1 file changed, 74 insertions(+), 17 deletions(-) diff --git a/extensions-builtin/Lora/network_oft.py b/extensions-builtin/Lora/network_oft.py index 2be67fe5305..e4aa082b7f5 100644 --- a/extensions-builtin/Lora/network_oft.py +++ b/extensions-builtin/Lora/network_oft.py @@ -2,6 +2,7 @@ import network from lyco_helpers import factorization from einops import rearrange +from modules import devices class ModuleTypeOFT(network.ModuleType): @@ -24,12 +25,14 @@ def __init__(self, net: network.Network, weights: network.NetworkWeights): # kohya-ss if "oft_blocks" in weights.w.keys(): self.is_kohya = True - self.oft_blocks = weights.w["oft_blocks"] + self.oft_blocks = weights.w["oft_blocks"] # (num_blocks, block_size, block_size) self.alpha = weights.w["alpha"] - self.dim = self.oft_blocks.shape[0] + self.dim = self.oft_blocks.shape[0] # lora dim + #self.oft_blocks = rearrange(self.oft_blocks, 'k m ... -> (k m) ...') elif "oft_diag" in weights.w.keys(): self.is_kohya = False - self.oft_blocks = weights.w["oft_diag"] + self.oft_blocks = weights.w["oft_diag"] # (num_blocks, block_size, block_size) + # alpha is rank if alpha is 0 or None if self.alpha is None: pass @@ -51,12 +54,57 @@ def __init__(self, net: network.Network, weights: network.NetworkWeights): raise ValueError("sd_module must be Linear or Conv") if self.is_kohya: - self.num_blocks = self.dim - self.block_size = self.out_dim // self.num_blocks + #self.num_blocks = self.dim + #self.block_size = self.out_dim // self.num_blocks + #self.block_size = self.dim + #self.num_blocks = self.out_dim // self.block_size self.constraint = self.alpha * self.out_dim + self.num_blocks, self.block_size = factorization(self.out_dim, self.dim) else: - self.block_size, self.num_blocks = factorization(self.out_dim, self.dim) self.constraint = None + self.block_size, self.num_blocks = factorization(self.out_dim, self.dim) + + if is_other_linear: + self.lin_module = self.create_module(weights.w, "oft_diag", none_ok=True) + + + def create_module(self, weights, key, none_ok=False): + weight = weights.get(key) + + if weight is None and none_ok: + return None + + is_linear = type(self.sd_module) in [torch.nn.Linear, torch.nn.modules.linear.NonDynamicallyQuantizableLinear, torch.nn.MultiheadAttention] + is_conv = type(self.sd_module) in [torch.nn.Conv2d] + + if is_linear: + weight = weight.reshape(weight.shape[0], -1) + module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False) + elif is_conv and key == "lora_down.weight" or key == "dyn_up": + if len(weight.shape) == 2: + weight = weight.reshape(weight.shape[0], -1, 1, 1) + + if weight.shape[2] != 1 or weight.shape[3] != 1: + module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], self.sd_module.kernel_size, self.sd_module.stride, self.sd_module.padding, bias=False) + else: + module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (1, 1), bias=False) + elif is_conv and key == "lora_mid.weight": + module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], self.sd_module.kernel_size, self.sd_module.stride, self.sd_module.padding, bias=False) + elif is_conv and key == "lora_up.weight" or key == "dyn_down": + module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (1, 1), bias=False) + else: + raise AssertionError(f'Lora layer {self.network_key} matched a layer with unsupported type: {type(self.sd_module).__name__}') + + with torch.no_grad(): + if weight.shape != module.weight.shape: + weight = weight.reshape(module.weight.shape) + module.weight.copy_(weight) + + module.to(device=devices.cpu, dtype=devices.dtype) + module.weight.requires_grad_(False) + + return module + def merge_weight(self, R_weight, org_weight): R_weight = R_weight.to(org_weight.device, dtype=org_weight.dtype) @@ -77,7 +125,8 @@ def get_weight(self, oft_blocks, multiplier=None): else: new_norm_Q = norm_Q block_Q = block_Q * ((new_norm_Q + 1e-8) / (norm_Q + 1e-8)) - m_I = torch.eye(self.block_size, device=oft_blocks.device).unsqueeze(0).repeat(self.num_blocks, 1, 1) + m_I = torch.eye(self.num_blocks, device=oft_blocks.device).unsqueeze(0).repeat(self.block_size, 1, 1) + #m_I = torch.eye(self.block_size, device=oft_blocks.device).unsqueeze(0).repeat(self.num_blocks, 1, 1) block_R = torch.matmul(m_I + block_Q, (m_I - block_Q).inverse()) block_R_weighted = multiplier * block_R + (1 - multiplier) * m_I @@ -97,25 +146,33 @@ def calc_updown_kb(self, orig_weight, multiplier): is_other_linear = type(self.sd_module) in [torch.nn.MultiheadAttention] if not is_other_linear: - if is_other_linear and orig_weight.shape[0] != orig_weight.shape[1]: - orig_weight=orig_weight.permute(1, 0) + #if is_other_linear and orig_weight.shape[0] != orig_weight.shape[1]: + # orig_weight=orig_weight.permute(1, 0) + + oft_blocks = self.oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype) + + # without this line the results are significantly worse / less accurate + oft_blocks = oft_blocks - oft_blocks.transpose(1, 2) + + R = oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype) + R = R * multiplier + torch.eye(self.block_size, device=orig_weight.device) - R = self.oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype) merged_weight = rearrange(orig_weight, '(k n) ... -> k n ...', k=self.num_blocks, n=self.block_size) merged_weight = torch.einsum( 'k n m, k n ... -> k m ...', - R * multiplier + torch.eye(self.block_size, device=orig_weight.device), + R, merged_weight ) merged_weight = rearrange(merged_weight, 'k m ... -> (k m) ...') - if is_other_linear and orig_weight.shape[0] != orig_weight.shape[1]: - orig_weight=orig_weight.permute(1, 0) + #if is_other_linear and orig_weight.shape[0] != orig_weight.shape[1]: + # orig_weight=orig_weight.permute(1, 0) updown = merged_weight.to(orig_weight.device, dtype=orig_weight.dtype) - orig_weight output_shape = orig_weight.shape else: # FIXME: skip MultiheadAttention for now + #up = self.lin_module.weight.to(orig_weight.device, dtype=orig_weight.dtype) updown = torch.zeros([orig_weight.shape[1], orig_weight.shape[1]], device=orig_weight.device, dtype=orig_weight.dtype) output_shape = (orig_weight.shape[1], orig_weight.shape[1]) @@ -123,10 +180,10 @@ def calc_updown_kb(self, orig_weight, multiplier): def calc_updown(self, orig_weight): multiplier = self.multiplier() * self.calc_scale() - if self.is_kohya: - return self.calc_updown_kohya(orig_weight, multiplier) - else: - return self.calc_updown_kb(orig_weight, multiplier) + #if self.is_kohya: + # return self.calc_updown_kohya(orig_weight, multiplier) + #else: + return self.calc_updown_kb(orig_weight, multiplier) # override to remove the multiplier/scale factor; it's already multiplied in get_weight def finalize_updown(self, updown, orig_weight, output_shape, ex_bias=None): From bbf00a96afb2215f13cc72a7908225ae300c423d Mon Sep 17 00:00:00 2001 From: v0xie <28695009+v0xie@users.noreply.github.com> Date: Sat, 4 Nov 2023 14:56:47 -0700 Subject: [PATCH 0227/1174] refactor: remove unused function --- extensions-builtin/Lora/network_oft.py | 47 -------------------------- 1 file changed, 47 deletions(-) diff --git a/extensions-builtin/Lora/network_oft.py b/extensions-builtin/Lora/network_oft.py index e4aa082b7f5..93402bb28f1 100644 --- a/extensions-builtin/Lora/network_oft.py +++ b/extensions-builtin/Lora/network_oft.py @@ -2,7 +2,6 @@ import network from lyco_helpers import factorization from einops import rearrange -from modules import devices class ModuleTypeOFT(network.ModuleType): @@ -54,58 +53,12 @@ def __init__(self, net: network.Network, weights: network.NetworkWeights): raise ValueError("sd_module must be Linear or Conv") if self.is_kohya: - #self.num_blocks = self.dim - #self.block_size = self.out_dim // self.num_blocks - #self.block_size = self.dim - #self.num_blocks = self.out_dim // self.block_size self.constraint = self.alpha * self.out_dim self.num_blocks, self.block_size = factorization(self.out_dim, self.dim) else: self.constraint = None self.block_size, self.num_blocks = factorization(self.out_dim, self.dim) - if is_other_linear: - self.lin_module = self.create_module(weights.w, "oft_diag", none_ok=True) - - - def create_module(self, weights, key, none_ok=False): - weight = weights.get(key) - - if weight is None and none_ok: - return None - - is_linear = type(self.sd_module) in [torch.nn.Linear, torch.nn.modules.linear.NonDynamicallyQuantizableLinear, torch.nn.MultiheadAttention] - is_conv = type(self.sd_module) in [torch.nn.Conv2d] - - if is_linear: - weight = weight.reshape(weight.shape[0], -1) - module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False) - elif is_conv and key == "lora_down.weight" or key == "dyn_up": - if len(weight.shape) == 2: - weight = weight.reshape(weight.shape[0], -1, 1, 1) - - if weight.shape[2] != 1 or weight.shape[3] != 1: - module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], self.sd_module.kernel_size, self.sd_module.stride, self.sd_module.padding, bias=False) - else: - module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (1, 1), bias=False) - elif is_conv and key == "lora_mid.weight": - module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], self.sd_module.kernel_size, self.sd_module.stride, self.sd_module.padding, bias=False) - elif is_conv and key == "lora_up.weight" or key == "dyn_down": - module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (1, 1), bias=False) - else: - raise AssertionError(f'Lora layer {self.network_key} matched a layer with unsupported type: {type(self.sd_module).__name__}') - - with torch.no_grad(): - if weight.shape != module.weight.shape: - weight = weight.reshape(module.weight.shape) - module.weight.copy_(weight) - - module.to(device=devices.cpu, dtype=devices.dtype) - module.weight.requires_grad_(False) - - return module - - def merge_weight(self, R_weight, org_weight): R_weight = R_weight.to(org_weight.device, dtype=org_weight.dtype) if org_weight.dim() == 4: From 2b06cefe66684ed2648d3221efbc36aeaae99a2f Mon Sep 17 00:00:00 2001 From: gibiee <37574274+gibiee@users.noreply.github.com> Date: Sun, 5 Nov 2023 11:37:23 +0900 Subject: [PATCH 0228/1174] correct a typo modify "defaul" to "default" --- modules/cmd_args.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/cmd_args.py b/modules/cmd_args.py index 4e602a84270..a9fb9bfa3ef 100644 --- a/modules/cmd_args.py +++ b/modules/cmd_args.py @@ -107,7 +107,7 @@ parser.add_argument("--disable-tls-verify", action="store_false", help="When passed, enables the use of self-signed certificates.", default=None) parser.add_argument("--server-name", type=str, help="Sets hostname of server", default=None) parser.add_argument("--gradio-queue", action='store_true', help="does not do anything", default=True) -parser.add_argument("--no-gradio-queue", action='store_true', help="Disables gradio queue; causes the webpage to use http requests instead of websockets; was the defaul in earlier versions") +parser.add_argument("--no-gradio-queue", action='store_true', help="Disables gradio queue; causes the webpage to use http requests instead of websockets; was the default in earlier versions") parser.add_argument("--skip-version-check", action='store_true', help="Do not check versions of torch and xformers") parser.add_argument("--no-hashing", action='store_true', help="disable sha256 hashing of checkpoints to help loading performance", default=False) parser.add_argument("--no-download-sd-model", action='store_true', help="don't download SD1.5 model even if no model is found in --ckpt-dir", default=False) From 6b8c661c49796bba093ca8a8301e81d28afb9832 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sun, 5 Nov 2023 08:55:54 +0300 Subject: [PATCH 0229/1174] add a visible checkbox to input accordion --- javascript/inputAccordion.js | 79 ++++++++++++++++++++++++------------ style.css | 5 +++ 2 files changed, 59 insertions(+), 25 deletions(-) diff --git a/javascript/inputAccordion.js b/javascript/inputAccordion.js index f2839852ee7..8fc01230a92 100644 --- a/javascript/inputAccordion.js +++ b/javascript/inputAccordion.js @@ -1,37 +1,66 @@ -var observerAccordionOpen = new MutationObserver(function(mutations) { - mutations.forEach(function(mutationRecord) { - var elem = mutationRecord.target; - var open = elem.classList.contains('open'); +function inputAccordionChecked(id, checked) { + var accordion = gradioApp().getElementById(id); + accordion.visibleCheckbox.checked = checked; + accordion.onVisibleCheckboxChange(); +} - var accordion = elem.parentNode; - accordion.classList.toggle('input-accordion-open', open); +function setupAccordion(accordion){ + var labelWrap = accordion.querySelector('.label-wrap'); + var gradioCheckbox = gradioApp().querySelector('#' + accordion.id + "-checkbox input"); + var extra = gradioApp().querySelector('#' + accordion.id + "-extra"); + var span = labelWrap.querySelector('span'); + var linked = true; - var checkbox = gradioApp().querySelector('#' + accordion.id + "-checkbox input"); - checkbox.checked = open; - updateInput(checkbox); + var isOpen = function(){ return labelWrap.classList.contains('open'); } - var extra = gradioApp().querySelector('#' + accordion.id + "-extra"); - if (extra) { - extra.style.display = open ? "" : "none"; - } + var observerAccordionOpen = new MutationObserver(function(mutations) { + mutations.forEach(function(mutationRecord) { + accordion.classList.toggle('input-accordion-open', isOpen()); + + if(linked){ + accordion.visibleCheckbox.checked = isOpen(); + accordion.onVisibleCheckboxChange(); + } + }); }); -}); + observerAccordionOpen.observe(labelWrap, {attributes: true, attributeFilter: ['class']}); -function inputAccordionChecked(id, checked) { - var label = gradioApp().querySelector('#' + id + " .label-wrap"); - if (label.classList.contains('open') != checked) { - label.click(); + if (extra) { + labelWrap.insertBefore(extra, labelWrap.lastElementChild); + } + + accordion.onChecked = function(checked){ + if (isOpen() != checked) { + labelWrap.click(); + } } + + var visibleCheckbox = document.createElement('INPUT'); + visibleCheckbox.type = 'checkbox'; + visibleCheckbox.checked = isOpen(); + visibleCheckbox.id = accordion.id + "-visible-checkbox"; + visibleCheckbox.className = gradioCheckbox.className + " input-accordion-checkbox"; + span.insertBefore(visibleCheckbox, span.firstChild); + + accordion.visibleCheckbox = visibleCheckbox; + accordion.onVisibleCheckboxChange = function(){ + if(linked && isOpen() != visibleCheckbox.checked) { + labelWrap.click(); + } + + gradioCheckbox.checked = visibleCheckbox.checked; + updateInput(gradioCheckbox); + }; + + visibleCheckbox.addEventListener('click', function(event){ + linked = false; + event.stopPropagation(); + }); + visibleCheckbox.addEventListener('input', accordion.onVisibleCheckboxChange); } onUiLoaded(function() { for (var accordion of gradioApp().querySelectorAll('.input-accordion')) { - var labelWrap = accordion.querySelector('.label-wrap'); - observerAccordionOpen.observe(labelWrap, {attributes: true, attributeFilter: ['class']}); - - var extra = gradioApp().querySelector('#' + accordion.id + "-extra"); - if (extra) { - labelWrap.insertBefore(extra, labelWrap.lastElementChild); - } + setupAccordion(accordion); } }); diff --git a/style.css b/style.css index 115626cd9de..9a1181e8b33 100644 --- a/style.css +++ b/style.css @@ -204,6 +204,11 @@ div.block.gradio-accordion { padding: 8px 8px; } +input[type="checkbox"].input-accordion-checkbox{ + vertical-align: sub; + margin-right: 0.5em; +} + /* txt2img/img2img specific */ From 16ab17429016a1154b9aa83244cdbfc7ba463d72 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sun, 5 Nov 2023 09:20:05 +0300 Subject: [PATCH 0230/1174] eslint --- javascript/inputAccordion.js | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/javascript/inputAccordion.js b/javascript/inputAccordion.js index 8fc01230a92..7570309aa73 100644 --- a/javascript/inputAccordion.js +++ b/javascript/inputAccordion.js @@ -4,20 +4,22 @@ function inputAccordionChecked(id, checked) { accordion.onVisibleCheckboxChange(); } -function setupAccordion(accordion){ +function setupAccordion(accordion) { var labelWrap = accordion.querySelector('.label-wrap'); var gradioCheckbox = gradioApp().querySelector('#' + accordion.id + "-checkbox input"); var extra = gradioApp().querySelector('#' + accordion.id + "-extra"); var span = labelWrap.querySelector('span'); var linked = true; - var isOpen = function(){ return labelWrap.classList.contains('open'); } + var isOpen = function() { + return labelWrap.classList.contains('open'); + }; var observerAccordionOpen = new MutationObserver(function(mutations) { mutations.forEach(function(mutationRecord) { accordion.classList.toggle('input-accordion-open', isOpen()); - if(linked){ + if (linked) { accordion.visibleCheckbox.checked = isOpen(); accordion.onVisibleCheckboxChange(); } @@ -29,22 +31,22 @@ function setupAccordion(accordion){ labelWrap.insertBefore(extra, labelWrap.lastElementChild); } - accordion.onChecked = function(checked){ + accordion.onChecked = function(checked) { if (isOpen() != checked) { labelWrap.click(); } - } + }; var visibleCheckbox = document.createElement('INPUT'); visibleCheckbox.type = 'checkbox'; visibleCheckbox.checked = isOpen(); - visibleCheckbox.id = accordion.id + "-visible-checkbox"; + visibleCheckbox.id = accordion.id + "-visible-checkbox"; visibleCheckbox.className = gradioCheckbox.className + " input-accordion-checkbox"; span.insertBefore(visibleCheckbox, span.firstChild); accordion.visibleCheckbox = visibleCheckbox; - accordion.onVisibleCheckboxChange = function(){ - if(linked && isOpen() != visibleCheckbox.checked) { + accordion.onVisibleCheckboxChange = function() { + if (linked && isOpen() != visibleCheckbox.checked) { labelWrap.click(); } @@ -52,7 +54,7 @@ function setupAccordion(accordion){ updateInput(gradioCheckbox); }; - visibleCheckbox.addEventListener('click', function(event){ + visibleCheckbox.addEventListener('click', function(event) { linked = false; event.stopPropagation(); }); From d9499f4301018ebd2977685d098381aa4111d2ae Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sun, 5 Nov 2023 10:12:50 +0300 Subject: [PATCH 0231/1174] properly apply sort order for extra network cards when selected from dropdown allow selection of default sort order in settings remove 'Default' sort order, replace with 'Name' --- javascript/extraNetworks.js | 27 ++++++++++++++++----------- modules/shared_options.py | 2 ++ modules/ui_extra_networks.py | 6 ++++-- 3 files changed, 22 insertions(+), 13 deletions(-) diff --git a/javascript/extraNetworks.js b/javascript/extraNetworks.js index ac26718f6f0..a4d1d9d9856 100644 --- a/javascript/extraNetworks.js +++ b/javascript/extraNetworks.js @@ -27,7 +27,6 @@ function setupExtraNetworksForTab(tabname) { var showDirsDiv = gradioApp().getElementById(tabname + '_extra_show_dirs'); var showDirs = gradioApp().querySelector('#' + tabname + '_extra_show_dirs input'); - sort.dataset.sortkey = 'sortDefault'; tabs.appendChild(searchDiv); tabs.appendChild(sort); tabs.appendChild(sortOrder); @@ -49,20 +48,23 @@ function setupExtraNetworksForTab(tabname) { elem.style.display = visible ? "" : "none"; }); + + applySort(); }; var applySort = function() { + var cards = gradioApp().querySelectorAll('#' + tabname + '_extra_tabs div.card'); + var reverse = sortOrder.classList.contains("sortReverse"); - var sortKey = sort.querySelector("input").value.toLowerCase().replace("sort", "").replaceAll(" ", "_").replace(/_+$/, "").trim(); - sortKey = sortKey ? "sort" + sortKey.charAt(0).toUpperCase() + sortKey.slice(1) : ""; - var sortKeyStore = sortKey ? sortKey + (reverse ? "Reverse" : "") : ""; - if (!sortKey || sortKeyStore == sort.dataset.sortkey) { + var sortKey = sort.querySelector("input").value.toLowerCase().replace("sort", "").replaceAll(" ", "_").replace(/_+$/, "").trim() || "name"; + sortKey = "sort" + sortKey.charAt(0).toUpperCase() + sortKey.slice(1); + var sortKeyStore = sortKey + "-" + (reverse ? "Descending" : "Ascending") + "-" + cards.length; + + if (sortKeyStore == sort.dataset.sortkey) { return; } - sort.dataset.sortkey = sortKeyStore; - var cards = gradioApp().querySelectorAll('#' + tabname + '_extra_tabs div.card'); cards.forEach(function(card) { card.originalParentElement = card.parentElement; }); @@ -88,15 +90,13 @@ function setupExtraNetworksForTab(tabname) { }; search.addEventListener("input", applyFilter); - applyFilter(); - ["change", "blur", "click"].forEach(function(evt) { - sort.querySelector("input").addEventListener(evt, applySort); - }); sortOrder.addEventListener("click", function() { sortOrder.classList.toggle("sortReverse"); applySort(); }); + applyFilter(); + extraNetworksApplySort[tabname] = applySort; extraNetworksApplyFilter[tabname] = applyFilter; var showDirsUpdate = function() { @@ -113,7 +113,12 @@ function applyExtraNetworkFilter(tabname) { setTimeout(extraNetworksApplyFilter[tabname], 1); } +function applyExtraNetworkSort(tabname) { + setTimeout(extraNetworksApplySort[tabname], 1); +} + var extraNetworksApplyFilter = {}; +var extraNetworksApplySort = {}; var activePromptTextarea = {}; function setupExtraNetworks() { diff --git a/modules/shared_options.py b/modules/shared_options.py index 0a82216fff5..6543e440d47 100644 --- a/modules/shared_options.py +++ b/modules/shared_options.py @@ -234,6 +234,8 @@ "extra_networks_card_height": OptionInfo(0, "Card height for Extra Networks").info("in pixels"), "extra_networks_card_text_scale": OptionInfo(1.0, "Card text scale", gr.Slider, {"minimum": 0.0, "maximum": 2.0, "step": 0.01}).info("1 = original size"), "extra_networks_card_show_desc": OptionInfo(True, "Show description on card"), + "extra_networks_card_order_field": OptionInfo("Name", "Default order field for Extra Networks cards", gr.Dropdown, {"choices": ['Name', 'Date Created', 'Date Modified']}).needs_reload_ui(), + "extra_networks_card_order": OptionInfo("Ascending", "Default order for Extra Networks cards", gr.Dropdown, {"choices": ['Ascending', 'Descending']}).needs_reload_ui(), "extra_networks_add_text_separator": OptionInfo(" ", "Extra networks separator").info("extra text to add before <...> when adding extra network to prompt"), "ui_extra_networks_tab_reorder": OptionInfo("", "Extra networks tab order").needs_reload_ui(), "textual_inversion_print_at_load": OptionInfo(False, "Print a list of Textual Inversion embeddings when loading model"), diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index 59d6ecc612a..fc729917b77 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -381,8 +381,8 @@ def create_ui(interface: gr.Blocks, unrelated_tabs, tabname): related_tabs.append(tab) edit_search = gr.Textbox('', show_label=False, elem_id=tabname+"_extra_search", elem_classes="search", placeholder="Search...", visible=False, interactive=True) - dropdown_sort = gr.Dropdown(choices=['Default Sort', 'Date Created', 'Date Modified', 'Name'], value='Default Sort', elem_id=tabname+"_extra_sort", elem_classes="sort", multiselect=False, visible=False, show_label=False, interactive=True, label=tabname+"_extra_sort_order") - button_sortorder = ToolButton(switch_values_symbol, elem_id=tabname+"_extra_sortorder", elem_classes="sortorder", visible=False, tooltip="Invert sort order") + dropdown_sort = gr.Dropdown(choices=['Name', 'Date Created', 'Date Modified', ], value=shared.opts.extra_networks_card_order_field, elem_id=tabname+"_extra_sort", elem_classes="sort", multiselect=False, visible=False, show_label=False, interactive=True, label=tabname+"_extra_sort_order") + button_sortorder = ToolButton(switch_values_symbol, elem_id=tabname+"_extra_sortorder", elem_classes=["sortorder"] + ([] if shared.opts.extra_networks_card_order == "Ascending" else ["sortReverse"]), visible=False, tooltip="Invert sort order") button_refresh = gr.Button('Refresh', elem_id=tabname+"_extra_refresh", visible=False) checkbox_show_dirs = gr.Checkbox(True, label='Show dirs', elem_id=tabname+"_extra_show_dirs", elem_classes="show-dirs", visible=False) @@ -395,6 +395,8 @@ def create_ui(interface: gr.Blocks, unrelated_tabs, tabname): for tab in related_tabs: tab.select(fn=lambda: [gr.update(visible=True) for _ in range(5)], inputs=[], outputs=[edit_search, dropdown_sort, button_sortorder, button_refresh, checkbox_show_dirs], show_progress=False) + dropdown_sort.change(fn=lambda: None, _js="function(){ applyExtraNetworkSort('" + tabname + "'); }") + def pages_html(): if not ui.pages_contents: return refresh() From ff1609f91ea0e9a90ba7b6ecc6d794c39c1f8c8f Mon Sep 17 00:00:00 2001 From: Ritesh Gangnani Date: Sun, 5 Nov 2023 19:13:49 +0530 Subject: [PATCH 0232/1174] Add SSD-1B as a supported model --- modules/sd_hijack.py | 11 +++++++++++ modules/sd_models.py | 8 ++++++-- modules/sd_models_types.py | 5 ++++- 3 files changed, 21 insertions(+), 3 deletions(-) diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 592f00551f1..d19f853ec2c 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -180,6 +180,17 @@ def apply_optimizations(self, option=None): except Exception as e: errors.display(e, "applying cross attention optimization") undo_optimizations() + + def conv_ssd(self, m): + delattr(m.model.diffusion_model.middle_block, '1') + delattr(m.model.diffusion_model.middle_block, '2') + for i in ['9','8','7','6','5','4']: + delattr(m.model.diffusion_model.input_blocks[7][1].transformer_blocks,i) + delattr(m.model.diffusion_model.input_blocks[8][1].transformer_blocks,i) + delattr(m.model.diffusion_model.output_blocks[0][1].transformer_blocks,i) + delattr(m.model.diffusion_model.output_blocks[1][1].transformer_blocks,i) + delattr(m.model.diffusion_model.output_blocks[4][1].transformer_blocks,'1') + delattr(m.model.diffusion_model.output_blocks[5][1].transformer_blocks,'1') def hijack(self, m): conditioner = getattr(m, 'conditioner', None) diff --git a/modules/sd_models.py b/modules/sd_models.py index 930d0bee5c8..ef96d29dbc4 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -346,10 +346,14 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer model.is_sdxl = hasattr(model, 'conditioner') model.is_sd2 = not model.is_sdxl and hasattr(model.cond_stage_model, 'model') model.is_sd1 = not model.is_sdxl and not model.is_sd2 - + model.is_ssd = model.is_sdxl and 'model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_q.weight' not in state_dict.keys() + if model.is_sdxl: sd_models_xl.extend_sdxl(model) - + + if model.is_ssd: + sd_hijack.model_hijack.conv_ssd(model) + model.load_state_dict(state_dict, strict=False) timer.record("apply weights to model") diff --git a/modules/sd_models_types.py b/modules/sd_models_types.py index 5ffd2f4f9fd..1f28942a446 100644 --- a/modules/sd_models_types.py +++ b/modules/sd_models_types.py @@ -22,7 +22,10 @@ class WebuiSdModel(LatentDiffusion): """structure with additional information about the file with model's weights""" is_sdxl: bool - """True if the model's architecture is SDXL""" + """True if the model's architecture is SDXL or SSD""" + + is_ssd: bool + """True if the model is SSD""" is_sd2: bool """True if the model's architecture is SD 2.x""" From 44db35fb1ad5d07837e890a0fd3c00addfb0402c Mon Sep 17 00:00:00 2001 From: Ritesh Gangnani Date: Sun, 5 Nov 2023 19:15:38 +0530 Subject: [PATCH 0233/1174] Added memory clearance after deletion --- modules/sd_hijack.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index d19f853ec2c..059ffe8f066 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -1,3 +1,5 @@ +import gc + import torch from torch.nn.functional import silu from types import MethodType @@ -190,7 +192,9 @@ def conv_ssd(self, m): delattr(m.model.diffusion_model.output_blocks[0][1].transformer_blocks,i) delattr(m.model.diffusion_model.output_blocks[1][1].transformer_blocks,i) delattr(m.model.diffusion_model.output_blocks[4][1].transformer_blocks,'1') - delattr(m.model.diffusion_model.output_blocks[5][1].transformer_blocks,'1') + delattr(m.model.diffusion_model.output_blocks[5][1].transformer_blocks,'1') + torch.cuda.empty_cache() + gc.collect() def hijack(self, m): conditioner = getattr(m, 'conditioner', None) From 44c5097375ae4cf40300c09473bb46cf6c5d6cb7 Mon Sep 17 00:00:00 2001 From: Ritesh Gangnani Date: Sun, 5 Nov 2023 20:31:57 +0530 Subject: [PATCH 0234/1174] Use devices.torch_gc() instead of empty_cache() --- modules/sd_hijack.py | 5 +---- modules/sd_models.py | 1 - 2 files changed, 1 insertion(+), 5 deletions(-) diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 059ffe8f066..0a7e5135d81 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -1,5 +1,3 @@ -import gc - import torch from torch.nn.functional import silu from types import MethodType @@ -193,8 +191,7 @@ def conv_ssd(self, m): delattr(m.model.diffusion_model.output_blocks[1][1].transformer_blocks,i) delattr(m.model.diffusion_model.output_blocks[4][1].transformer_blocks,'1') delattr(m.model.diffusion_model.output_blocks[5][1].transformer_blocks,'1') - torch.cuda.empty_cache() - gc.collect() + devices.torch_gc() def hijack(self, m): conditioner = getattr(m, 'conditioner', None) diff --git a/modules/sd_models.py b/modules/sd_models.py index ef96d29dbc4..2242c3637a5 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -347,7 +347,6 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer model.is_sd2 = not model.is_sdxl and hasattr(model.cond_stage_model, 'model') model.is_sd1 = not model.is_sdxl and not model.is_sd2 model.is_ssd = model.is_sdxl and 'model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_q.weight' not in state_dict.keys() - if model.is_sdxl: sd_models_xl.extend_sdxl(model) From 4d4a9e733219f8c065a4ab6c5ab42836db7330fe Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sun, 5 Nov 2023 19:19:55 +0300 Subject: [PATCH 0235/1174] added compact prompt option --- .../mobile/javascript/mobile.js | 2 + javascript/extraNetworks.js | 33 +++ modules/shared_items.py | 2 + modules/shared_options.py | 1 + modules/ui.py | 247 +++++++----------- modules/ui_common.py | 15 +- modules/ui_extra_networks.py | 16 +- modules/ui_extra_networks_checkpoints.py | 2 + modules/ui_toprow.py | 141 ++++++++++ style.css | 23 +- 10 files changed, 314 insertions(+), 168 deletions(-) create mode 100644 modules/ui_toprow.py diff --git a/extensions-builtin/mobile/javascript/mobile.js b/extensions-builtin/mobile/javascript/mobile.js index 652f07ac7ec..bff1acedff3 100644 --- a/extensions-builtin/mobile/javascript/mobile.js +++ b/extensions-builtin/mobile/javascript/mobile.js @@ -12,6 +12,8 @@ function isMobile() { } function reportWindowSize() { + if (gradioApp().querySelector('.toprow-compact-tools')) return; // not applicable for compact prompt layout + var currentlyMobile = isMobile(); if (currentlyMobile == isSetupForMobile) return; isSetupForMobile = currentlyMobile; diff --git a/javascript/extraNetworks.js b/javascript/extraNetworks.js index a4d1d9d9856..a1bf29a8c72 100644 --- a/javascript/extraNetworks.js +++ b/javascript/extraNetworks.js @@ -26,6 +26,8 @@ function setupExtraNetworksForTab(tabname) { var refresh = gradioApp().getElementById(tabname + '_extra_refresh'); var showDirsDiv = gradioApp().getElementById(tabname + '_extra_show_dirs'); var showDirs = gradioApp().querySelector('#' + tabname + '_extra_show_dirs input'); + var promptContainer = gradioApp().querySelector('.prompt-container-compact#' + tabname + '_prompt_container'); + var negativePrompt = gradioApp().querySelector('#' + tabname + '_neg_prompt'); tabs.appendChild(searchDiv); tabs.appendChild(sort); @@ -109,6 +111,37 @@ function setupExtraNetworksForTab(tabname) { showDirsUpdate(); } +function extraNetworksMovePromptToTab(tabname, id, showPrompt, showNegativePrompt) { + if (!gradioApp().querySelector('.toprow-compact-tools')) return; // only applicable for compact prompt layout + + var promptContainer = gradioApp().getElementById(tabname + '_prompt_container'); + var prompt = gradioApp().getElementById(tabname + '_prompt_row'); + var negPrompt = gradioApp().getElementById(tabname + '_neg_prompt_row'); + var elem = id ? gradioApp().getElementById(id) : null; + + if (showNegativePrompt && elem) { + elem.insertBefore(negPrompt, elem.firstChild); + } else { + promptContainer.insertBefore(negPrompt, promptContainer.firstChild); + } + + if (showPrompt && elem) { + elem.insertBefore(prompt, elem.firstChild); + } else { + promptContainer.insertBefore(prompt, promptContainer.firstChild); + } +} + + +function extraNetworksUrelatedTabSelected(tabname) { // called from python when user selects an unrelated tab (generate) + extraNetworksMovePromptToTab(tabname, '', false, false); +} + +function extraNetworksTabSelected(tabname, id, showPrompt, showNegativePrompt) { // called from python when user selects an extra networks tab + extraNetworksMovePromptToTab(tabname, id, showPrompt, showNegativePrompt); + +} + function applyExtraNetworkFilter(tabname) { setTimeout(extraNetworksApplyFilter[tabname], 1); } diff --git a/modules/shared_items.py b/modules/shared_items.py index b1459f8c4e4..5024b4268d0 100644 --- a/modules/shared_items.py +++ b/modules/shared_items.py @@ -67,6 +67,8 @@ def reload_hypernetworks(): ui_reorder_categories_builtin_items = [ + "prompt", + "image", "inpaint", "sampler", "accordions", diff --git a/modules/shared_options.py b/modules/shared_options.py index 6543e440d47..4e3d7541699 100644 --- a/modules/shared_options.py +++ b/modules/shared_options.py @@ -272,6 +272,7 @@ "hires_fix_show_sampler": OptionInfo(False, "Hires fix: show hires checkpoint and sampler selection").needs_reload_ui(), "hires_fix_show_prompts": OptionInfo(False, "Hires fix: show hires prompt and negative prompt").needs_reload_ui(), "disable_token_counters": OptionInfo(False, "Disable prompt token counters").needs_reload_ui(), + "compact_prompt_box": OptionInfo(True, "Compact prompt layout").info("puts prompt and negative prompt inside the Generate tab, leaving more vertical space for the image on the right").needs_reload_ui(), })) diff --git a/modules/ui.py b/modules/ui.py index bcf39199715..2454eb36bbf 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -12,7 +12,7 @@ from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call from modules import gradio_extensons # noqa: F401 -from modules import sd_hijack, sd_models, script_callbacks, ui_extensions, deepbooru, extra_networks, ui_common, ui_postprocessing, progress, ui_loadsave, shared_items, ui_settings, timer, sysinfo, ui_checkpoint_merger, ui_prompt_styles, scripts, sd_samplers, processing, ui_extra_networks +from modules import sd_hijack, sd_models, script_callbacks, ui_extensions, deepbooru, extra_networks, ui_common, ui_postprocessing, progress, ui_loadsave, shared_items, ui_settings, timer, sysinfo, ui_checkpoint_merger, scripts, sd_samplers, processing, ui_extra_networks, ui_toprow from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML, InputAccordion, ResizeHandleRow from modules.paths import script_path from modules.ui_common import create_refresh_button @@ -25,7 +25,6 @@ import modules.textual_inversion.ui as textual_inversion_ui import modules.textual_inversion.textual_inversion as textual_inversion import modules.shared as shared -import modules.images from modules import prompt_parser from modules.sd_hijack import model_hijack from modules.generation_parameters_copypaste import image_from_url_text @@ -177,79 +176,6 @@ def update_negative_prompt_token_counter(text, steps): return update_token_counter(text, steps, is_positive=False) -class Toprow: - """Creates a top row UI with prompts, generate button, styles, extra little buttons for things, and enables some functionality related to their operation""" - - def __init__(self, is_img2img): - id_part = "img2img" if is_img2img else "txt2img" - self.id_part = id_part - - with gr.Row(elem_id=f"{id_part}_toprow", variant="compact"): - with gr.Column(elem_id=f"{id_part}_prompt_container", scale=6): - with gr.Row(): - with gr.Column(scale=80): - with gr.Row(): - self.prompt = gr.Textbox(label="Prompt", elem_id=f"{id_part}_prompt", show_label=False, lines=3, placeholder="Prompt (press Ctrl+Enter or Alt+Enter to generate)", elem_classes=["prompt"]) - self.prompt_img = gr.File(label="", elem_id=f"{id_part}_prompt_image", file_count="single", type="binary", visible=False) - - with gr.Row(): - with gr.Column(scale=80): - with gr.Row(): - self.negative_prompt = gr.Textbox(label="Negative prompt", elem_id=f"{id_part}_neg_prompt", show_label=False, lines=3, placeholder="Negative prompt (press Ctrl+Enter or Alt+Enter to generate)", elem_classes=["prompt"]) - - self.button_interrogate = None - self.button_deepbooru = None - if is_img2img: - with gr.Column(scale=1, elem_classes="interrogate-col"): - self.button_interrogate = gr.Button('Interrogate\nCLIP', elem_id="interrogate") - self.button_deepbooru = gr.Button('Interrogate\nDeepBooru', elem_id="deepbooru") - - with gr.Column(scale=1, elem_id=f"{id_part}_actions_column"): - with gr.Row(elem_id=f"{id_part}_generate_box", elem_classes="generate-box"): - self.interrupt = gr.Button('Interrupt', elem_id=f"{id_part}_interrupt", elem_classes="generate-box-interrupt") - self.skip = gr.Button('Skip', elem_id=f"{id_part}_skip", elem_classes="generate-box-skip") - self.submit = gr.Button('Generate', elem_id=f"{id_part}_generate", variant='primary') - - self.skip.click( - fn=lambda: shared.state.skip(), - inputs=[], - outputs=[], - ) - - self.interrupt.click( - fn=lambda: shared.state.interrupt(), - inputs=[], - outputs=[], - ) - - with gr.Row(elem_id=f"{id_part}_tools"): - self.paste = ToolButton(value=paste_symbol, elem_id="paste", tooltip="Read generation parameters from prompt or last generation if prompt is empty into user interface.") - self.clear_prompt_button = ToolButton(value=clear_prompt_symbol, elem_id=f"{id_part}_clear_prompt", tooltip="Clear prompt") - self.apply_styles = ToolButton(value=ui_prompt_styles.styles_materialize_symbol, elem_id=f"{id_part}_style_apply", tooltip="Apply all selected styles to prompts.") - self.restore_progress_button = ToolButton(value=restore_progress_symbol, elem_id=f"{id_part}_restore_progress", visible=False, tooltip="Restore progress") - - self.token_counter = gr.HTML(value="0/75", elem_id=f"{id_part}_token_counter", elem_classes=["token-counter"]) - self.token_button = gr.Button(visible=False, elem_id=f"{id_part}_token_button") - self.negative_token_counter = gr.HTML(value="0/75", elem_id=f"{id_part}_negative_token_counter", elem_classes=["token-counter"]) - self.negative_token_button = gr.Button(visible=False, elem_id=f"{id_part}_negative_token_button") - - self.clear_prompt_button.click( - fn=lambda *x: x, - _js="confirm_clear_prompt", - inputs=[self.prompt, self.negative_prompt], - outputs=[self.prompt, self.negative_prompt], - ) - - self.ui_styles = ui_prompt_styles.UiPromptStyles(id_part, self.prompt, self.negative_prompt) - self.ui_styles.setup_apply_button(self.apply_styles) - - self.prompt_img.change( - fn=modules.images.image_data, - inputs=[self.prompt_img], - outputs=[self.prompt, self.prompt_img], - show_progress=False, - ) - def setup_progressbar(*args, **kwargs): pass @@ -288,8 +214,8 @@ def apply_setting(key, value): return getattr(opts, key) -def create_output_panel(tabname, outdir): - return ui_common.create_output_panel(tabname, outdir) +def create_output_panel(tabname, outdir, toprow=None): + return ui_common.create_output_panel(tabname, outdir, toprow) def create_sampler_and_steps_selection(choices, tabname): @@ -336,7 +262,7 @@ def create_ui(): scripts.scripts_txt2img.initialize_scripts(is_img2img=False) with gr.Blocks(analytics_enabled=False) as txt2img_interface: - toprow = Toprow(is_img2img=False) + toprow = ui_toprow.Toprow(is_img2img=False, is_compact=shared.opts.compact_prompt_box) dummy_component = gr.Label(visible=False) @@ -348,6 +274,9 @@ def create_ui(): scripts.scripts_txt2img.prepare_ui() for category in ordered_ui_categories(): + if category == "prompt": + toprow.create_inline_toprow_prompts() + if category == "sampler": steps, sampler_name = create_sampler_and_steps_selection(sd_samplers.visible_sampler_names(), "txt2img") @@ -442,7 +371,7 @@ def create_ui(): show_progress=False, ) - txt2img_gallery, generation_info, html_info, html_log = create_output_panel("txt2img", opts.outdir_txt2img_samples) + txt2img_gallery, generation_info, html_info, html_log = create_output_panel("txt2img", opts.outdir_txt2img_samples, toprow) txt2img_args = dict( fn=wrap_gradio_gpu_call(modules.txt2img.txt2img, extra_outputs=[None, '', '']), @@ -554,7 +483,7 @@ def create_ui(): scripts.scripts_img2img.initialize_scripts(is_img2img=True) with gr.Blocks(analytics_enabled=False) as img2img_interface: - toprow = Toprow(is_img2img=True) + toprow = ui_toprow.Toprow(is_img2img=True, is_compact=shared.opts.compact_prompt_box) extra_tabs = gr.Tabs(elem_id="img2img_extra_tabs") extra_tabs.__enter__() @@ -577,85 +506,89 @@ def add_copy_image_controls(tab_name, elem): button = gr.Button(title) copy_image_buttons.append((button, name, elem)) - with gr.Tabs(elem_id="mode_img2img"): - img2img_selected_tab = gr.State(0) - - with gr.TabItem('img2img', id='img2img', elem_id="img2img_img2img_tab") as tab_img2img: - init_img = gr.Image(label="Image for img2img", elem_id="img2img_image", show_label=False, source="upload", interactive=True, type="pil", tool="editor", image_mode="RGBA", height=opts.img2img_editor_height) - add_copy_image_controls('img2img', init_img) - - with gr.TabItem('Sketch', id='img2img_sketch', elem_id="img2img_img2img_sketch_tab") as tab_sketch: - sketch = gr.Image(label="Image for img2img", elem_id="img2img_sketch", show_label=False, source="upload", interactive=True, type="pil", tool="color-sketch", image_mode="RGB", height=opts.img2img_editor_height, brush_color=opts.img2img_sketch_default_brush_color) - add_copy_image_controls('sketch', sketch) - - with gr.TabItem('Inpaint', id='inpaint', elem_id="img2img_inpaint_tab") as tab_inpaint: - init_img_with_mask = gr.Image(label="Image for inpainting with mask", show_label=False, elem_id="img2maskimg", source="upload", interactive=True, type="pil", tool="sketch", image_mode="RGBA", height=opts.img2img_editor_height, brush_color=opts.img2img_inpaint_mask_brush_color) - add_copy_image_controls('inpaint', init_img_with_mask) - - with gr.TabItem('Inpaint sketch', id='inpaint_sketch', elem_id="img2img_inpaint_sketch_tab") as tab_inpaint_color: - inpaint_color_sketch = gr.Image(label="Color sketch inpainting", show_label=False, elem_id="inpaint_sketch", source="upload", interactive=True, type="pil", tool="color-sketch", image_mode="RGB", height=opts.img2img_editor_height, brush_color=opts.img2img_inpaint_sketch_default_brush_color) - inpaint_color_sketch_orig = gr.State(None) - add_copy_image_controls('inpaint_sketch', inpaint_color_sketch) - - def update_orig(image, state): - if image is not None: - same_size = state is not None and state.size == image.size - has_exact_match = np.any(np.all(np.array(image) == np.array(state), axis=-1)) - edited = same_size and has_exact_match - return image if not edited or state is None else state - - inpaint_color_sketch.change(update_orig, [inpaint_color_sketch, inpaint_color_sketch_orig], inpaint_color_sketch_orig) - - with gr.TabItem('Inpaint upload', id='inpaint_upload', elem_id="img2img_inpaint_upload_tab") as tab_inpaint_upload: - init_img_inpaint = gr.Image(label="Image for img2img", show_label=False, source="upload", interactive=True, type="pil", elem_id="img_inpaint_base") - init_mask_inpaint = gr.Image(label="Mask", source="upload", interactive=True, type="pil", image_mode="RGBA", elem_id="img_inpaint_mask") - - with gr.TabItem('Batch', id='batch', elem_id="img2img_batch_tab") as tab_batch: - hidden = '
Disabled when launched with --hide-ui-dir-config.' if shared.cmd_opts.hide_ui_dir_config else '' - gr.HTML( - "

Process images in a directory on the same machine where the server is running." + - "
Use an empty output directory to save pictures normally instead of writing to the output directory." + - f"
Add inpaint batch mask directory to enable inpaint batch processing." - f"{hidden}

" - ) - img2img_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs, elem_id="img2img_batch_input_dir") - img2img_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, elem_id="img2img_batch_output_dir") - img2img_batch_inpaint_mask_dir = gr.Textbox(label="Inpaint batch mask directory (required for inpaint batch processing only)", **shared.hide_dirs, elem_id="img2img_batch_inpaint_mask_dir") - with gr.Accordion("PNG info", open=False): - img2img_batch_use_png_info = gr.Checkbox(label="Append png info to prompts", **shared.hide_dirs, elem_id="img2img_batch_use_png_info") - img2img_batch_png_info_dir = gr.Textbox(label="PNG info directory", **shared.hide_dirs, placeholder="Leave empty to use input directory", elem_id="img2img_batch_png_info_dir") - img2img_batch_png_info_props = gr.CheckboxGroup(["Prompt", "Negative prompt", "Seed", "CFG scale", "Sampler", "Steps", "Model hash"], label="Parameters to take from png info", info="Prompts from png info will be appended to prompts set in ui.") - - img2img_tabs = [tab_img2img, tab_sketch, tab_inpaint, tab_inpaint_color, tab_inpaint_upload, tab_batch] - - for i, tab in enumerate(img2img_tabs): - tab.select(fn=lambda tabnum=i: tabnum, inputs=[], outputs=[img2img_selected_tab]) - - def copy_image(img): - if isinstance(img, dict) and 'image' in img: - return img['image'] - - return img - - for button, name, elem in copy_image_buttons: - button.click( - fn=copy_image, - inputs=[elem], - outputs=[copy_image_destinations[name]], - ) - button.click( - fn=lambda: None, - _js=f"switch_to_{name.replace(' ', '_')}", - inputs=[], - outputs=[], - ) - - with FormRow(): - resize_mode = gr.Radio(label="Resize mode", elem_id="resize_mode", choices=["Just resize", "Crop and resize", "Resize and fill", "Just resize (latent upscale)"], type="index", value="Just resize") - scripts.scripts_img2img.prepare_ui() for category in ordered_ui_categories(): + if category == "prompt": + toprow.create_inline_toprow_prompts() + + if category == "image": + with gr.Tabs(elem_id="mode_img2img"): + img2img_selected_tab = gr.State(0) + + with gr.TabItem('img2img', id='img2img', elem_id="img2img_img2img_tab") as tab_img2img: + init_img = gr.Image(label="Image for img2img", elem_id="img2img_image", show_label=False, source="upload", interactive=True, type="pil", tool="editor", image_mode="RGBA", height=opts.img2img_editor_height) + add_copy_image_controls('img2img', init_img) + + with gr.TabItem('Sketch', id='img2img_sketch', elem_id="img2img_img2img_sketch_tab") as tab_sketch: + sketch = gr.Image(label="Image for img2img", elem_id="img2img_sketch", show_label=False, source="upload", interactive=True, type="pil", tool="color-sketch", image_mode="RGB", height=opts.img2img_editor_height, brush_color=opts.img2img_sketch_default_brush_color) + add_copy_image_controls('sketch', sketch) + + with gr.TabItem('Inpaint', id='inpaint', elem_id="img2img_inpaint_tab") as tab_inpaint: + init_img_with_mask = gr.Image(label="Image for inpainting with mask", show_label=False, elem_id="img2maskimg", source="upload", interactive=True, type="pil", tool="sketch", image_mode="RGBA", height=opts.img2img_editor_height, brush_color=opts.img2img_inpaint_mask_brush_color) + add_copy_image_controls('inpaint', init_img_with_mask) + + with gr.TabItem('Inpaint sketch', id='inpaint_sketch', elem_id="img2img_inpaint_sketch_tab") as tab_inpaint_color: + inpaint_color_sketch = gr.Image(label="Color sketch inpainting", show_label=False, elem_id="inpaint_sketch", source="upload", interactive=True, type="pil", tool="color-sketch", image_mode="RGB", height=opts.img2img_editor_height, brush_color=opts.img2img_inpaint_sketch_default_brush_color) + inpaint_color_sketch_orig = gr.State(None) + add_copy_image_controls('inpaint_sketch', inpaint_color_sketch) + + def update_orig(image, state): + if image is not None: + same_size = state is not None and state.size == image.size + has_exact_match = np.any(np.all(np.array(image) == np.array(state), axis=-1)) + edited = same_size and has_exact_match + return image if not edited or state is None else state + + inpaint_color_sketch.change(update_orig, [inpaint_color_sketch, inpaint_color_sketch_orig], inpaint_color_sketch_orig) + + with gr.TabItem('Inpaint upload', id='inpaint_upload', elem_id="img2img_inpaint_upload_tab") as tab_inpaint_upload: + init_img_inpaint = gr.Image(label="Image for img2img", show_label=False, source="upload", interactive=True, type="pil", elem_id="img_inpaint_base") + init_mask_inpaint = gr.Image(label="Mask", source="upload", interactive=True, type="pil", image_mode="RGBA", elem_id="img_inpaint_mask") + + with gr.TabItem('Batch', id='batch', elem_id="img2img_batch_tab") as tab_batch: + hidden = '
Disabled when launched with --hide-ui-dir-config.' if shared.cmd_opts.hide_ui_dir_config else '' + gr.HTML( + "

Process images in a directory on the same machine where the server is running." + + "
Use an empty output directory to save pictures normally instead of writing to the output directory." + + f"
Add inpaint batch mask directory to enable inpaint batch processing." + f"{hidden}

" + ) + img2img_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs, elem_id="img2img_batch_input_dir") + img2img_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, elem_id="img2img_batch_output_dir") + img2img_batch_inpaint_mask_dir = gr.Textbox(label="Inpaint batch mask directory (required for inpaint batch processing only)", **shared.hide_dirs, elem_id="img2img_batch_inpaint_mask_dir") + with gr.Accordion("PNG info", open=False): + img2img_batch_use_png_info = gr.Checkbox(label="Append png info to prompts", **shared.hide_dirs, elem_id="img2img_batch_use_png_info") + img2img_batch_png_info_dir = gr.Textbox(label="PNG info directory", **shared.hide_dirs, placeholder="Leave empty to use input directory", elem_id="img2img_batch_png_info_dir") + img2img_batch_png_info_props = gr.CheckboxGroup(["Prompt", "Negative prompt", "Seed", "CFG scale", "Sampler", "Steps", "Model hash"], label="Parameters to take from png info", info="Prompts from png info will be appended to prompts set in ui.") + + img2img_tabs = [tab_img2img, tab_sketch, tab_inpaint, tab_inpaint_color, tab_inpaint_upload, tab_batch] + + for i, tab in enumerate(img2img_tabs): + tab.select(fn=lambda tabnum=i: tabnum, inputs=[], outputs=[img2img_selected_tab]) + + def copy_image(img): + if isinstance(img, dict) and 'image' in img: + return img['image'] + + return img + + for button, name, elem in copy_image_buttons: + button.click( + fn=copy_image, + inputs=[elem], + outputs=[copy_image_destinations[name]], + ) + button.click( + fn=lambda: None, + _js=f"switch_to_{name.replace(' ', '_')}", + inputs=[], + outputs=[], + ) + + with FormRow(): + resize_mode = gr.Radio(label="Resize mode", elem_id="resize_mode", choices=["Just resize", "Crop and resize", "Resize and fill", "Just resize (latent upscale)"], type="index", value="Just resize") + if category == "sampler": steps, sampler_name = create_sampler_and_steps_selection(sd_samplers.visible_sampler_names(), "img2img") @@ -769,7 +702,7 @@ def select_img2img_tab(tab): if category not in {"accordions"}: scripts.scripts_img2img.setup_ui_for_section(category) - img2img_gallery, generation_info, html_info, html_log = create_output_panel("img2img", opts.outdir_img2img_samples) + img2img_gallery, generation_info, html_info, html_log = create_output_panel("img2img", opts.outdir_img2img_samples, toprow) img2img_args = dict( fn=wrap_gradio_gpu_call(modules.img2img.img2img, extra_outputs=[None, '', '']), diff --git a/modules/ui_common.py b/modules/ui_common.py index 84a7d7f2756..032ec4af762 100644 --- a/modules/ui_common.py +++ b/modules/ui_common.py @@ -104,7 +104,7 @@ def __init__(self, d=None): return gr.File.update(value=fullfns, visible=True), plaintext_to_html(f"Saved: {filenames[0]}") -def create_output_panel(tabname, outdir): +def create_output_panel(tabname, outdir, toprow=None): def open_folder(f): if not os.path.exists(f): @@ -130,12 +130,15 @@ def open_folder(f): else: sp.Popen(["xdg-open", path]) - with gr.Column(variant='panel', elem_id=f"{tabname}_results"): - with gr.Group(elem_id=f"{tabname}_gallery_container"): - result_gallery = gr.Gallery(label='Output', show_label=False, elem_id=f"{tabname}_gallery", columns=4, preview=True, height=shared.opts.gallery_height or None) + with gr.Column(elem_id=f"{tabname}_results"): + if toprow: + toprow.create_inline_toprow_image() - generation_info = None - with gr.Column(): + with gr.Column(variant='panel', elem_id=f"{tabname}_results_panel"): + with gr.Group(elem_id=f"{tabname}_gallery_container"): + result_gallery = gr.Gallery(label='Output', show_label=False, elem_id=f"{tabname}_gallery", columns=4, preview=True, height=shared.opts.gallery_height or None) + + generation_info = None with gr.Row(elem_id=f"image_buttons_{tabname}", elem_classes="image-buttons"): open_folder_button = ToolButton(folder_symbol, elem_id=f'{tabname}_open_folder', visible=not shared.cmd_opts.hide_ui_dir_config, tooltip="Open images output directory.") diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index fc729917b77..7907cd63f9c 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -103,6 +103,7 @@ def __init__(self, title): self.name = title.lower() self.id_page = self.name.replace(" ", "_") self.card_page = shared.html("extra-networks-card.html") + self.allow_prompt = True self.allow_negative_prompt = False self.metadata = {} self.items = {} @@ -367,7 +368,7 @@ def create_ui(interface: gr.Blocks, unrelated_tabs, tabname): related_tabs = [] for page in ui.stored_extra_pages: - with gr.Tab(page.title, id=page.id_page) as tab: + with gr.Tab(page.title, elem_id=f"{tabname}_{page.id_page}", elem_classes=["extra-page"]) as tab: elem_id = f"{tabname}_{page.id_page}_cards_html" page_elem = gr.HTML('Loading...', elem_id=elem_id) ui.pages.append(page_elem) @@ -389,11 +390,18 @@ def create_ui(interface: gr.Blocks, unrelated_tabs, tabname): ui.button_save_preview = gr.Button('Save preview', elem_id=tabname+"_save_preview", visible=False) ui.preview_target_filename = gr.Textbox('Preview save filename', elem_id=tabname+"_preview_filename", visible=False) + tab_controls = [edit_search, dropdown_sort, button_sortorder, button_refresh, checkbox_show_dirs] + for tab in unrelated_tabs: - tab.select(fn=lambda: [gr.update(visible=False) for _ in range(5)], inputs=[], outputs=[edit_search, dropdown_sort, button_sortorder, button_refresh, checkbox_show_dirs], show_progress=False) + tab.select(fn=lambda: [gr.update(visible=False) for _ in tab_controls], _js='function(){ extraNetworksUrelatedTabSelected("' + tabname + '"); }', inputs=[], outputs=tab_controls, show_progress=False) + + for page, tab in zip(ui.stored_extra_pages, related_tabs): + allow_prompt = "true" if page.allow_prompt else "false" + allow_negative_prompt = "true" if page.allow_negative_prompt else "false" + + jscode = 'extraNetworksTabSelected("' + tabname + '", "' + f"{tabname}_{page.id_page}" + '", ' + allow_prompt + ', ' + allow_negative_prompt + ');' - for tab in related_tabs: - tab.select(fn=lambda: [gr.update(visible=True) for _ in range(5)], inputs=[], outputs=[edit_search, dropdown_sort, button_sortorder, button_refresh, checkbox_show_dirs], show_progress=False) + tab.select(fn=lambda: [gr.update(visible=True) for _ in tab_controls], _js='function(){ ' + jscode + ' }', inputs=[], outputs=tab_controls, show_progress=False) dropdown_sort.change(fn=lambda: None, _js="function(){ applyExtraNetworkSort('" + tabname + "'); }") diff --git a/modules/ui_extra_networks_checkpoints.py b/modules/ui_extra_networks_checkpoints.py index ca6c26076f9..2fc0ed43d03 100644 --- a/modules/ui_extra_networks_checkpoints.py +++ b/modules/ui_extra_networks_checkpoints.py @@ -10,6 +10,8 @@ class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage): def __init__(self): super().__init__('Checkpoints') + self.allow_prompt = False + def refresh(self): shared.refresh_checkpoints() diff --git a/modules/ui_toprow.py b/modules/ui_toprow.py new file mode 100644 index 00000000000..985b5a2dd9c --- /dev/null +++ b/modules/ui_toprow.py @@ -0,0 +1,141 @@ +import gradio as gr + +from modules import shared, ui_prompt_styles +import modules.images + +from modules.ui_components import ToolButton + + +class Toprow: + """Creates a top row UI with prompts, generate button, styles, extra little buttons for things, and enables some functionality related to their operation""" + + prompt = None + prompt_img = None + negative_prompt = None + + button_interrogate = None + button_deepbooru = None + + interrupt = None + skip = None + submit = None + + paste = None + clear_prompt_button = None + apply_styles = None + restore_progress_button = None + + token_counter = None + token_button = None + negative_token_counter = None + negative_token_button = None + + ui_styles = None + + submit_box = None + + def __init__(self, is_img2img, is_compact=False): + id_part = "img2img" if is_img2img else "txt2img" + self.id_part = id_part + self.is_img2img = is_img2img + self.is_compact = is_compact + + if not is_compact: + with gr.Row(elem_id=f"{id_part}_toprow", variant="compact"): + self.create_classic_toprow() + else: + self.create_submit_box() + + def create_classic_toprow(self): + self.create_prompts() + + with gr.Column(scale=1, elem_id=f"{self.id_part}_actions_column"): + self.create_submit_box() + + self.create_tools_row() + + self.create_styles_ui() + + def create_inline_toprow_prompts(self): + if not self.is_compact: + return + + self.create_prompts() + + with gr.Row(elem_classes=["toprow-compact-stylerow"]): + with gr.Column(elem_classes=["toprow-compact-tools"]): + self.create_tools_row() + with gr.Column(): + self.create_styles_ui() + + def create_inline_toprow_image(self): + if not self.is_compact: + return + + self.submit_box.render() + + def create_prompts(self): + with gr.Column(elem_id=f"{self.id_part}_prompt_container", elem_classes=["prompt-container-compact"] if self.is_compact else [], scale=6): + with gr.Row(elem_id=f"{self.id_part}_prompt_row", elem_classes=["prompt-row"]): + self.prompt = gr.Textbox(label="Prompt", elem_id=f"{self.id_part}_prompt", show_label=False, lines=3, placeholder="Prompt (press Ctrl+Enter or Alt+Enter to generate)", elem_classes=["prompt"]) + self.prompt_img = gr.File(label="", elem_id=f"{self.id_part}_prompt_image", file_count="single", type="binary", visible=False) + + with gr.Row(elem_id=f"{self.id_part}_neg_prompt_row", elem_classes=["prompt-row"]): + self.negative_prompt = gr.Textbox(label="Negative prompt", elem_id=f"{self.id_part}_neg_prompt", show_label=False, lines=3, placeholder="Negative prompt (press Ctrl+Enter or Alt+Enter to generate)", elem_classes=["prompt"]) + + self.prompt_img.change( + fn=modules.images.image_data, + inputs=[self.prompt_img], + outputs=[self.prompt, self.prompt_img], + show_progress=False, + ) + + def create_submit_box(self): + with gr.Row(elem_id=f"{self.id_part}_generate_box", elem_classes=["generate-box"] + (["generate-box-compact"] if self.is_compact else []), render=not self.is_compact) as submit_box: + self.submit_box = submit_box + + self.interrupt = gr.Button('Interrupt', elem_id=f"{self.id_part}_interrupt", elem_classes="generate-box-interrupt") + self.skip = gr.Button('Skip', elem_id=f"{self.id_part}_skip", elem_classes="generate-box-skip") + self.submit = gr.Button('Generate', elem_id=f"{self.id_part}_generate", variant='primary') + + self.skip.click( + fn=lambda: shared.state.skip(), + inputs=[], + outputs=[], + ) + + self.interrupt.click( + fn=lambda: shared.state.interrupt(), + inputs=[], + outputs=[], + ) + + def create_tools_row(self): + with gr.Row(elem_id=f"{self.id_part}_tools"): + from modules.ui import paste_symbol, clear_prompt_symbol, restore_progress_symbol + + self.paste = ToolButton(value=paste_symbol, elem_id="paste", tooltip="Read generation parameters from prompt or last generation if prompt is empty into user interface.") + self.clear_prompt_button = ToolButton(value=clear_prompt_symbol, elem_id=f"{self.id_part}_clear_prompt", tooltip="Clear prompt") + self.apply_styles = ToolButton(value=ui_prompt_styles.styles_materialize_symbol, elem_id=f"{self.id_part}_style_apply", tooltip="Apply all selected styles to prompts.") + + if self.is_img2img: + self.button_interrogate = ToolButton('📎', tooltip='Interrogate CLIP - use CLIP neural network to create a text describing the image, and put it into the prompt field', elem_id="interrogate") + self.button_deepbooru = ToolButton('📦', tooltip='Interrogate DeepBooru - use DeepBooru neural network to create a text describing the image, and put it into the prompt field', elem_id="deepbooru") + + self.restore_progress_button = ToolButton(value=restore_progress_symbol, elem_id=f"{self.id_part}_restore_progress", visible=False, tooltip="Restore progress") + + self.token_counter = gr.HTML(value="0/75", elem_id=f"{self.id_part}_token_counter", elem_classes=["token-counter"]) + self.token_button = gr.Button(visible=False, elem_id=f"{self.id_part}_token_button") + self.negative_token_counter = gr.HTML(value="0/75", elem_id=f"{self.id_part}_negative_token_counter", elem_classes=["token-counter"]) + self.negative_token_button = gr.Button(visible=False, elem_id=f"{self.id_part}_negative_token_button") + + self.clear_prompt_button.click( + fn=lambda *x: x, + _js="confirm_clear_prompt", + inputs=[self.prompt, self.negative_prompt], + outputs=[self.prompt, self.negative_prompt], + ) + + def create_styles_ui(self): + self.ui_styles = ui_prompt_styles.UiPromptStyles(self.id_part, self.prompt, self.negative_prompt) + self.ui_styles.setup_apply_button(self.apply_styles) diff --git a/style.css b/style.css index 9a1181e8b33..731620226e9 100644 --- a/style.css +++ b/style.css @@ -296,6 +296,13 @@ input[type="checkbox"].input-accordion-checkbox{ min-height: 4.5em; } +#txt2img_generate, #img2img_generate { + min-height: 4.5em; +} +.generate-box-compact #txt2img_generate, .generate-box-compact #img2img_generate { + min-height: 3em; +} + @media screen and (min-width: 2500px) { #txt2img_gallery, #img2img_gallery { min-height: 768px; @@ -403,6 +410,15 @@ div#extras_scale_to_tab div.form{ min-width: 0.5em; } +div.toprow-compact-stylerow{ + margin: 0.5em 0; +} + +div.toprow-compact-tools{ + min-width: fit-content !important; + max-width: fit-content; +} + /* settings */ #quicksettings { align-items: end; @@ -525,7 +541,8 @@ table.popup-table .link{ height: 20px; background: #b4c0cc; border-radius: 3px !important; - top: -20px; + top: -14px; + left: 0px; width: 100%; } @@ -823,6 +840,10 @@ footer { /* extra networks UI */ +.extra-page .prompt{ + margin: 0 0 0.5em 0; +} + .extra-network-cards{ height: calc(100vh - 24rem); overflow: clip scroll; From c3699d4fd185d5a7285c5519f9bb4b6fec236d9f Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sun, 5 Nov 2023 19:23:48 +0300 Subject: [PATCH 0236/1174] compact prompt option disabled by default --- modules/shared_options.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/shared_options.py b/modules/shared_options.py index 4e3d7541699..a9964fcbbfe 100644 --- a/modules/shared_options.py +++ b/modules/shared_options.py @@ -272,7 +272,7 @@ "hires_fix_show_sampler": OptionInfo(False, "Hires fix: show hires checkpoint and sampler selection").needs_reload_ui(), "hires_fix_show_prompts": OptionInfo(False, "Hires fix: show hires prompt and negative prompt").needs_reload_ui(), "disable_token_counters": OptionInfo(False, "Disable prompt token counters").needs_reload_ui(), - "compact_prompt_box": OptionInfo(True, "Compact prompt layout").info("puts prompt and negative prompt inside the Generate tab, leaving more vertical space for the image on the right").needs_reload_ui(), + "compact_prompt_box": OptionInfo(False, "Compact prompt layout").info("puts prompt and negative prompt inside the Generate tab, leaving more vertical space for the image on the right").needs_reload_ui(), })) From 80d639a440929e9effe4620ce74333de283e7efc Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sun, 5 Nov 2023 19:32:21 +0300 Subject: [PATCH 0237/1174] linter --- modules/sd_hijack.py | 2 +- modules/sd_models.py | 4 ++-- modules/sd_models_types.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 4fff418d757..c6d17764e41 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -183,7 +183,7 @@ def apply_optimizations(self, option=None): except Exception as e: errors.display(e, "applying cross attention optimization") undo_optimizations() - + def conv_ssd(self, m): delattr(m.model.diffusion_model.middle_block, '1') delattr(m.model.diffusion_model.middle_block, '2') diff --git a/modules/sd_models.py b/modules/sd_models.py index d76dc5803ae..1036a3b1316 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -355,10 +355,10 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer model.is_ssd = model.is_sdxl and 'model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_q.weight' not in state_dict.keys() if model.is_sdxl: sd_models_xl.extend_sdxl(model) - + if model.is_ssd: sd_hijack.model_hijack.conv_ssd(model) - + if shared.opts.sd_checkpoint_cache > 0: # cache newly loaded model checkpoints_loaded[checkpoint_info] = state_dict.copy() diff --git a/modules/sd_models_types.py b/modules/sd_models_types.py index 1f28942a446..f911fbb68db 100644 --- a/modules/sd_models_types.py +++ b/modules/sd_models_types.py @@ -23,7 +23,7 @@ class WebuiSdModel(LatentDiffusion): is_sdxl: bool """True if the model's architecture is SDXL or SSD""" - + is_ssd: bool """True if the model is SSD""" From 6ad666e4794a57dd65790dd6a259d5d4330d45ed Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sun, 5 Nov 2023 19:46:20 +0300 Subject: [PATCH 0238/1174] more changes for #13865: fix formatting, rename the function, add comment and add a readme entry --- README.md | 1 + modules/sd_hijack.py | 24 +++++++++++++----------- modules/sd_models.py | 2 +- 3 files changed, 15 insertions(+), 12 deletions(-) diff --git a/README.md b/README.md index c7a4e363fed..25ba070e7e8 100644 --- a/README.md +++ b/README.md @@ -91,6 +91,7 @@ A browser interface based on Gradio library for Stable Diffusion. - Eased resolution restriction: generated image's dimensions must be a multiple of 8 rather than 64 - Now with a license! - Reorder elements in the UI from settings screen +- [Segmind Stable Diffusion](https://huggingface.co/segmind/SSD-1B) support ## Installation and Running Make sure the required [dependencies](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Dependencies) are met and follow the instructions available for: diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index c6d17764e41..fba23c38b2e 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -184,17 +184,19 @@ def apply_optimizations(self, option=None): errors.display(e, "applying cross attention optimization") undo_optimizations() - def conv_ssd(self, m): - delattr(m.model.diffusion_model.middle_block, '1') - delattr(m.model.diffusion_model.middle_block, '2') - for i in ['9','8','7','6','5','4']: - delattr(m.model.diffusion_model.input_blocks[7][1].transformer_blocks,i) - delattr(m.model.diffusion_model.input_blocks[8][1].transformer_blocks,i) - delattr(m.model.diffusion_model.output_blocks[0][1].transformer_blocks,i) - delattr(m.model.diffusion_model.output_blocks[1][1].transformer_blocks,i) - delattr(m.model.diffusion_model.output_blocks[4][1].transformer_blocks,'1') - delattr(m.model.diffusion_model.output_blocks[5][1].transformer_blocks,'1') - devices.torch_gc() + def convert_sdxl_to_ssd(self, m): + """Converts an SDXL model to a Segmind Stable Diffusion model (see https://huggingface.co/segmind/SSD-1B)""" + + delattr(m.model.diffusion_model.middle_block, '1') + delattr(m.model.diffusion_model.middle_block, '2') + for i in ['9', '8', '7', '6', '5', '4']: + delattr(m.model.diffusion_model.input_blocks[7][1].transformer_blocks, i) + delattr(m.model.diffusion_model.input_blocks[8][1].transformer_blocks, i) + delattr(m.model.diffusion_model.output_blocks[0][1].transformer_blocks, i) + delattr(m.model.diffusion_model.output_blocks[1][1].transformer_blocks, i) + delattr(m.model.diffusion_model.output_blocks[4][1].transformer_blocks, '1') + delattr(m.model.diffusion_model.output_blocks[5][1].transformer_blocks, '1') + devices.torch_gc() def hijack(self, m): conditioner = getattr(m, 'conditioner', None) diff --git a/modules/sd_models.py b/modules/sd_models.py index 1036a3b1316..841402e8629 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -357,7 +357,7 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer sd_models_xl.extend_sdxl(model) if model.is_ssd: - sd_hijack.model_hijack.conv_ssd(model) + sd_hijack.model_hijack.convert_sdxl_to_ssd(model) if shared.opts.sd_checkpoint_cache > 0: # cache newly loaded model From 656437e0a50212778746c67785d23b0ea14a8837 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Mon, 6 Nov 2023 10:32:21 +0300 Subject: [PATCH 0239/1174] fix img2img_tabs error --- modules/ui.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/modules/ui.py b/modules/ui.py index 2454eb36bbf..accdb457256 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -689,19 +689,19 @@ def copy_image(img): with gr.Column(scale=4): inpaint_full_res_padding = gr.Slider(label='Only masked padding, pixels', minimum=0, maximum=256, step=4, value=32, elem_id="img2img_inpaint_full_res_padding") - def select_img2img_tab(tab): - return gr.update(visible=tab in [2, 3, 4]), gr.update(visible=tab == 3), - - for i, elem in enumerate(img2img_tabs): - elem.select( - fn=lambda tab=i: select_img2img_tab(tab), - inputs=[], - outputs=[inpaint_controls, mask_alpha], - ) - if category not in {"accordions"}: scripts.scripts_img2img.setup_ui_for_section(category) + def select_img2img_tab(tab): + return gr.update(visible=tab in [2, 3, 4]), gr.update(visible=tab == 3), + + for i, elem in enumerate(img2img_tabs): + elem.select( + fn=lambda tab=i: select_img2img_tab(tab), + inputs=[], + outputs=[inpaint_controls, mask_alpha], + ) + img2img_gallery, generation_info, html_info, html_log = create_output_panel("img2img", opts.outdir_img2img_samples, toprow) img2img_args = dict( From 9c1c0da026cb7ef091a0f3fa24b14ae8634f6de5 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Mon, 6 Nov 2023 11:17:36 +0300 Subject: [PATCH 0240/1174] fix exception related to the pix2pix --- modules/sd_hijack.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index fba23c38b2e..0157e19f003 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -256,8 +256,12 @@ def flatten(el): self.layers = flatten(m) + import modules.models.diffusion.ddpm_edit + if isinstance(m, ldm.models.diffusion.ddpm.LatentDiffusion): sd_unet.original_forward = ldm_original_forward + elif isinstance(m, modules.models.diffusion.ddpm_edit.LatentDiffusion): + sd_unet.original_forward = ldm_original_forward elif isinstance(m, sgm.models.diffusion.DiffusionEngine): sd_unet.original_forward = sgm_original_forward else: From 9ba991cad8a15a99f71f5b2ec5feff7dd9d270d7 Mon Sep 17 00:00:00 2001 From: GerryDE Date: Tue, 7 Nov 2023 03:09:08 +0100 Subject: [PATCH 0241/1174] Add option to set notification sound volume --- javascript/notification.js | 6 +++++- modules/shared_options.py | 1 + 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/javascript/notification.js b/javascript/notification.js index 6d79956125c..3ee972ae166 100644 --- a/javascript/notification.js +++ b/javascript/notification.js @@ -26,7 +26,11 @@ onAfterUiUpdate(function() { lastHeadImg = headImg; // play notification sound if available - gradioApp().querySelector('#audio_notification audio')?.play(); + const notificationAudio = gradioApp().querySelector('#audio_notification audio'); + if (notificationAudio) { + notificationAudio.volume = opts.notification_volume / 100.0 || 1.0; + notificationAudio.play(); + } if (document.hasFocus()) return; diff --git a/modules/shared_options.py b/modules/shared_options.py index a9964fcbbfe..d40db5306ee 100644 --- a/modules/shared_options.py +++ b/modules/shared_options.py @@ -64,6 +64,7 @@ "save_incomplete_images": OptionInfo(False, "Save incomplete images").info("save images that has been interrupted in mid-generation; even if not saved, they will still show up in webui output."), "notification_audio": OptionInfo(True, "Play notification sound after image generation").info("notification.mp3 should be present in the root directory").needs_reload_ui(), + "notification_volume": OptionInfo(100, "Notification sound volume", gr.Slider, {"minimum": 0, "maximum": 100, "step": 1}).info("in %"), })) options_templates.update(options_section(('saving-paths', "Paths for saving"), { From 5e80d9ee99c5899e5e2b130408ffb65a0585a62a Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Tue, 7 Nov 2023 11:33:16 +0300 Subject: [PATCH 0242/1174] fix pix2pix producing bad results --- modules/processing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/processing.py b/modules/processing.py index 70ad1ebed0a..b0e240a4640 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -296,7 +296,7 @@ def depth2img_image_conditioning(self, source_image): return conditioning def edit_image_conditioning(self, source_image): - conditioning_image = images_tensor_to_samples(source_image*0.5+0.5, approximation_indexes.get(opts.sd_vae_encode_method)) + conditioning_image = shared.sd_model.encode_first_stage(source_image).mode() return conditioning_image From 6b9795849d497b41514aa9462690cf7c2802e4f6 Mon Sep 17 00:00:00 2001 From: hako-mikan <122196982+hako-mikan@users.noreply.github.com> Date: Thu, 9 Nov 2023 20:23:37 +0900 Subject: [PATCH 0243/1174] Fix model switch bug --- extensions-builtin/Lora/networks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/extensions-builtin/Lora/networks.py b/extensions-builtin/Lora/networks.py index 96f935b236f..a21ea0fa356 100644 --- a/extensions-builtin/Lora/networks.py +++ b/extensions-builtin/Lora/networks.py @@ -418,7 +418,7 @@ def network_forward(module, input, original_forward): def network_reset_cached_weight(self: Union[torch.nn.Conv2d, torch.nn.Linear]): self.network_current_names = () self.network_weights_backup = None - + self.network_bias_backup = None def network_Linear_forward(self, input): if shared.opts.lora_functional: From a625a7bb817cbf6a97d2030dc3a8015a046bd388 Mon Sep 17 00:00:00 2001 From: Emily Zeng Date: Thu, 9 Nov 2023 13:15:06 -0500 Subject: [PATCH 0244/1174] moved nested with to single line to remove extra tabs --- modules/ui.py | 591 +++++++++++++++++++++++++------------------------- 1 file changed, 295 insertions(+), 296 deletions(-) diff --git a/modules/ui.py b/modules/ui.py index 4a3e60d13c9..0faccbd3488 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -270,88 +270,87 @@ def create_ui(): extra_tabs.__enter__() with gr.Tab("Generation", id="txt2img_generation") as txt2img_generation_tab, ResizeHandleRow(equal_height=False): - with gr.Accordion("Open for Settings", open=False) if shared.opts.img2img_settings_accordion else gr.Group(): - with gr.Column(variant='compact', elem_id="txt2img_settings"): - scripts.scripts_txt2img.prepare_ui() + with gr.Accordion("Open for Settings", open=False), gr.Column(variant='compact', elem_id="txt2img_settings") if shared.opts.img2img_settings_accordion else gr.Column(variant='compact', elem_id="txt2img_settings"): + scripts.scripts_txt2img.prepare_ui() - for category in ordered_ui_categories(): - if category == "prompt": - toprow.create_inline_toprow_prompts() + for category in ordered_ui_categories(): + if category == "prompt": + toprow.create_inline_toprow_prompts() - if category == "sampler": - steps, sampler_name = create_sampler_and_steps_selection(sd_samplers.visible_sampler_names(), "txt2img") + if category == "sampler": + steps, sampler_name = create_sampler_and_steps_selection(sd_samplers.visible_sampler_names(), "txt2img") - elif category == "dimensions": - with FormRow(): - with gr.Column(elem_id="txt2img_column_size", scale=4): - width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="txt2img_width") - height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="txt2img_height") + elif category == "dimensions": + with FormRow(): + with gr.Column(elem_id="txt2img_column_size", scale=4): + width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="txt2img_width") + height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="txt2img_height") - with gr.Column(elem_id="txt2img_dimensions_row", scale=1, elem_classes="dimensions-tools"): - res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="txt2img_res_switch_btn", tooltip="Switch width/height") + with gr.Column(elem_id="txt2img_dimensions_row", scale=1, elem_classes="dimensions-tools"): + res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="txt2img_res_switch_btn", tooltip="Switch width/height") - if opts.dimensions_and_batch_together: - with gr.Column(elem_id="txt2img_column_batch"): - batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="txt2img_batch_count") - batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="txt2img_batch_size") + if opts.dimensions_and_batch_together: + with gr.Column(elem_id="txt2img_column_batch"): + batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="txt2img_batch_count") + batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="txt2img_batch_size") - elif category == "cfg": - with gr.Row(): - cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0, elem_id="txt2img_cfg_scale") + elif category == "cfg": + with gr.Row(): + cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0, elem_id="txt2img_cfg_scale") - elif category == "checkboxes": - with FormRow(elem_classes="checkboxes-row", variant="compact"): - pass + elif category == "checkboxes": + with FormRow(elem_classes="checkboxes-row", variant="compact"): + pass - elif category == "accordions": - with gr.Row(elem_id="txt2img_accordions", elem_classes="accordions"): - with InputAccordion(False, label="Hires. fix", elem_id="txt2img_hr") as enable_hr: - with enable_hr.extra(): - hr_final_resolution = FormHTML(value="", elem_id="txtimg_hr_finalres", label="Upscaled resolution", interactive=False, min_width=0) + elif category == "accordions": + with gr.Row(elem_id="txt2img_accordions", elem_classes="accordions"): + with InputAccordion(False, label="Hires. fix", elem_id="txt2img_hr") as enable_hr: + with enable_hr.extra(): + hr_final_resolution = FormHTML(value="", elem_id="txtimg_hr_finalres", label="Upscaled resolution", interactive=False, min_width=0) - with FormRow(elem_id="txt2img_hires_fix_row1", variant="compact"): - hr_upscaler = gr.Dropdown(label="Upscaler", elem_id="txt2img_hr_upscaler", choices=[*shared.latent_upscale_modes, *[x.name for x in shared.sd_upscalers]], value=shared.latent_upscale_default_mode) - hr_second_pass_steps = gr.Slider(minimum=0, maximum=150, step=1, label='Hires steps', value=0, elem_id="txt2img_hires_steps") - denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.7, elem_id="txt2img_denoising_strength") + with FormRow(elem_id="txt2img_hires_fix_row1", variant="compact"): + hr_upscaler = gr.Dropdown(label="Upscaler", elem_id="txt2img_hr_upscaler", choices=[*shared.latent_upscale_modes, *[x.name for x in shared.sd_upscalers]], value=shared.latent_upscale_default_mode) + hr_second_pass_steps = gr.Slider(minimum=0, maximum=150, step=1, label='Hires steps', value=0, elem_id="txt2img_hires_steps") + denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.7, elem_id="txt2img_denoising_strength") - with FormRow(elem_id="txt2img_hires_fix_row2", variant="compact"): - hr_scale = gr.Slider(minimum=1.0, maximum=4.0, step=0.05, label="Upscale by", value=2.0, elem_id="txt2img_hr_scale") - hr_resize_x = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize width to", value=0, elem_id="txt2img_hr_resize_x") - hr_resize_y = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize height to", value=0, elem_id="txt2img_hr_resize_y") + with FormRow(elem_id="txt2img_hires_fix_row2", variant="compact"): + hr_scale = gr.Slider(minimum=1.0, maximum=4.0, step=0.05, label="Upscale by", value=2.0, elem_id="txt2img_hr_scale") + hr_resize_x = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize width to", value=0, elem_id="txt2img_hr_resize_x") + hr_resize_y = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize height to", value=0, elem_id="txt2img_hr_resize_y") - with FormRow(elem_id="txt2img_hires_fix_row3", variant="compact", visible=opts.hires_fix_show_sampler) as hr_sampler_container: + with FormRow(elem_id="txt2img_hires_fix_row3", variant="compact", visible=opts.hires_fix_show_sampler) as hr_sampler_container: - hr_checkpoint_name = gr.Dropdown(label='Hires checkpoint', elem_id="hr_checkpoint", choices=["Use same checkpoint"] + modules.sd_models.checkpoint_tiles(use_short=True), value="Use same checkpoint") - create_refresh_button(hr_checkpoint_name, modules.sd_models.list_models, lambda: {"choices": ["Use same checkpoint"] + modules.sd_models.checkpoint_tiles(use_short=True)}, "hr_checkpoint_refresh") + hr_checkpoint_name = gr.Dropdown(label='Hires checkpoint', elem_id="hr_checkpoint", choices=["Use same checkpoint"] + modules.sd_models.checkpoint_tiles(use_short=True), value="Use same checkpoint") + create_refresh_button(hr_checkpoint_name, modules.sd_models.list_models, lambda: {"choices": ["Use same checkpoint"] + modules.sd_models.checkpoint_tiles(use_short=True)}, "hr_checkpoint_refresh") - hr_sampler_name = gr.Dropdown(label='Hires sampling method', elem_id="hr_sampler", choices=["Use same sampler"] + sd_samplers.visible_sampler_names(), value="Use same sampler") + hr_sampler_name = gr.Dropdown(label='Hires sampling method', elem_id="hr_sampler", choices=["Use same sampler"] + sd_samplers.visible_sampler_names(), value="Use same sampler") - with FormRow(elem_id="txt2img_hires_fix_row4", variant="compact", visible=opts.hires_fix_show_prompts) as hr_prompts_container: - with gr.Column(scale=80): - with gr.Row(): - hr_prompt = gr.Textbox(label="Hires prompt", elem_id="hires_prompt", show_label=False, lines=3, placeholder="Prompt for hires fix pass.\nLeave empty to use the same prompt as in first pass.", elem_classes=["prompt"]) - with gr.Column(scale=80): - with gr.Row(): - hr_negative_prompt = gr.Textbox(label="Hires negative prompt", elem_id="hires_neg_prompt", show_label=False, lines=3, placeholder="Negative prompt for hires fix pass.\nLeave empty to use the same negative prompt as in first pass.", elem_classes=["prompt"]) + with FormRow(elem_id="txt2img_hires_fix_row4", variant="compact", visible=opts.hires_fix_show_prompts) as hr_prompts_container: + with gr.Column(scale=80): + with gr.Row(): + hr_prompt = gr.Textbox(label="Hires prompt", elem_id="hires_prompt", show_label=False, lines=3, placeholder="Prompt for hires fix pass.\nLeave empty to use the same prompt as in first pass.", elem_classes=["prompt"]) + with gr.Column(scale=80): + with gr.Row(): + hr_negative_prompt = gr.Textbox(label="Hires negative prompt", elem_id="hires_neg_prompt", show_label=False, lines=3, placeholder="Negative prompt for hires fix pass.\nLeave empty to use the same negative prompt as in first pass.", elem_classes=["prompt"]) - scripts.scripts_txt2img.setup_ui_for_section(category) + scripts.scripts_txt2img.setup_ui_for_section(category) - elif category == "batch": - if not opts.dimensions_and_batch_together: - with FormRow(elem_id="txt2img_column_batch"): - batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="txt2img_batch_count") - batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="txt2img_batch_size") + elif category == "batch": + if not opts.dimensions_and_batch_together: + with FormRow(elem_id="txt2img_column_batch"): + batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="txt2img_batch_count") + batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="txt2img_batch_size") - elif category == "override_settings": - with FormRow(elem_id="txt2img_override_settings_row") as row: - override_settings = create_override_settings_dropdown('txt2img', row) + elif category == "override_settings": + with FormRow(elem_id="txt2img_override_settings_row") as row: + override_settings = create_override_settings_dropdown('txt2img', row) - elif category == "scripts": - with FormGroup(elem_id="txt2img_script_container"): - custom_inputs = scripts.scripts_txt2img.setup_ui() + elif category == "scripts": + with FormGroup(elem_id="txt2img_script_container"): + custom_inputs = scripts.scripts_txt2img.setup_ui() - if category not in {"accordions"}: - scripts.scripts_txt2img.setup_ui_for_section(category) + if category not in {"accordions"}: + scripts.scripts_txt2img.setup_ui_for_section(category) hr_resolution_preview_inputs = [enable_hr, width, height, hr_scale, hr_resize_x, hr_resize_y] @@ -490,258 +489,258 @@ def create_ui(): extra_tabs.__enter__() with gr.Tab("Generation", id="img2img_generation") as img2img_generation_tab, ResizeHandleRow(equal_height=False): - with gr.Accordion("Open for Settings", open=False) if shared.opts.img2img_settings_accordion else gr.Group(): - with gr.Column(variant='compact', elem_id="img2img_settings"): - copy_image_buttons = [] - copy_image_destinations = {} - - def add_copy_image_controls(tab_name, elem): - with gr.Row(variant="compact", elem_id=f"img2img_copy_to_{tab_name}"): - gr.HTML("Copy image to: ", elem_id=f"img2img_label_copy_to_{tab_name}") - - for title, name in zip(['img2img', 'sketch', 'inpaint', 'inpaint sketch'], ['img2img', 'sketch', 'inpaint', 'inpaint_sketch']): - if name == tab_name: - gr.Button(title, interactive=False) - copy_image_destinations[name] = elem - continue - - button = gr.Button(title) - copy_image_buttons.append((button, name, elem)) - - scripts.scripts_img2img.prepare_ui() - - for category in ordered_ui_categories(): - if category == "prompt": - toprow.create_inline_toprow_prompts() - - if category == "image": - with gr.Tabs(elem_id="mode_img2img"): - img2img_selected_tab = gr.State(0) - - with gr.TabItem('img2img', id='img2img', elem_id="img2img_img2img_tab") as tab_img2img: - init_img = gr.Image(label="Image for img2img", elem_id="img2img_image", show_label=False, source="upload", interactive=True, type="pil", tool="editor", image_mode="RGBA", height=opts.img2img_editor_height) - add_copy_image_controls('img2img', init_img) - - with gr.TabItem('Sketch', id='img2img_sketch', elem_id="img2img_img2img_sketch_tab") as tab_sketch: - sketch = gr.Image(label="Image for img2img", elem_id="img2img_sketch", show_label=False, source="upload", interactive=True, type="pil", tool="color-sketch", image_mode="RGB", height=opts.img2img_editor_height, brush_color=opts.img2img_sketch_default_brush_color) - add_copy_image_controls('sketch', sketch) - - with gr.TabItem('Inpaint', id='inpaint', elem_id="img2img_inpaint_tab") as tab_inpaint: - init_img_with_mask = gr.Image(label="Image for inpainting with mask", show_label=False, elem_id="img2maskimg", source="upload", interactive=True, type="pil", tool="sketch", image_mode="RGBA", height=opts.img2img_editor_height, brush_color=opts.img2img_inpaint_mask_brush_color) - add_copy_image_controls('inpaint', init_img_with_mask) - - with gr.TabItem('Inpaint sketch', id='inpaint_sketch', elem_id="img2img_inpaint_sketch_tab") as tab_inpaint_color: - inpaint_color_sketch = gr.Image(label="Color sketch inpainting", show_label=False, elem_id="inpaint_sketch", source="upload", interactive=True, type="pil", tool="color-sketch", image_mode="RGB", height=opts.img2img_editor_height, brush_color=opts.img2img_inpaint_sketch_default_brush_color) - inpaint_color_sketch_orig = gr.State(None) - add_copy_image_controls('inpaint_sketch', inpaint_color_sketch) - - def update_orig(image, state): - if image is not None: - same_size = state is not None and state.size == image.size - has_exact_match = np.any(np.all(np.array(image) == np.array(state), axis=-1)) - edited = same_size and has_exact_match - return image if not edited or state is None else state - - inpaint_color_sketch.change(update_orig, [inpaint_color_sketch, inpaint_color_sketch_orig], inpaint_color_sketch_orig) - - with gr.TabItem('Inpaint upload', id='inpaint_upload', elem_id="img2img_inpaint_upload_tab") as tab_inpaint_upload: - init_img_inpaint = gr.Image(label="Image for img2img", show_label=False, source="upload", interactive=True, type="pil", elem_id="img_inpaint_base") - init_mask_inpaint = gr.Image(label="Mask", source="upload", interactive=True, type="pil", image_mode="RGBA", elem_id="img_inpaint_mask") - - with gr.TabItem('Batch', id='batch', elem_id="img2img_batch_tab") as tab_batch: - hidden = '
Disabled when launched with --hide-ui-dir-config.' if shared.cmd_opts.hide_ui_dir_config else '' - gr.HTML( - "

Process images in a directory on the same machine where the server is running." + - "
Use an empty output directory to save pictures normally instead of writing to the output directory." + - f"
Add inpaint batch mask directory to enable inpaint batch processing." - f"{hidden}

" - ) - img2img_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs, elem_id="img2img_batch_input_dir") - img2img_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, elem_id="img2img_batch_output_dir") - img2img_batch_inpaint_mask_dir = gr.Textbox(label="Inpaint batch mask directory (required for inpaint batch processing only)", **shared.hide_dirs, elem_id="img2img_batch_inpaint_mask_dir") - with gr.Accordion("PNG info", open=False): - img2img_batch_use_png_info = gr.Checkbox(label="Append png info to prompts", **shared.hide_dirs, elem_id="img2img_batch_use_png_info") - img2img_batch_png_info_dir = gr.Textbox(label="PNG info directory", **shared.hide_dirs, placeholder="Leave empty to use input directory", elem_id="img2img_batch_png_info_dir") - img2img_batch_png_info_props = gr.CheckboxGroup(["Prompt", "Negative prompt", "Seed", "CFG scale", "Sampler", "Steps", "Model hash"], label="Parameters to take from png info", info="Prompts from png info will be appended to prompts set in ui.") - - img2img_tabs = [tab_img2img, tab_sketch, tab_inpaint, tab_inpaint_color, tab_inpaint_upload, tab_batch] - - for i, tab in enumerate(img2img_tabs): - tab.select(fn=lambda tabnum=i: tabnum, inputs=[], outputs=[img2img_selected_tab]) - - def copy_image(img): - if isinstance(img, dict) and 'image' in img: - return img['image'] - - return img - - for button, name, elem in copy_image_buttons: - button.click( - fn=copy_image, - inputs=[elem], - outputs=[copy_image_destinations[name]], - ) - button.click( - fn=lambda: None, - _js=f"switch_to_{name.replace(' ', '_')}", - inputs=[], - outputs=[], + with gr.Accordion("Open for Settings", open=False), gr.Column(variant='compact', elem_id="img2img_settings") if shared.opts.img2img_settings_accordion else gr.Column(variant='compact', elem_id="img2img_settings"): + copy_image_buttons = [] + copy_image_destinations = {} + + def add_copy_image_controls(tab_name, elem): + with gr.Row(variant="compact", elem_id=f"img2img_copy_to_{tab_name}"): + gr.HTML("Copy image to: ", elem_id=f"img2img_label_copy_to_{tab_name}") + + for title, name in zip(['img2img', 'sketch', 'inpaint', 'inpaint sketch'], ['img2img', 'sketch', 'inpaint', 'inpaint_sketch']): + if name == tab_name: + gr.Button(title, interactive=False) + copy_image_destinations[name] = elem + continue + + button = gr.Button(title) + copy_image_buttons.append((button, name, elem)) + + scripts.scripts_img2img.prepare_ui() + + for category in ordered_ui_categories(): + if category == "prompt": + toprow.create_inline_toprow_prompts() + + if category == "image": + with gr.Tabs(elem_id="mode_img2img"): + img2img_selected_tab = gr.State(0) + + with gr.TabItem('img2img', id='img2img', elem_id="img2img_img2img_tab") as tab_img2img: + init_img = gr.Image(label="Image for img2img", elem_id="img2img_image", show_label=False, source="upload", interactive=True, type="pil", tool="editor", image_mode="RGBA", height=opts.img2img_editor_height) + add_copy_image_controls('img2img', init_img) + + with gr.TabItem('Sketch', id='img2img_sketch', elem_id="img2img_img2img_sketch_tab") as tab_sketch: + sketch = gr.Image(label="Image for img2img", elem_id="img2img_sketch", show_label=False, source="upload", interactive=True, type="pil", tool="color-sketch", image_mode="RGB", height=opts.img2img_editor_height, brush_color=opts.img2img_sketch_default_brush_color) + add_copy_image_controls('sketch', sketch) + + with gr.TabItem('Inpaint', id='inpaint', elem_id="img2img_inpaint_tab") as tab_inpaint: + init_img_with_mask = gr.Image(label="Image for inpainting with mask", show_label=False, elem_id="img2maskimg", source="upload", interactive=True, type="pil", tool="sketch", image_mode="RGBA", height=opts.img2img_editor_height, brush_color=opts.img2img_inpaint_mask_brush_color) + add_copy_image_controls('inpaint', init_img_with_mask) + + with gr.TabItem('Inpaint sketch', id='inpaint_sketch', elem_id="img2img_inpaint_sketch_tab") as tab_inpaint_color: + inpaint_color_sketch = gr.Image(label="Color sketch inpainting", show_label=False, elem_id="inpaint_sketch", source="upload", interactive=True, type="pil", tool="color-sketch", image_mode="RGB", height=opts.img2img_editor_height, brush_color=opts.img2img_inpaint_sketch_default_brush_color) + inpaint_color_sketch_orig = gr.State(None) + add_copy_image_controls('inpaint_sketch', inpaint_color_sketch) + + def update_orig(image, state): + if image is not None: + same_size = state is not None and state.size == image.size + has_exact_match = np.any(np.all(np.array(image) == np.array(state), axis=-1)) + edited = same_size and has_exact_match + return image if not edited or state is None else state + + inpaint_color_sketch.change(update_orig, [inpaint_color_sketch, inpaint_color_sketch_orig], inpaint_color_sketch_orig) + + with gr.TabItem('Inpaint upload', id='inpaint_upload', elem_id="img2img_inpaint_upload_tab") as tab_inpaint_upload: + init_img_inpaint = gr.Image(label="Image for img2img", show_label=False, source="upload", interactive=True, type="pil", elem_id="img_inpaint_base") + init_mask_inpaint = gr.Image(label="Mask", source="upload", interactive=True, type="pil", image_mode="RGBA", elem_id="img_inpaint_mask") + + with gr.TabItem('Batch', id='batch', elem_id="img2img_batch_tab") as tab_batch: + hidden = '
Disabled when launched with --hide-ui-dir-config.' if shared.cmd_opts.hide_ui_dir_config else '' + gr.HTML( + "

Process images in a directory on the same machine where the server is running." + + "
Use an empty output directory to save pictures normally instead of writing to the output directory." + + f"
Add inpaint batch mask directory to enable inpaint batch processing." + f"{hidden}

" ) + img2img_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs, elem_id="img2img_batch_input_dir") + img2img_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, elem_id="img2img_batch_output_dir") + img2img_batch_inpaint_mask_dir = gr.Textbox(label="Inpaint batch mask directory (required for inpaint batch processing only)", **shared.hide_dirs, elem_id="img2img_batch_inpaint_mask_dir") + with gr.Accordion("PNG info", open=False): + img2img_batch_use_png_info = gr.Checkbox(label="Append png info to prompts", **shared.hide_dirs, elem_id="img2img_batch_use_png_info") + img2img_batch_png_info_dir = gr.Textbox(label="PNG info directory", **shared.hide_dirs, placeholder="Leave empty to use input directory", elem_id="img2img_batch_png_info_dir") + img2img_batch_png_info_props = gr.CheckboxGroup(["Prompt", "Negative prompt", "Seed", "CFG scale", "Sampler", "Steps", "Model hash"], label="Parameters to take from png info", info="Prompts from png info will be appended to prompts set in ui.") + + img2img_tabs = [tab_img2img, tab_sketch, tab_inpaint, tab_inpaint_color, tab_inpaint_upload, tab_batch] + + for i, tab in enumerate(img2img_tabs): + tab.select(fn=lambda tabnum=i: tabnum, inputs=[], outputs=[img2img_selected_tab]) + + def copy_image(img): + if isinstance(img, dict) and 'image' in img: + return img['image'] + + return img + + for button, name, elem in copy_image_buttons: + button.click( + fn=copy_image, + inputs=[elem], + outputs=[copy_image_destinations[name]], + ) + button.click( + fn=lambda: None, + _js=f"switch_to_{name.replace(' ', '_')}", + inputs=[], + outputs=[], + ) + + with FormRow(): + resize_mode = gr.Radio(label="Resize mode", elem_id="resize_mode", choices=["Just resize", "Crop and resize", "Resize and fill", "Just resize (latent upscale)"], type="index", value="Just resize") + + if category == "sampler": + steps, sampler_name = create_sampler_and_steps_selection(sd_samplers.visible_sampler_names(), "img2img") + + elif category == "dimensions": + with FormRow(): + with gr.Column(elem_id="img2img_column_size", scale=4): + selected_scale_tab = gr.State(value=0) + + with gr.Tabs(): + with gr.Tab(label="Resize to", elem_id="img2img_tab_resize_to") as tab_scale_to: + with FormRow(): + with gr.Column(elem_id="img2img_column_size", scale=4): + width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="img2img_width") + height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="img2img_height") + with gr.Column(elem_id="img2img_dimensions_row", scale=1, elem_classes="dimensions-tools"): + res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="img2img_res_switch_btn", tooltip="Switch width/height") + detect_image_size_btn = ToolButton(value=detect_image_size_symbol, elem_id="img2img_detect_image_size_btn", tooltip="Auto detect size from img2img") + + with gr.Tab(label="Resize by", elem_id="img2img_tab_resize_by") as tab_scale_by: + scale_by = gr.Slider(minimum=0.05, maximum=4.0, step=0.05, label="Scale", value=1.0, elem_id="img2img_scale") + + with FormRow(): + scale_by_html = FormHTML(resize_from_to_html(0, 0, 0.0), elem_id="img2img_scale_resolution_preview") + gr.Slider(label="Unused", elem_id="img2img_unused_scale_by_slider") + button_update_resize_to = gr.Button(visible=False, elem_id="img2img_update_resize_to") + + on_change_args = dict( + fn=resize_from_to_html, + _js="currentImg2imgSourceResolution", + inputs=[dummy_component, dummy_component, scale_by], + outputs=scale_by_html, + show_progress=False, + ) - with FormRow(): - resize_mode = gr.Radio(label="Resize mode", elem_id="resize_mode", choices=["Just resize", "Crop and resize", "Resize and fill", "Just resize (latent upscale)"], type="index", value="Just resize") + scale_by.release(**on_change_args) + button_update_resize_to.click(**on_change_args) + + # the code below is meant to update the resolution label after the image in the image selection UI has changed. + # as it is now the event keeps firing continuously for inpaint edits, which ruins the page with constant requests. + # I assume this must be a gradio bug and for now we'll just do it for non-inpaint inputs. + for component in [init_img, sketch]: + component.change(fn=lambda: None, _js="updateImg2imgResizeToTextAfterChangingImage", inputs=[], outputs=[], show_progress=False) + + with FormRow(): + resize_mode = gr.Radio(label="Resize mode", elem_id="resize_mode", choices=["Just resize", "Crop and resize", "Resize and fill", "Just resize (latent upscale)"], type="index", value="Just resize") + + scripts.scripts_img2img.prepare_ui() + + for category in ordered_ui_categories(): + if category == "sampler": + steps, sampler_name = create_sampler_and_steps_selection(sd_samplers.visible_sampler_names(), "img2img") + + elif category == "dimensions": + with FormRow(): + with gr.Column(elem_id="img2img_column_size", scale=4): + selected_scale_tab = gr.State(value=0) + + with gr.Tabs(): + with gr.Tab(label="Resize to", elem_id="img2img_tab_resize_to") as tab_scale_to: + with FormRow(): + with gr.Column(elem_id="img2img_column_size", scale=4): + width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="img2img_width") + height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="img2img_height") + with gr.Column(elem_id="img2img_dimensions_row", scale=1, elem_classes="dimensions-tools"): + res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="img2img_res_switch_btn", tooltip="Switch width/height") + detect_image_size_btn = ToolButton(value=detect_image_size_symbol, elem_id="img2img_detect_image_size_btn", tooltip="Auto detect size from img2img") + + with gr.Tab(label="Resize by", elem_id="img2img_tab_resize_by") as tab_scale_by: + scale_by = gr.Slider(minimum=0.05, maximum=4.0, step=0.05, label="Scale", value=1.0, elem_id="img2img_scale") + + with FormRow(): + scale_by_html = FormHTML(resize_from_to_html(0, 0, 0.0), elem_id="img2img_scale_resolution_preview") + gr.Slider(label="Unused", elem_id="img2img_unused_scale_by_slider") + button_update_resize_to = gr.Button(visible=False, elem_id="img2img_update_resize_to") + + on_change_args = dict( + fn=resize_from_to_html, + _js="currentImg2imgSourceResolution", + inputs=[dummy_component, dummy_component, scale_by], + outputs=scale_by_html, + show_progress=False, + ) - if category == "sampler": - steps, sampler_name = create_sampler_and_steps_selection(sd_samplers.visible_sampler_names(), "img2img") + scale_by.release(**on_change_args) + button_update_resize_to.click(**on_change_args) - elif category == "dimensions": - with FormRow(): - with gr.Column(elem_id="img2img_column_size", scale=4): - selected_scale_tab = gr.State(value=0) - - with gr.Tabs(): - with gr.Tab(label="Resize to", elem_id="img2img_tab_resize_to") as tab_scale_to: - with FormRow(): - with gr.Column(elem_id="img2img_column_size", scale=4): - width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="img2img_width") - height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="img2img_height") - with gr.Column(elem_id="img2img_dimensions_row", scale=1, elem_classes="dimensions-tools"): - res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="img2img_res_switch_btn", tooltip="Switch width/height") - detect_image_size_btn = ToolButton(value=detect_image_size_symbol, elem_id="img2img_detect_image_size_btn", tooltip="Auto detect size from img2img") - - with gr.Tab(label="Resize by", elem_id="img2img_tab_resize_by") as tab_scale_by: - scale_by = gr.Slider(minimum=0.05, maximum=4.0, step=0.05, label="Scale", value=1.0, elem_id="img2img_scale") - - with FormRow(): - scale_by_html = FormHTML(resize_from_to_html(0, 0, 0.0), elem_id="img2img_scale_resolution_preview") - gr.Slider(label="Unused", elem_id="img2img_unused_scale_by_slider") - button_update_resize_to = gr.Button(visible=False, elem_id="img2img_update_resize_to") - - on_change_args = dict( - fn=resize_from_to_html, - _js="currentImg2imgSourceResolution", - inputs=[dummy_component, dummy_component, scale_by], - outputs=scale_by_html, - show_progress=False, - ) - - scale_by.release(**on_change_args) - button_update_resize_to.click(**on_change_args) - - # the code below is meant to update the resolution label after the image in the image selection UI has changed. - # as it is now the event keeps firing continuously for inpaint edits, which ruins the page with constant requests. - # I assume this must be a gradio bug and for now we'll just do it for non-inpaint inputs. - for component in [init_img, sketch]: - component.change(fn=lambda: None, _js="updateImg2imgResizeToTextAfterChangingImage", inputs=[], outputs=[], show_progress=False) + # the code below is meant to update the resolution label after the image in the image selection UI has changed. + # as it is now the event keeps firing continuously for inpaint edits, which ruins the page with constant requests. + # I assume this must be a gradio bug and for now we'll just do it for non-inpaint inputs. + for component in [init_img, sketch]: + component.change(fn=lambda: None, _js="updateImg2imgResizeToTextAfterChangingImage", inputs=[], outputs=[], show_progress=False) - with FormRow(): - resize_mode = gr.Radio(label="Resize mode", elem_id="resize_mode", choices=["Just resize", "Crop and resize", "Resize and fill", "Just resize (latent upscale)"], type="index", value="Just resize") + tab_scale_to.select(fn=lambda: 0, inputs=[], outputs=[selected_scale_tab]) + tab_scale_by.select(fn=lambda: 1, inputs=[], outputs=[selected_scale_tab]) - scripts.scripts_img2img.prepare_ui() + if opts.dimensions_and_batch_together: + with gr.Column(elem_id="img2img_column_batch"): + batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="img2img_batch_count") + batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="img2img_batch_size") - for category in ordered_ui_categories(): - if category == "sampler": - steps, sampler_name = create_sampler_and_steps_selection(sd_samplers.visible_sampler_names(), "img2img") + elif category == "denoising": + denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.75, elem_id="img2img_denoising_strength") - elif category == "dimensions": - with FormRow(): - with gr.Column(elem_id="img2img_column_size", scale=4): - selected_scale_tab = gr.State(value=0) - - with gr.Tabs(): - with gr.Tab(label="Resize to", elem_id="img2img_tab_resize_to") as tab_scale_to: - with FormRow(): - with gr.Column(elem_id="img2img_column_size", scale=4): - width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="img2img_width") - height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="img2img_height") - with gr.Column(elem_id="img2img_dimensions_row", scale=1, elem_classes="dimensions-tools"): - res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="img2img_res_switch_btn", tooltip="Switch width/height") - detect_image_size_btn = ToolButton(value=detect_image_size_symbol, elem_id="img2img_detect_image_size_btn", tooltip="Auto detect size from img2img") - - with gr.Tab(label="Resize by", elem_id="img2img_tab_resize_by") as tab_scale_by: - scale_by = gr.Slider(minimum=0.05, maximum=4.0, step=0.05, label="Scale", value=1.0, elem_id="img2img_scale") - - with FormRow(): - scale_by_html = FormHTML(resize_from_to_html(0, 0, 0.0), elem_id="img2img_scale_resolution_preview") - gr.Slider(label="Unused", elem_id="img2img_unused_scale_by_slider") - button_update_resize_to = gr.Button(visible=False, elem_id="img2img_update_resize_to") - - on_change_args = dict( - fn=resize_from_to_html, - _js="currentImg2imgSourceResolution", - inputs=[dummy_component, dummy_component, scale_by], - outputs=scale_by_html, - show_progress=False, - ) - - scale_by.release(**on_change_args) - button_update_resize_to.click(**on_change_args) - - # the code below is meant to update the resolution label after the image in the image selection UI has changed. - # as it is now the event keeps firing continuously for inpaint edits, which ruins the page with constant requests. - # I assume this must be a gradio bug and for now we'll just do it for non-inpaint inputs. - for component in [init_img, sketch]: - component.change(fn=lambda: None, _js="updateImg2imgResizeToTextAfterChangingImage", inputs=[], outputs=[], show_progress=False) - - tab_scale_to.select(fn=lambda: 0, inputs=[], outputs=[selected_scale_tab]) - tab_scale_by.select(fn=lambda: 1, inputs=[], outputs=[selected_scale_tab]) - - if opts.dimensions_and_batch_together: - with gr.Column(elem_id="img2img_column_batch"): - batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="img2img_batch_count") - batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="img2img_batch_size") - - elif category == "denoising": - denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.75, elem_id="img2img_denoising_strength") - - elif category == "cfg": - with gr.Row(): - cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0, elem_id="img2img_cfg_scale") - image_cfg_scale = gr.Slider(minimum=0, maximum=3.0, step=0.05, label='Image CFG Scale', value=1.5, elem_id="img2img_image_cfg_scale", visible=False) + elif category == "cfg": + with gr.Row(): + cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0, elem_id="img2img_cfg_scale") + image_cfg_scale = gr.Slider(minimum=0, maximum=3.0, step=0.05, label='Image CFG Scale', value=1.5, elem_id="img2img_image_cfg_scale", visible=False) - elif category == "checkboxes": - with FormRow(elem_classes="checkboxes-row", variant="compact"): - pass + elif category == "checkboxes": + with FormRow(elem_classes="checkboxes-row", variant="compact"): + pass - elif category == "accordions": - with gr.Row(elem_id="img2img_accordions", elem_classes="accordions"): - scripts.scripts_img2img.setup_ui_for_section(category) + elif category == "accordions": + with gr.Row(elem_id="img2img_accordions", elem_classes="accordions"): + scripts.scripts_img2img.setup_ui_for_section(category) - elif category == "batch": - if not opts.dimensions_and_batch_together: - with FormRow(elem_id="img2img_column_batch"): - batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="img2img_batch_count") - batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="img2img_batch_size") + elif category == "batch": + if not opts.dimensions_and_batch_together: + with FormRow(elem_id="img2img_column_batch"): + batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="img2img_batch_count") + batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="img2img_batch_size") - elif category == "override_settings": - with FormRow(elem_id="img2img_override_settings_row") as row: - override_settings = create_override_settings_dropdown('img2img', row) + elif category == "override_settings": + with FormRow(elem_id="img2img_override_settings_row") as row: + override_settings = create_override_settings_dropdown('img2img', row) - elif category == "scripts": - with FormGroup(elem_id="img2img_script_container"): - custom_inputs = scripts.scripts_img2img.setup_ui() + elif category == "scripts": + with FormGroup(elem_id="img2img_script_container"): + custom_inputs = scripts.scripts_img2img.setup_ui() - elif category == "inpaint": - with FormGroup(elem_id="inpaint_controls", visible=False) as inpaint_controls: - with FormRow(): - mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=4, elem_id="img2img_mask_blur") - mask_alpha = gr.Slider(label="Mask transparency", visible=False, elem_id="img2img_mask_alpha") + elif category == "inpaint": + with FormGroup(elem_id="inpaint_controls", visible=False) as inpaint_controls: + with FormRow(): + mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=4, elem_id="img2img_mask_blur") + mask_alpha = gr.Slider(label="Mask transparency", visible=False, elem_id="img2img_mask_alpha") - with FormRow(): - inpainting_mask_invert = gr.Radio(label='Mask mode', choices=['Inpaint masked', 'Inpaint not masked'], value='Inpaint masked', type="index", elem_id="img2img_mask_mode") + with FormRow(): + inpainting_mask_invert = gr.Radio(label='Mask mode', choices=['Inpaint masked', 'Inpaint not masked'], value='Inpaint masked', type="index", elem_id="img2img_mask_mode") - with FormRow(): - inpainting_fill = gr.Radio(label='Masked content', choices=['fill', 'original', 'latent noise', 'latent nothing'], value='original', type="index", elem_id="img2img_inpainting_fill") + with FormRow(): + inpainting_fill = gr.Radio(label='Masked content', choices=['fill', 'original', 'latent noise', 'latent nothing'], value='original', type="index", elem_id="img2img_inpainting_fill") - with FormRow(): - with gr.Column(): - inpaint_full_res = gr.Radio(label="Inpaint area", choices=["Whole picture", "Only masked"], type="index", value="Whole picture", elem_id="img2img_inpaint_full_res") + with FormRow(): + with gr.Column(): + inpaint_full_res = gr.Radio(label="Inpaint area", choices=["Whole picture", "Only masked"], type="index", value="Whole picture", elem_id="img2img_inpaint_full_res") - with gr.Column(scale=4): - inpaint_full_res_padding = gr.Slider(label='Only masked padding, pixels', minimum=0, maximum=256, step=4, value=32, elem_id="img2img_inpaint_full_res_padding") + with gr.Column(scale=4): + inpaint_full_res_padding = gr.Slider(label='Only masked padding, pixels', minimum=0, maximum=256, step=4, value=32, elem_id="img2img_inpaint_full_res_padding") - if category not in {"accordions"}: - scripts.scripts_img2img.setup_ui_for_section(category) + if category not in {"accordions"}: + scripts.scripts_img2img.setup_ui_for_section(category) + def select_img2img_tab(tab): return gr.update(visible=tab in [2, 3, 4]), gr.update(visible=tab == 3), From 9aa4d098f07655d99cd16e8e9c984d043dbf9006 Mon Sep 17 00:00:00 2001 From: Emily Zeng Date: Thu, 9 Nov 2023 13:25:24 -0500 Subject: [PATCH 0245/1174] removed changes that weren't merged properly --- modules/ui.py | 51 +-------------------------------------------------- 1 file changed, 1 insertion(+), 50 deletions(-) diff --git a/modules/ui.py b/modules/ui.py index 0faccbd3488..3eec7839f76 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -592,55 +592,6 @@ def copy_image(img): if category == "sampler": steps, sampler_name = create_sampler_and_steps_selection(sd_samplers.visible_sampler_names(), "img2img") - elif category == "dimensions": - with FormRow(): - with gr.Column(elem_id="img2img_column_size", scale=4): - selected_scale_tab = gr.State(value=0) - - with gr.Tabs(): - with gr.Tab(label="Resize to", elem_id="img2img_tab_resize_to") as tab_scale_to: - with FormRow(): - with gr.Column(elem_id="img2img_column_size", scale=4): - width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="img2img_width") - height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="img2img_height") - with gr.Column(elem_id="img2img_dimensions_row", scale=1, elem_classes="dimensions-tools"): - res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="img2img_res_switch_btn", tooltip="Switch width/height") - detect_image_size_btn = ToolButton(value=detect_image_size_symbol, elem_id="img2img_detect_image_size_btn", tooltip="Auto detect size from img2img") - - with gr.Tab(label="Resize by", elem_id="img2img_tab_resize_by") as tab_scale_by: - scale_by = gr.Slider(minimum=0.05, maximum=4.0, step=0.05, label="Scale", value=1.0, elem_id="img2img_scale") - - with FormRow(): - scale_by_html = FormHTML(resize_from_to_html(0, 0, 0.0), elem_id="img2img_scale_resolution_preview") - gr.Slider(label="Unused", elem_id="img2img_unused_scale_by_slider") - button_update_resize_to = gr.Button(visible=False, elem_id="img2img_update_resize_to") - - on_change_args = dict( - fn=resize_from_to_html, - _js="currentImg2imgSourceResolution", - inputs=[dummy_component, dummy_component, scale_by], - outputs=scale_by_html, - show_progress=False, - ) - - scale_by.release(**on_change_args) - button_update_resize_to.click(**on_change_args) - - # the code below is meant to update the resolution label after the image in the image selection UI has changed. - # as it is now the event keeps firing continuously for inpaint edits, which ruins the page with constant requests. - # I assume this must be a gradio bug and for now we'll just do it for non-inpaint inputs. - for component in [init_img, sketch]: - component.change(fn=lambda: None, _js="updateImg2imgResizeToTextAfterChangingImage", inputs=[], outputs=[], show_progress=False) - - with FormRow(): - resize_mode = gr.Radio(label="Resize mode", elem_id="resize_mode", choices=["Just resize", "Crop and resize", "Resize and fill", "Just resize (latent upscale)"], type="index", value="Just resize") - - scripts.scripts_img2img.prepare_ui() - - for category in ordered_ui_categories(): - if category == "sampler": - steps, sampler_name = create_sampler_and_steps_selection(sd_samplers.visible_sampler_names(), "img2img") - elif category == "dimensions": with FormRow(): with gr.Column(elem_id="img2img_column_size", scale=4): @@ -740,7 +691,7 @@ def copy_image(img): if category not in {"accordions"}: scripts.scripts_img2img.setup_ui_for_section(category) - + def select_img2img_tab(tab): return gr.update(visible=tab in [2, 3, 4]), gr.update(visible=tab == 3), From ff2952f10551aab2000002079d5f862af979e964 Mon Sep 17 00:00:00 2001 From: Emily Zeng Date: Thu, 9 Nov 2023 13:35:52 -0500 Subject: [PATCH 0246/1174] multiline with statement for readibility --- modules/ui.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/modules/ui.py b/modules/ui.py index 3eec7839f76..bf06776e935 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -270,7 +270,9 @@ def create_ui(): extra_tabs.__enter__() with gr.Tab("Generation", id="txt2img_generation") as txt2img_generation_tab, ResizeHandleRow(equal_height=False): - with gr.Accordion("Open for Settings", open=False), gr.Column(variant='compact', elem_id="txt2img_settings") if shared.opts.img2img_settings_accordion else gr.Column(variant='compact', elem_id="txt2img_settings"): + with gr.Accordion("Open for Settings", open=False), gr.Column(variant='compact', elem_id="txt2img_settings") \ + if shared.opts.img2img_settings_accordion else gr.Column(variant='compact', elem_id="txt2img_settings"): + scripts.scripts_txt2img.prepare_ui() for category in ordered_ui_categories(): @@ -489,7 +491,9 @@ def create_ui(): extra_tabs.__enter__() with gr.Tab("Generation", id="img2img_generation") as img2img_generation_tab, ResizeHandleRow(equal_height=False): - with gr.Accordion("Open for Settings", open=False), gr.Column(variant='compact', elem_id="img2img_settings") if shared.opts.img2img_settings_accordion else gr.Column(variant='compact', elem_id="img2img_settings"): + with gr.Accordion("Open for Settings", open=False), gr.Column(variant='compact', elem_id="img2img_settings") \ + if shared.opts.img2img_settings_accordion else gr.Column(variant='compact', elem_id="img2img_settings"): + copy_image_buttons = [] copy_image_destinations = {} From 98fc525a2c3ebc548ba774e564d99bea2db2f191 Mon Sep 17 00:00:00 2001 From: "fuchen.ljl" Date: Fri, 10 Nov 2023 14:37:30 +0800 Subject: [PATCH 0247/1174] Update README.md Modify the stablediffusion dependency address From 6d77a6e1c6b27ae82b2186cfc36cc4ad2a5e9ecf Mon Sep 17 00:00:00 2001 From: "fuchen.ljl" Date: Fri, 10 Nov 2023 14:40:39 +0800 Subject: [PATCH 0248/1174] Update README.md Modify the stablediffusion dependency address --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 4e08344008c..1c97ecbb0dc 100644 --- a/README.md +++ b/README.md @@ -146,7 +146,7 @@ For the purposes of getting Google and other search engines to crawl the wiki, h ## Credits Licenses for borrowed code can be found in `Settings -> Licenses` screen, and also in `html/licenses.html` file. -- Stable Diffusion - https://github.com/CompVis/stable-diffusion, https://github.com/CompVis/taming-transformers +- Stable Diffusion - https://github.com/Stability-AI/stablediffusion, https://github.com/CompVis/taming-transformers - k-diffusion - https://github.com/crowsonkb/k-diffusion.git - GFPGAN - https://github.com/TencentARC/GFPGAN.git - CodeFormer - https://github.com/sczhou/CodeFormer From 66767e3876dde8d0ef27ce00254cd6b75332f036 Mon Sep 17 00:00:00 2001 From: "Alessandro de Oliveira Faria (A.K.A. CABELO)" Date: Fri, 10 Nov 2023 03:45:44 -0300 Subject: [PATCH 0249/1174] - opensuse compatibility --- README.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/README.md b/README.md index 4e08344008c..89e54a6132e 100644 --- a/README.md +++ b/README.md @@ -121,6 +121,8 @@ Alternatively, use online services (like Google Colab): sudo apt install wget git python3 python3-venv libgl1 libglib2.0-0 # Red Hat-based: sudo dnf install wget git python3 +# openSUSE-based: +sudo zypper install wget git python3 # Arch-based: sudo pacman -S wget git python3 ``` From 7ff54005fee46ce188544db75c27de61ae279001 Mon Sep 17 00:00:00 2001 From: missionfloyd Date: Thu, 9 Nov 2023 23:47:53 -0700 Subject: [PATCH 0250/1174] Enable prompt hotkeys in style editor --- modules/ui_prompt_styles.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/modules/ui_prompt_styles.py b/modules/ui_prompt_styles.py index 85eb3a6417e..d6f8d4c76e8 100644 --- a/modules/ui_prompt_styles.py +++ b/modules/ui_prompt_styles.py @@ -64,10 +64,10 @@ def __init__(self, tabname, main_ui_prompt, main_ui_negative_prompt): self.materialize = ui_components.ToolButton(value=styles_materialize_symbol, elem_id=f"{tabname}_style_apply", tooltip="Apply all selected styles from the style selction dropdown in main UI to the prompt.") with gr.Row(): - self.prompt = gr.Textbox(label="Prompt", show_label=True, elem_id=f"{tabname}_edit_style_prompt", lines=3) + self.prompt = gr.Textbox(label="Prompt", show_label=True, elem_id=f"{tabname}_edit_style_prompt", lines=3, elem_classes=["prompt"]) with gr.Row(): - self.neg_prompt = gr.Textbox(label="Negative prompt", show_label=True, elem_id=f"{tabname}_edit_style_neg_prompt", lines=3) + self.neg_prompt = gr.Textbox(label="Negative prompt", show_label=True, elem_id=f"{tabname}_edit_style_neg_prompt", lines=3, elem_classes=["prompt"]) with gr.Row(): self.save = gr.Button('Save', variant='primary', elem_id=f'{tabname}_edit_style_save', visible=False) From 6a86b3ad9bc7bb9a58dc4228ecf93a3a511ed122 Mon Sep 17 00:00:00 2001 From: "Alessandro de Oliveira Faria (A.K.A. CABELO)" Date: Fri, 10 Nov 2023 14:15:34 -0300 Subject: [PATCH 0251/1174] Compatibility with Debian 11, Fedora 34+ and openSUSE 15.4+ --- README.md | 4 ++-- webui.sh | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index 89e54a6132e..d4aa376b1be 100644 --- a/README.md +++ b/README.md @@ -120,9 +120,9 @@ Alternatively, use online services (like Google Colab): # Debian-based: sudo apt install wget git python3 python3-venv libgl1 libglib2.0-0 # Red Hat-based: -sudo dnf install wget git python3 +sudo dnf install wget git python3 gperftools-libs libglvnd-glx # openSUSE-based: -sudo zypper install wget git python3 +sudo zypper install wget git python3 libtcmalloc4 libglvnd # Arch-based: sudo pacman -S wget git python3 ``` diff --git a/webui.sh b/webui.sh index 3d0f87eed74..5c23c1d8fb3 100755 --- a/webui.sh +++ b/webui.sh @@ -87,7 +87,7 @@ delimiter="################################################################" printf "\n%s\n" "${delimiter}" printf "\e[1m\e[32mInstall script for stable-diffusion + Web UI\n" -printf "\e[1m\e[34mTested on Debian 11 (Bullseye)\e[0m" +printf "\e[1m\e[34mTested on Debian 11 (Bullseye), Fedora 34+ and openSUSE Leap 15.4 or newer.\e[0m" printf "\n%s\n" "${delimiter}" # Do not run as root @@ -222,7 +222,7 @@ fi # Try using TCMalloc on Linux prepare_tcmalloc() { if [[ "${OSTYPE}" == "linux"* ]] && [[ -z "${NO_TCMALLOC}" ]] && [[ -z "${LD_PRELOAD}" ]]; then - TCMALLOC="$(PATH=/usr/sbin:$PATH ldconfig -p | grep -Po "libtcmalloc(_minimal|)\.so\.\d" | head -n 1)" + TCMALLOC="$(PATH=/sbin:$PATH ldconfig -p | grep -Po "libtcmalloc(_minimal|)\.so\.\d" | head -n 1)" if [[ ! -z "${TCMALLOC}" ]]; then echo "Using TCMalloc: ${TCMALLOC}" export LD_PRELOAD="${TCMALLOC}" From 5432d9301359945b595d5e6649c7a64b4bb0bfca Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Sat, 11 Nov 2023 03:38:55 +0900 Subject: [PATCH 0252/1174] fix added accordion settings options --- modules/ui.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/modules/ui.py b/modules/ui.py index bf06776e935..f28de3543ae 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -4,6 +4,7 @@ import sys from functools import reduce import warnings +from contextlib import suppress import gradio as gr import gradio.utils @@ -270,9 +271,7 @@ def create_ui(): extra_tabs.__enter__() with gr.Tab("Generation", id="txt2img_generation") as txt2img_generation_tab, ResizeHandleRow(equal_height=False): - with gr.Accordion("Open for Settings", open=False), gr.Column(variant='compact', elem_id="txt2img_settings") \ - if shared.opts.img2img_settings_accordion else gr.Column(variant='compact', elem_id="txt2img_settings"): - + with gr.Accordion("Open for Settings", open=False) if shared.opts.txt2img_settings_accordion else suppress(), gr.Column(variant='compact', elem_id="txt2img_settings"): scripts.scripts_txt2img.prepare_ui() for category in ordered_ui_categories(): @@ -491,8 +490,7 @@ def create_ui(): extra_tabs.__enter__() with gr.Tab("Generation", id="img2img_generation") as img2img_generation_tab, ResizeHandleRow(equal_height=False): - with gr.Accordion("Open for Settings", open=False), gr.Column(variant='compact', elem_id="img2img_settings") \ - if shared.opts.img2img_settings_accordion else gr.Column(variant='compact', elem_id="img2img_settings"): + with gr.Accordion("Open for Settings", open=False) if shared.opts.img2img_settings_accordion else suppress(), gr.Column(variant='compact', elem_id="img2img_settings"): copy_image_buttons = [] copy_image_destinations = {} From 3a4a6c43a4ca31056d5c09bb54e3eef24e6cf864 Mon Sep 17 00:00:00 2001 From: Emily Zeng Date: Fri, 10 Nov 2023 16:06:01 -0500 Subject: [PATCH 0253/1174] ExitStack as alternative to suppress --- modules/ui.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/modules/ui.py b/modules/ui.py index f28de3543ae..ba0d8542b60 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -4,7 +4,7 @@ import sys from functools import reduce import warnings -from contextlib import suppress +from contextlib import ExitStack import gradio as gr import gradio.utils @@ -271,7 +271,11 @@ def create_ui(): extra_tabs.__enter__() with gr.Tab("Generation", id="txt2img_generation") as txt2img_generation_tab, ResizeHandleRow(equal_height=False): - with gr.Accordion("Open for Settings", open=False) if shared.opts.txt2img_settings_accordion else suppress(), gr.Column(variant='compact', elem_id="txt2img_settings"): + with ExitStack() as stack: + if shared.opts.txt2img_settings_accordion: + stack.enter_context(gr.Accordion("Open for Settings", open=False)) + stack.enter_context(gr.Column(variant='compact', elem_id="txt2img_settings")) + scripts.scripts_txt2img.prepare_ui() for category in ordered_ui_categories(): @@ -490,7 +494,10 @@ def create_ui(): extra_tabs.__enter__() with gr.Tab("Generation", id="img2img_generation") as img2img_generation_tab, ResizeHandleRow(equal_height=False): - with gr.Accordion("Open for Settings", open=False) if shared.opts.img2img_settings_accordion else suppress(), gr.Column(variant='compact', elem_id="img2img_settings"): + with ExitStack() as stack: + if shared.opts.img2img_settings_accordion: + stack.enter_context(gr.Accordion("Open for Settings", open=False)) + stack.enter_context(gr.Column(variant='compact', elem_id="img2img_settings")) copy_image_buttons = [] copy_image_destinations = {} From 0fc7dc1c04a046d95588651ffc4e71a7d40378d3 Mon Sep 17 00:00:00 2001 From: wfjsw Date: Sat, 11 Nov 2023 04:01:13 -0600 Subject: [PATCH 0254/1174] implementing script metadata and DAG sorting mechanism --- modules/extensions.py | 80 +++++++++++++++++++++--- modules/scripts.py | 141 +++++++++++++++++++++++++++++++++++++----- 2 files changed, 197 insertions(+), 24 deletions(-) diff --git a/modules/extensions.py b/modules/extensions.py index bf9a1878f5d..e317a404150 100644 --- a/modules/extensions.py +++ b/modules/extensions.py @@ -1,3 +1,5 @@ +import configparser +import functools import os import threading @@ -23,8 +25,9 @@ class Extension: lock = threading.Lock() cached_fields = ['remote', 'commit_date', 'branch', 'commit_hash', 'version'] - def __init__(self, name, path, enabled=True, is_builtin=False): + def __init__(self, name, path, enabled=True, is_builtin=False, canonical_name=None): self.name = name + self.canonical_name = canonical_name or name.lower() self.path = path self.enabled = enabled self.status = '' @@ -37,6 +40,17 @@ def __init__(self, name, path, enabled=True, is_builtin=False): self.remote = None self.have_info_from_repo = False + @functools.cached_property + def metadata(self): + if os.path.isfile(os.path.join(self.path, "sd_webui_metadata.ini")): + try: + config = configparser.ConfigParser() + config.read(os.path.join(self.path, "sd_webui_metadata.ini")) + return config + except Exception: + errors.report(f"Error reading sd_webui_metadata.ini for extension {self.canonical_name}.", exc_info=True) + return None + def to_dict(self): return {x: getattr(self, x) for x in self.cached_fields} @@ -136,9 +150,6 @@ def fetch_and_reset_hard(self, commit='origin'): def list_extensions(): extensions.clear() - if not os.path.isdir(extensions_dir): - return - if shared.cmd_opts.disable_all_extensions: print("*** \"--disable-all-extensions\" arg was used, will not load any extensions ***") elif shared.opts.disable_all_extensions == "all": @@ -148,18 +159,69 @@ def list_extensions(): elif shared.opts.disable_all_extensions == "extra": print("*** \"Disable all extensions\" option was set, will only load built-in extensions ***") - extension_paths = [] + extension_dependency_map = {} + + # scan through extensions directory and load metadata for dirname in [extensions_dir, extensions_builtin_dir]: if not os.path.isdir(dirname): - return + continue for extension_dirname in sorted(os.listdir(dirname)): path = os.path.join(dirname, extension_dirname) if not os.path.isdir(path): continue - extension_paths.append((extension_dirname, path, dirname == extensions_builtin_dir)) + canonical_name = extension_dirname + requires = None + + if os.path.isfile(os.path.join(path, "sd_webui_metadata.ini")): + try: + config = configparser.ConfigParser() + config.read(os.path.join(path, "sd_webui_metadata.ini")) + canonical_name = config.get("Extension", "Name", fallback=canonical_name) + requires = config.get("Extension", "Requires", fallback=None) + continue + except Exception: + errors.report(f"Error reading sd_webui_metadata.ini for extension {extension_dirname}. " + f"Will load regardless.", exc_info=True) + + canonical_name = canonical_name.lower().strip() + + # check for duplicated canonical names + if canonical_name in extension_dependency_map: + errors.report(f"Duplicate canonical name \"{canonical_name}\" found in extensions " + f"\"{extension_dirname}\" and \"{extension_dependency_map[canonical_name]['dirname']}\". " + f"The current loading extension will be discarded.", exc_info=False) + continue - for dirname, path, is_builtin in extension_paths: - extension = Extension(name=dirname, path=path, enabled=dirname not in shared.opts.disabled_extensions, is_builtin=is_builtin) + # we want to wash the data to lowercase and remove whitespaces just in case + requires = [x.strip() for x in requires.lower().split(',')] if requires else [] + + extension_dependency_map[canonical_name] = { + "dirname": extension_dirname, + "path": path, + "requires": requires, + } + + # check for requirements + for (_, extension_data) in extension_dependency_map.items(): + dirname, path, requires = extension_data['dirname'], extension_data['path'], extension_data['requires'] + requirement_met = True + for req in requires: + if req not in extension_dependency_map: + errors.report(f"Extension \"{dirname}\" requires \"{req}\" which is not installed. " + f"The current loading extension will be discarded.", exc_info=False) + requirement_met = False + break + dep_dirname = extension_dependency_map[req]['dirname'] + if dep_dirname in shared.opts.disabled_extensions: + errors.report(f"Extension \"{dirname}\" requires \"{dep_dirname}\" which is disabled. " + f"The current loading extension will be discarded.", exc_info=False) + requirement_met = False + break + + is_builtin = dirname == extensions_builtin_dir + extension = Extension(name=dirname, path=path, + enabled=dirname not in shared.opts.disabled_extensions and requirement_met, + is_builtin=is_builtin) extensions.append(extension) diff --git a/modules/scripts.py b/modules/scripts.py index 5c6e0226e2e..e92a34a0e30 100644 --- a/modules/scripts.py +++ b/modules/scripts.py @@ -2,6 +2,7 @@ import re import sys import inspect +from graphlib import TopologicalSorter, CycleError from collections import namedtuple from dataclasses import dataclass @@ -314,15 +315,131 @@ def basedir(): def list_scripts(scriptdirname, extension, *, include_extensions=True): scripts_list = [] - - basedir = os.path.join(paths.script_path, scriptdirname) - if os.path.exists(basedir): - for filename in sorted(os.listdir(basedir)): - scripts_list.append(ScriptFile(paths.script_path, filename, os.path.join(basedir, filename))) + script_dependency_map = {} + + # build script dependency map + + root_script_basedir = os.path.join(paths.script_path, scriptdirname) + if os.path.exists(root_script_basedir): + for filename in sorted(os.listdir(root_script_basedir)): + script_dependency_map[filename] = { + "extension": None, + "extension_dirname": None, + "script_file": ScriptFile(paths.script_path, filename, os.path.join(root_script_basedir, filename)), + "requires": [], + "load_before": [], + "load_after": [], + } if include_extensions: for ext in extensions.active(): - scripts_list += ext.list_files(scriptdirname, extension) + extension_scripts_list = ext.list_files(scriptdirname, extension) + for extension_script in extension_scripts_list: + # this is built on the assumption that script name is unique. + # I think bad thing is gonna happen if name collide in the current implementation anyway, but we + # will need to refactor here if this assumption is broken later on. + if extension_script.filename in script_dependency_map: + errors.report(f"Duplicate script name \"{extension_script.filename}\" found in extensions " + f"\"{ext.name}\" and \"{script_dependency_map[extension_script.filename]['extension_dirname'] or 'builtin'}\". " + f"The current loading file will be discarded.", exc_info=False) + continue + + relative_path = scriptdirname + "/" + extension_script.filename + + requires = None + load_before = None + load_after = None + + if ext.metadata is not None: + requires = ext.metadata.get(relative_path, "Requires", fallback=None) + load_before = ext.metadata.get(relative_path, "Before", fallback=None) + load_after = ext.metadata.get(relative_path, "After", fallback=None) + + requires = [x.strip() for x in requires.split(',')] if requires else [] + load_after = [x.strip() for x in load_after.split(',')] if load_after else [] + load_before = [x.strip() for x in load_before.split(',')] if load_before else [] + + script_dependency_map[extension_script.filename] = { + "extension": ext.canonical_name, + "extension_dirname": ext.name, + "script_file": extension_script, + "requires": requires, + "load_before": load_before, + "load_after": load_after, + } + + # resolve dependencies + + loaded_extensions = set() + for _, script_data in script_dependency_map.items(): + if script_data['extension'] is not None: + loaded_extensions.add(script_data['extension']) + + for script_filename, script_data in script_dependency_map.items(): + # load before requires inverse dependency + # in this case, append the script name into the load_after list of the specified script + for load_before_script in script_data['load_before']: + if load_before_script.startswith('ext:'): + # if this requires an extension to be loaded before + required_extension = load_before_script[4:] + for _, script_data2 in script_dependency_map.items(): + if script_data2['extension'] == required_extension: + script_data2['load_after'].append(script_filename) + break + else: + # if this requires an individual script to be loaded before + if load_before_script in script_dependency_map: + script_dependency_map[load_before_script]['load_after'].append(script_filename) + + # resolve extension name in load_after lists + for load_after_script in script_data['load_after']: + if load_after_script.startswith('ext:'): + # if this requires an extension to be loaded after + required_extension = load_after_script[4:] + for script_file_name2, script_data2 in script_dependency_map.items(): + if script_data2['extension'] == required_extension: + script_data['load_after'].append(script_file_name2) + + # remove all extension names in load_after lists + script_data['load_after'] = [x for x in script_data['load_after'] if not x.startswith('ext:')] + + # build the DAG + sorter = TopologicalSorter() + for script_filename, script_data in script_dependency_map.items(): + requirement_met = True + for required_script in script_data['requires']: + if required_script.startswith('ext:'): + # if this requires an extension to be installed + required_extension = required_script[4:] + if required_extension not in loaded_extensions: + errors.report(f"Script \"{script_filename}\" requires extension \"{required_extension}\" to " + f"be loaded, but it is not. Skipping.", + exc_info=False) + requirement_met = False + break + else: + # if this requires an individual script to be loaded + if required_script not in script_dependency_map: + errors.report(f"Script \"{script_filename}\" requires script \"{required_script}\" to " + f"be loaded, but it is not. Skipping.", + exc_info=False) + requirement_met = False + break + if not requirement_met: + continue + + sorter.add(script_filename, *script_data['load_after']) + + # sort the scripts + try: + ordered_script = sorter.static_order() + except CycleError: + errors.report("Cycle detected in script dependencies. Scripts will load in ascending order.", exc_info=True) + ordered_script = script_dependency_map.keys() + + for script_filename in ordered_script: + script_data = script_dependency_map[script_filename] + scripts_list.append(script_data['script_file']) scripts_list = [x for x in scripts_list if os.path.splitext(x.path)[1].lower() == extension and os.path.isfile(x.path)] @@ -365,15 +482,9 @@ def register_scripts_from_module(module): elif issubclass(script_class, scripts_postprocessing.ScriptPostprocessing): postprocessing_scripts_data.append(ScriptClassData(script_class, scriptfile.path, scriptfile.basedir, module)) - def orderby(basedir): - # 1st webui, 2nd extensions-builtin, 3rd extensions - priority = {os.path.join(paths.script_path, "extensions-builtin"):1, paths.script_path:0} - for key in priority: - if basedir.startswith(key): - return priority[key] - return 9999 - - for scriptfile in sorted(scripts_list, key=lambda x: [orderby(x.basedir), x]): + # here the scripts_list is already ordered + # processing_script is not considered though + for scriptfile in scripts_list: try: if scriptfile.basedir != paths.script_path: sys.path = [scriptfile.basedir] + sys.path From 0d1924c48be3d02650e87b12a4f53165a8b4a599 Mon Sep 17 00:00:00 2001 From: wfjsw Date: Sat, 11 Nov 2023 04:03:55 -0600 Subject: [PATCH 0255/1174] populate loaded_extensions from extension list instead --- modules/scripts.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/modules/scripts.py b/modules/scripts.py index e92a34a0e30..7cdf288df4e 100644 --- a/modules/scripts.py +++ b/modules/scripts.py @@ -371,9 +371,8 @@ def list_scripts(scriptdirname, extension, *, include_extensions=True): # resolve dependencies loaded_extensions = set() - for _, script_data in script_dependency_map.items(): - if script_data['extension'] is not None: - loaded_extensions.add(script_data['extension']) + for ext in extensions.active(): + loaded_extensions.add(ext.canonical_name) for script_filename, script_data in script_dependency_map.items(): # load before requires inverse dependency From bc1a450124ab643fc0c3ea7630d875afb4b84b84 Mon Sep 17 00:00:00 2001 From: wfjsw Date: Sat, 11 Nov 2023 04:08:45 -0600 Subject: [PATCH 0256/1174] reverse the extension load order so builtin extensions load earlier natively --- modules/extensions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/extensions.py b/modules/extensions.py index e317a404150..7583a3b0372 100644 --- a/modules/extensions.py +++ b/modules/extensions.py @@ -162,7 +162,7 @@ def list_extensions(): extension_dependency_map = {} # scan through extensions directory and load metadata - for dirname in [extensions_dir, extensions_builtin_dir]: + for dirname in [extensions_builtin_dir, extensions_dir]: if not os.path.isdir(dirname): continue From 294f8a514f982248cda1cafda30d35566f3a0321 Mon Sep 17 00:00:00 2001 From: aria1th <35677394+aria1th@users.noreply.github.com> Date: Sat, 11 Nov 2023 23:28:12 +0900 Subject: [PATCH 0257/1174] add hyperTile https://github.com/tfernd/HyperTile --- modules/processing.py | 27 ++++++++++++++++++++++++--- modules/shared_options.py | 2 ++ 2 files changed, 26 insertions(+), 3 deletions(-) diff --git a/modules/processing.py b/modules/processing.py index b0e240a4640..e2309534334 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -799,6 +799,16 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: infotexts = [] output_images = [] + unet_object = p.sd_model.model + vae_model = p.sd_model.first_stage_model + try: + from hyper_tile import split_attention, flush + except (ImportError, ModuleNotFoundError): # pip install git+https://github.com/tfernd/HyperTile@2ef64b2800d007d305755c33550537410310d7df + split_attention = lambda *args, **kwargs: lambda x: x # return a no-op context manager + flush = lambda: None + import random + saved_rng_state = random.getstate() + random.seed(p.seed) # hyper_tile uses random, so we need to seed it with torch.no_grad(), p.sd_model.ema_scope(): with devices.autocast(): @@ -866,15 +876,25 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: shared.state.job = f"Batch {n+1} out of {p.n_iter}" with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast(): - samples_ddim = p.sample(conditioning=p.c, unconditional_conditioning=p.uc, seeds=p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, prompts=p.prompts) + # get largest tile size available, which is 2^x which is factor of gcd of p.width and p.height + gcd = math.gcd(p.width, p.height) + largest_tile_size_available = 1 + while gcd % (largest_tile_size_available * 2) == 0: + largest_tile_size_available *= 2 + aspect_ratio = p.width / p.height + with split_attention(vae_model, aspect_ratio=aspect_ratio, tile_size=min(largest_tile_size_available, 128), disable=not shared.opts.hypertile_split_vae_attn): + with split_attention(unet_object, aspect_ratio=aspect_ratio, tile_size=min(largest_tile_size_available, 256), swap_size=2, disable=not shared.opts.hypertile_split_unet_attn): + flush() + samples_ddim = p.sample(conditioning=p.c, unconditional_conditioning=p.uc, seeds=p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, prompts=p.prompts) if getattr(samples_ddim, 'already_decoded', False): x_samples_ddim = samples_ddim else: if opts.sd_vae_decode_method != 'Full': p.extra_generation_params['VAE Decoder'] = opts.sd_vae_decode_method - - x_samples_ddim = decode_latent_batch(p.sd_model, samples_ddim, target_device=devices.cpu, check_for_nans=True) + with split_attention(vae_model, aspect_ratio=aspect_ratio, tile_size=min(largest_tile_size_available, 128), disable=not shared.opts.hypertile_split_vae_attn): + flush() + x_samples_ddim = decode_latent_batch(p.sd_model, samples_ddim, target_device=devices.cpu, check_for_nans=True) x_samples_ddim = torch.stack(x_samples_ddim).float() x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) @@ -980,6 +1000,7 @@ def infotext(index=0, use_main_prompt=False): if opts.grid_save: images.save_image(grid, p.outpath_grids, "grid", p.all_seeds[0], p.all_prompts[0], opts.grid_format, info=infotext(use_main_prompt=True), short_filename=not opts.grid_extended_filename, p=p, grid=True) + random.setstate(saved_rng_state) if not p.disable_extra_networks and p.extra_network_data: extra_networks.deactivate(p, p.extra_network_data) diff --git a/modules/shared_options.py b/modules/shared_options.py index d40db5306ee..d965026568e 100644 --- a/modules/shared_options.py +++ b/modules/shared_options.py @@ -200,6 +200,8 @@ "pad_cond_uncond": OptionInfo(False, "Pad prompt/negative prompt to be same length", infotext='Pad conds').info("improves performance when prompt and negative prompt have different lengths; changes seeds"), "persistent_cond_cache": OptionInfo(True, "Persistent cond cache").info("do not recalculate conds from prompts if prompts have not changed since previous calculation"), "batch_cond_uncond": OptionInfo(True, "Batch cond/uncond").info("do both conditional and unconditional denoising in one batch; uses a bit more VRAM during sampling, but improves speed; previously this was controlled by --always-batch-cond-uncond comandline argument"), + "hypertile_split_unet_attn" : OptionInfo(False, "Split attention in Unet with HyperTile").link("Github", "https://github.com/tfernd/HyperTile").info("improves performance; changes behavior, but deterministic"), + "hypertile_split_vae_attn": OptionInfo(False, "Split attention in VAE with HyperTile").link("Github", "https://github.com/tfernd/HyperTile").info("improves performance; changes behavior, but deterministic"), })) options_templates.update(options_section(('compatibility', "Compatibility"), { From 7af576e745c79a9539e40bc158e695192ae79f25 Mon Sep 17 00:00:00 2001 From: wfjsw Date: Sat, 11 Nov 2023 10:46:47 -0600 Subject: [PATCH 0258/1174] remove the assumption of same name --- modules/scripts.py | 81 +++++++++++++++++----------------------------- 1 file changed, 30 insertions(+), 51 deletions(-) diff --git a/modules/scripts.py b/modules/scripts.py index 7cdf288df4e..7ad22245142 100644 --- a/modules/scripts.py +++ b/modules/scripts.py @@ -335,15 +335,9 @@ def list_scripts(scriptdirname, extension, *, include_extensions=True): for ext in extensions.active(): extension_scripts_list = ext.list_files(scriptdirname, extension) for extension_script in extension_scripts_list: - # this is built on the assumption that script name is unique. - # I think bad thing is gonna happen if name collide in the current implementation anyway, but we - # will need to refactor here if this assumption is broken later on. - if extension_script.filename in script_dependency_map: - errors.report(f"Duplicate script name \"{extension_script.filename}\" found in extensions " - f"\"{ext.name}\" and \"{script_dependency_map[extension_script.filename]['extension_dirname'] or 'builtin'}\". " - f"The current loading file will be discarded.", exc_info=False) - continue - + script_canonical_name = ext.canonical_name + "/" + extension_script.filename + if ext.is_builtin: + script_canonical_name = "builtin/" + script_canonical_name relative_path = scriptdirname + "/" + extension_script.filename requires = None @@ -359,7 +353,7 @@ def list_scripts(scriptdirname, extension, *, include_extensions=True): load_after = [x.strip() for x in load_after.split(',')] if load_after else [] load_before = [x.strip() for x in load_before.split(',')] if load_before else [] - script_dependency_map[extension_script.filename] = { + script_dependency_map[script_canonical_name] = { "extension": ext.canonical_name, "extension_dirname": ext.name, "script_file": extension_script, @@ -374,60 +368,45 @@ def list_scripts(scriptdirname, extension, *, include_extensions=True): for ext in extensions.active(): loaded_extensions.add(ext.canonical_name) - for script_filename, script_data in script_dependency_map.items(): + for script_canonical_name, script_data in script_dependency_map.items(): # load before requires inverse dependency # in this case, append the script name into the load_after list of the specified script for load_before_script in script_data['load_before']: - if load_before_script.startswith('ext:'): - # if this requires an extension to be loaded before - required_extension = load_before_script[4:] + # if this requires an individual script to be loaded before + if load_before_script in script_dependency_map: + script_dependency_map[load_before_script]['load_after'].append(script_canonical_name) + elif load_before_script in loaded_extensions: for _, script_data2 in script_dependency_map.items(): - if script_data2['extension'] == required_extension: - script_data2['load_after'].append(script_filename) + if script_data2['extension'] == load_before_script: + script_data2['load_after'].append(script_canonical_name) break - else: - # if this requires an individual script to be loaded before - if load_before_script in script_dependency_map: - script_dependency_map[load_before_script]['load_after'].append(script_filename) # resolve extension name in load_after lists - for load_after_script in script_data['load_after']: - if load_after_script.startswith('ext:'): - # if this requires an extension to be loaded after - required_extension = load_after_script[4:] - for script_file_name2, script_data2 in script_dependency_map.items(): - if script_data2['extension'] == required_extension: - script_data['load_after'].append(script_file_name2) - - # remove all extension names in load_after lists - script_data['load_after'] = [x for x in script_data['load_after'] if not x.startswith('ext:')] + for load_after_script in list(script_data['load_after']): + if load_after_script not in script_dependency_map and load_after_script in loaded_extensions: + script_data['load_after'].remove(load_after_script) + for script_canonical_name2, script_data2 in script_dependency_map.items(): + if script_data2['extension'] == load_after_script: + script_data['load_after'].remove(script_canonical_name2) + break # build the DAG sorter = TopologicalSorter() - for script_filename, script_data in script_dependency_map.items(): + for script_canonical_name, script_data in script_dependency_map.items(): requirement_met = True for required_script in script_data['requires']: - if required_script.startswith('ext:'): - # if this requires an extension to be installed - required_extension = required_script[4:] - if required_extension not in loaded_extensions: - errors.report(f"Script \"{script_filename}\" requires extension \"{required_extension}\" to " - f"be loaded, but it is not. Skipping.", - exc_info=False) - requirement_met = False - break - else: - # if this requires an individual script to be loaded - if required_script not in script_dependency_map: - errors.report(f"Script \"{script_filename}\" requires script \"{required_script}\" to " - f"be loaded, but it is not. Skipping.", - exc_info=False) - requirement_met = False - break + # if this requires an individual script to be loaded + if required_script not in script_dependency_map and required_script not in loaded_extensions: + errors.report(f"Script \"{script_canonical_name}\" " + f"requires \"{required_script}\" to " + f"be loaded, but it is not. Skipping.", + exc_info=False) + requirement_met = False + break if not requirement_met: continue - sorter.add(script_filename, *script_data['load_after']) + sorter.add(script_canonical_name, *script_data['load_after']) # sort the scripts try: @@ -436,8 +415,8 @@ def list_scripts(scriptdirname, extension, *, include_extensions=True): errors.report("Cycle detected in script dependencies. Scripts will load in ascending order.", exc_info=True) ordered_script = script_dependency_map.keys() - for script_filename in ordered_script: - script_data = script_dependency_map[script_filename] + for script_canonical_name in ordered_script: + script_data = script_dependency_map[script_canonical_name] scripts_list.append(script_data['script_file']) scripts_list = [x for x in scripts_list if os.path.splitext(x.path)[1].lower() == extension and os.path.isfile(x.path)] From 520e52f846892cc2b207b738b4180fa863c7b38f Mon Sep 17 00:00:00 2001 From: wfjsw Date: Sat, 11 Nov 2023 10:58:26 -0600 Subject: [PATCH 0259/1174] allow comma and whitespace as separator --- modules/extensions.py | 9 ++++++--- modules/scripts.py | 6 +++--- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/modules/extensions.py b/modules/extensions.py index 7583a3b0372..795af996e77 100644 --- a/modules/extensions.py +++ b/modules/extensions.py @@ -2,6 +2,7 @@ import functools import os import threading +import re from modules import shared, errors, cache, scripts from modules.gitpython_hack import Repo @@ -48,7 +49,8 @@ def metadata(self): config.read(os.path.join(self.path, "sd_webui_metadata.ini")) return config except Exception: - errors.report(f"Error reading sd_webui_metadata.ini for extension {self.canonical_name}.", exc_info=True) + errors.report(f"Error reading sd_webui_metadata.ini for extension {self.canonical_name}.", + exc_info=True) return None def to_dict(self): @@ -70,6 +72,7 @@ def read_from_repo(): self.do_read_info_from_repo() return self.to_dict() + try: d = cache.cached_data_for_file('extensions-git', self.name, os.path.join(self.path, ".git"), read_from_repo) self.from_dict(d) @@ -194,8 +197,8 @@ def list_extensions(): f"The current loading extension will be discarded.", exc_info=False) continue - # we want to wash the data to lowercase and remove whitespaces just in case - requires = [x.strip() for x in requires.lower().split(',')] if requires else [] + # both "," and " " are accepted as separator + requires = list(filter(None, re.split(r"[,\s]+", requires.lower()))) if requires else [] extension_dependency_map[canonical_name] = { "dirname": extension_dirname, diff --git a/modules/scripts.py b/modules/scripts.py index 7ad22245142..5dd0555dd6a 100644 --- a/modules/scripts.py +++ b/modules/scripts.py @@ -349,9 +349,9 @@ def list_scripts(scriptdirname, extension, *, include_extensions=True): load_before = ext.metadata.get(relative_path, "Before", fallback=None) load_after = ext.metadata.get(relative_path, "After", fallback=None) - requires = [x.strip() for x in requires.split(',')] if requires else [] - load_after = [x.strip() for x in load_after.split(',')] if load_after else [] - load_before = [x.strip() for x in load_before.split(',')] if load_before else [] + requires = list(filter(None, re.split(r"[,\s]+", requires.lower()))) if requires else [] + load_after = list(filter(None, re.split(r"[,\s]+", load_after.lower()))) if load_after else [] + load_before = list(filter(None, re.split(r"[,\s]+", load_before.lower()))) if load_before else [] script_dependency_map[script_canonical_name] = { "extension": ext.canonical_name, From 48d6102b3105bb0179c8eab14ec7930945aca326 Mon Sep 17 00:00:00 2001 From: wfjsw Date: Sat, 11 Nov 2023 11:17:26 -0600 Subject: [PATCH 0260/1174] fix --- modules/extensions.py | 1 - 1 file changed, 1 deletion(-) diff --git a/modules/extensions.py b/modules/extensions.py index 795af996e77..5536db3eab1 100644 --- a/modules/extensions.py +++ b/modules/extensions.py @@ -183,7 +183,6 @@ def list_extensions(): config.read(os.path.join(path, "sd_webui_metadata.ini")) canonical_name = config.get("Extension", "Name", fallback=canonical_name) requires = config.get("Extension", "Requires", fallback=None) - continue except Exception: errors.report(f"Error reading sd_webui_metadata.ini for extension {extension_dirname}. " f"Will load regardless.", exc_info=True) From 3bb32befe9523a6acefbab7fe099f91660f41ea9 Mon Sep 17 00:00:00 2001 From: wfjsw Date: Sat, 11 Nov 2023 11:58:19 -0600 Subject: [PATCH 0261/1174] bug fix --- modules/scripts.py | 25 ++++++++++++++++++------- 1 file changed, 18 insertions(+), 7 deletions(-) diff --git a/modules/scripts.py b/modules/scripts.py index 5dd0555dd6a..b1f4504a5d3 100644 --- a/modules/scripts.py +++ b/modules/scripts.py @@ -322,6 +322,9 @@ def list_scripts(scriptdirname, extension, *, include_extensions=True): root_script_basedir = os.path.join(paths.script_path, scriptdirname) if os.path.exists(root_script_basedir): for filename in sorted(os.listdir(root_script_basedir)): + if not os.path.isfile(os.path.join(root_script_basedir, filename)): + continue + script_dependency_map[filename] = { "extension": None, "extension_dirname": None, @@ -335,19 +338,27 @@ def list_scripts(scriptdirname, extension, *, include_extensions=True): for ext in extensions.active(): extension_scripts_list = ext.list_files(scriptdirname, extension) for extension_script in extension_scripts_list: + if not os.path.isfile(extension_script.path): + continue + script_canonical_name = ext.canonical_name + "/" + extension_script.filename if ext.is_builtin: script_canonical_name = "builtin/" + script_canonical_name relative_path = scriptdirname + "/" + extension_script.filename - requires = None - load_before = None - load_after = None + requires = '' + load_before = '' + load_after = '' if ext.metadata is not None: - requires = ext.metadata.get(relative_path, "Requires", fallback=None) - load_before = ext.metadata.get(relative_path, "Before", fallback=None) - load_after = ext.metadata.get(relative_path, "After", fallback=None) + requires = ext.metadata.get(relative_path, "Requires", fallback='') + load_before = ext.metadata.get(relative_path, "Before", fallback='') + load_after = ext.metadata.get(relative_path, "After", fallback='') + + # propagate directory level metadata + requires = requires + ',' + ext.metadata.get(scriptdirname, "Requires", fallback='') + load_before = load_before + ',' + ext.metadata.get(scriptdirname, "Before", fallback='') + load_after = load_after + ',' + ext.metadata.get(scriptdirname, "After", fallback='') requires = list(filter(None, re.split(r"[,\s]+", requires.lower()))) if requires else [] load_after = list(filter(None, re.split(r"[,\s]+", load_after.lower()))) if load_after else [] @@ -387,7 +398,7 @@ def list_scripts(scriptdirname, extension, *, include_extensions=True): script_data['load_after'].remove(load_after_script) for script_canonical_name2, script_data2 in script_dependency_map.items(): if script_data2['extension'] == load_after_script: - script_data['load_after'].remove(script_canonical_name2) + script_data['load_after'].append(script_canonical_name2) break # build the DAG From f6762d2ad95e3de39fc900b3fd528310e512831f Mon Sep 17 00:00:00 2001 From: Tom Haelbich <65122811+h43lb1t0@users.noreply.github.com> Date: Sun, 12 Nov 2023 14:14:16 +0100 Subject: [PATCH 0262/1174] dir buttons start with / so only the correct dir will be shown and not dirs with a substrings as name from the dir --- modules/ui_extra_networks.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index 063bd7b80e6..43a94b74bec 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -138,8 +138,9 @@ def create_html(self, tabname): continue subdir = os.path.abspath(x)[len(parentdir):].replace("\\", "/") - while subdir.startswith("/"): - subdir = subdir[1:] + + if not subdir.startswith("/"): + subdir = "/" + subdir is_empty = len(os.listdir(x)) == 0 if not is_empty and not subdir.endswith("/"): From 8048f36072c8a281b8c8c79235df63a748ab7361 Mon Sep 17 00:00:00 2001 From: missionfloyd Date: Sun, 12 Nov 2023 17:12:50 -0700 Subject: [PATCH 0263/1174] Lint --- modules/ui_extra_networks.py | 1 - 1 file changed, 1 deletion(-) diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index 43a94b74bec..bd673285680 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -138,7 +138,6 @@ def create_html(self, tabname): continue subdir = os.path.abspath(x)[len(parentdir):].replace("\\", "/") - if not subdir.startswith("/"): subdir = "/" + subdir From 94e966956666ba13b368aaf781628085e3d4f7e3 Mon Sep 17 00:00:00 2001 From: kaalibro Date: Mon, 13 Nov 2023 14:51:06 +0600 Subject: [PATCH 0264/1174] Fixes generation restart not working for some users when 'Ctrl+Enter' is pressed --- script.js | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/script.js b/script.js index 5f6ee35427b..c0e678ea702 100644 --- a/script.js +++ b/script.js @@ -133,9 +133,18 @@ document.addEventListener('keydown', function(e) { if (isEnter && isModifierKey) { if (interruptButton.style.display === 'block') { interruptButton.click(); - setTimeout(function() { - generateButton.click(); - }, 500); + const callback = (mutationList) => { + for (const mutation of mutationList) { + if (mutation.type === 'attributes' && mutation.attributeName === 'style') { + if (interruptButton.style.display === 'none') { + generateButton.click(); + observer.disconnect(); + } + } + } + }; + const observer = new MutationObserver(callback); + observer.observe(interruptButton, {attributes: true}); } else { generateButton.click(); } From c1c816006e47f3b7dcf1512594fd31817242e7fa Mon Sep 17 00:00:00 2001 From: kaalibro Date: Mon, 13 Nov 2023 22:01:52 +0600 Subject: [PATCH 0265/1174] Adds 'Path' sorting for Extra network cards --- modules/shared_options.py | 2 +- modules/ui_extra_networks.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/modules/shared_options.py b/modules/shared_options.py index d40db5306ee..8fc7ef1d26e 100644 --- a/modules/shared_options.py +++ b/modules/shared_options.py @@ -235,7 +235,7 @@ "extra_networks_card_height": OptionInfo(0, "Card height for Extra Networks").info("in pixels"), "extra_networks_card_text_scale": OptionInfo(1.0, "Card text scale", gr.Slider, {"minimum": 0.0, "maximum": 2.0, "step": 0.01}).info("1 = original size"), "extra_networks_card_show_desc": OptionInfo(True, "Show description on card"), - "extra_networks_card_order_field": OptionInfo("Name", "Default order field for Extra Networks cards", gr.Dropdown, {"choices": ['Name', 'Date Created', 'Date Modified']}).needs_reload_ui(), + "extra_networks_card_order_field": OptionInfo("Path", "Default order field for Extra Networks cards", gr.Dropdown, {"choices": ['Path', 'Name', 'Date Created', 'Date Modified']}).needs_reload_ui(), "extra_networks_card_order": OptionInfo("Ascending", "Default order for Extra Networks cards", gr.Dropdown, {"choices": ['Ascending', 'Descending']}).needs_reload_ui(), "extra_networks_add_text_separator": OptionInfo(" ", "Extra networks separator").info("extra text to add before <...> when adding extra network to prompt"), "ui_extra_networks_tab_reorder": OptionInfo("", "Extra networks tab order").needs_reload_ui(), diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index 7907cd63f9c..f03e20337c2 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -279,6 +279,7 @@ def get_sort_keys(self, path): "date_created": int(stat.st_ctime or 0), "date_modified": int(stat.st_mtime or 0), "name": pth.name.lower(), + "path": str(pth.parent).lower(), } def find_preview(self, path): @@ -382,7 +383,7 @@ def create_ui(interface: gr.Blocks, unrelated_tabs, tabname): related_tabs.append(tab) edit_search = gr.Textbox('', show_label=False, elem_id=tabname+"_extra_search", elem_classes="search", placeholder="Search...", visible=False, interactive=True) - dropdown_sort = gr.Dropdown(choices=['Name', 'Date Created', 'Date Modified', ], value=shared.opts.extra_networks_card_order_field, elem_id=tabname+"_extra_sort", elem_classes="sort", multiselect=False, visible=False, show_label=False, interactive=True, label=tabname+"_extra_sort_order") + dropdown_sort = gr.Dropdown(choices=['Path', 'Name', 'Date Created', 'Date Modified', ], value=shared.opts.extra_networks_card_order_field, elem_id=tabname+"_extra_sort", elem_classes="sort", multiselect=False, visible=False, show_label=False, interactive=True, label=tabname+"_extra_sort_order") button_sortorder = ToolButton(switch_values_symbol, elem_id=tabname+"_extra_sortorder", elem_classes=["sortorder"] + ([] if shared.opts.extra_networks_card_order == "Ascending" else ["sortReverse"]), visible=False, tooltip="Invert sort order") button_refresh = gr.Button('Refresh', elem_id=tabname+"_extra_refresh", visible=False) checkbox_show_dirs = gr.Checkbox(True, label='Show dirs', elem_id=tabname+"_extra_show_dirs", elem_classes="show-dirs", visible=False) From a292d2c47f51fc71cc186709bdf3706f0944b7d6 Mon Sep 17 00:00:00 2001 From: AngelBottomless Date: Wed, 15 Nov 2023 14:26:37 +0900 Subject: [PATCH 0266/1174] hotfix: call shared.state.end() after postprocessing done --- modules/postprocessing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/postprocessing.py b/modules/postprocessing.py index cf04d38b059..fd0c0cc99c5 100644 --- a/modules/postprocessing.py +++ b/modules/postprocessing.py @@ -78,7 +78,7 @@ def get_images(extras_mode, image, image_folder, input_dir): image_data.close() devices.torch_gc() - + shared.state.end() return outputs, ui_common.plaintext_to_html(infotext), '' From b29fc6d4de8812b25c520a46676cda13c3fe64ca Mon Sep 17 00:00:00 2001 From: aria1th <35677394+aria1th@users.noreply.github.com> Date: Sat, 11 Nov 2023 23:43:13 +0900 Subject: [PATCH 0267/1174] Implement Hypertile Co-Authored-By: Kieran Hunt --- modules/hypertile.py | 333 ++++++++++++++++++++++++++++++++++++++++++ modules/processing.py | 65 ++++----- 2 files changed, 358 insertions(+), 40 deletions(-) create mode 100644 modules/hypertile.py diff --git a/modules/hypertile.py b/modules/hypertile.py new file mode 100644 index 00000000000..ab1c74c022f --- /dev/null +++ b/modules/hypertile.py @@ -0,0 +1,333 @@ +""" +Hypertile module for splitting attention layers in SD-1.5 U-Net and SD-1.5 VAE +Warn : The patch works well only if the input image has a width and height that are multiples of 128 +Author : @tfernd Github : https://github.com/tfernd/HyperTile +""" + +from __future__ import annotations +from typing import Callable +from typing_extensions import Literal + +import logging +from functools import wraps, cache +from contextlib import contextmanager + +import math +import torch.nn as nn +import random + +from einops import rearrange + +# TODO add SD-XL layers +DEPTH_LAYERS = { + 0: [ + # SD 1.5 U-Net (diffusers) + "down_blocks.0.attentions.0.transformer_blocks.0.attn1", + "down_blocks.0.attentions.1.transformer_blocks.0.attn1", + "up_blocks.3.attentions.0.transformer_blocks.0.attn1", + "up_blocks.3.attentions.1.transformer_blocks.0.attn1", + "up_blocks.3.attentions.2.transformer_blocks.0.attn1", + # SD 1.5 U-Net (ldm) + "input_blocks.1.1.transformer_blocks.0.attn1", + "input_blocks.2.1.transformer_blocks.0.attn1", + "output_blocks.9.1.transformer_blocks.0.attn1", + "output_blocks.10.1.transformer_blocks.0.attn1", + "output_blocks.11.1.transformer_blocks.0.attn1", + # SD 1.5 VAE + "decoder.mid_block.attentions.0", + ], + 1: [ + # SD 1.5 U-Net (diffusers) + "down_blocks.1.attentions.0.transformer_blocks.0.attn1", + "down_blocks.1.attentions.1.transformer_blocks.0.attn1", + "up_blocks.2.attentions.0.transformer_blocks.0.attn1", + "up_blocks.2.attentions.1.transformer_blocks.0.attn1", + "up_blocks.2.attentions.2.transformer_blocks.0.attn1", + # SD 1.5 U-Net (ldm) + "input_blocks.4.1.transformer_blocks.0.attn1", + "input_blocks.5.1.transformer_blocks.0.attn1", + "output_blocks.6.1.transformer_blocks.0.attn1", + "output_blocks.7.1.transformer_blocks.0.attn1", + "output_blocks.8.1.transformer_blocks.0.attn1", + ], + 2: [ + # SD 1.5 U-Net (diffusers) + "down_blocks.2.attentions.0.transformer_blocks.0.attn1", + "down_blocks.2.attentions.1.transformer_blocks.0.attn1", + "up_blocks.1.attentions.0.transformer_blocks.0.attn1", + "up_blocks.1.attentions.1.transformer_blocks.0.attn1", + "up_blocks.1.attentions.2.transformer_blocks.0.attn1", + # SD 1.5 U-Net (ldm) + "input_blocks.7.1.transformer_blocks.0.attn1", + "input_blocks.8.1.transformer_blocks.0.attn1", + "output_blocks.3.1.transformer_blocks.0.attn1", + "output_blocks.4.1.transformer_blocks.0.attn1", + "output_blocks.5.1.transformer_blocks.0.attn1", + ], + 3: [ + # SD 1.5 U-Net (diffusers) + "mid_block.attentions.0.transformer_blocks.0.attn1", + # SD 1.5 U-Net (ldm) + "middle_block.1.transformer_blocks.0.attn1", + ], +} +# XL layers, thanks for GitHub@gel-crabs for the help +DEPTH_LAYERS_XL = { + 0: [ + # SD 1.5 U-Net (diffusers) + "down_blocks.0.attentions.0.transformer_blocks.0.attn1", + "down_blocks.0.attentions.1.transformer_blocks.0.attn1", + "up_blocks.3.attentions.0.transformer_blocks.0.attn1", + "up_blocks.3.attentions.1.transformer_blocks.0.attn1", + "up_blocks.3.attentions.2.transformer_blocks.0.attn1", + # SD 1.5 U-Net (ldm) + "input_blocks.4.1.transformer_blocks.0.attn1", + "input_blocks.5.1.transformer_blocks.0.attn1", + "output_blocks.3.1.transformer_blocks.0.attn1", + "output_blocks.4.1.transformer_blocks.0.attn1", + "output_blocks.5.1.transformer_blocks.0.attn1", + # SD 1.5 VAE + "decoder.mid_block.attentions.0", + "decoder.mid.attn_1", + ], + 1: [ + # SD 1.5 U-Net (diffusers) + #"down_blocks.1.attentions.0.transformer_blocks.0.attn1", + #"down_blocks.1.attentions.1.transformer_blocks.0.attn1", + #"up_blocks.2.attentions.0.transformer_blocks.0.attn1", + #"up_blocks.2.attentions.1.transformer_blocks.0.attn1", + #"up_blocks.2.attentions.2.transformer_blocks.0.attn1", + # SD 1.5 U-Net (ldm) + "input_blocks.4.1.transformer_blocks.1.attn1", + "input_blocks.5.1.transformer_blocks.1.attn1", + "output_blocks.3.1.transformer_blocks.1.attn1", + "output_blocks.4.1.transformer_blocks.1.attn1", + "output_blocks.5.1.transformer_blocks.1.attn1", + "input_blocks.7.1.transformer_blocks.0.attn1", + "input_blocks.8.1.transformer_blocks.0.attn1", + "output_blocks.0.1.transformer_blocks.0.attn1", + "output_blocks.1.1.transformer_blocks.0.attn1", + "output_blocks.2.1.transformer_blocks.0.attn1", + "input_blocks.7.1.transformer_blocks.1.attn1", + "input_blocks.8.1.transformer_blocks.1.attn1", + "output_blocks.0.1.transformer_blocks.1.attn1", + "output_blocks.1.1.transformer_blocks.1.attn1", + "output_blocks.2.1.transformer_blocks.1.attn1", + "input_blocks.7.1.transformer_blocks.2.attn1", + "input_blocks.8.1.transformer_blocks.2.attn1", + "output_blocks.0.1.transformer_blocks.2.attn1", + "output_blocks.1.1.transformer_blocks.2.attn1", + "output_blocks.2.1.transformer_blocks.2.attn1", + "input_blocks.7.1.transformer_blocks.3.attn1", + "input_blocks.8.1.transformer_blocks.3.attn1", + "output_blocks.0.1.transformer_blocks.3.attn1", + "output_blocks.1.1.transformer_blocks.3.attn1", + "output_blocks.2.1.transformer_blocks.3.attn1", + "input_blocks.7.1.transformer_blocks.4.attn1", + "input_blocks.8.1.transformer_blocks.4.attn1", + "output_blocks.0.1.transformer_blocks.4.attn1", + "output_blocks.1.1.transformer_blocks.4.attn1", + "output_blocks.2.1.transformer_blocks.4.attn1", + "input_blocks.7.1.transformer_blocks.5.attn1", + "input_blocks.8.1.transformer_blocks.5.attn1", + "output_blocks.0.1.transformer_blocks.5.attn1", + "output_blocks.1.1.transformer_blocks.5.attn1", + "output_blocks.2.1.transformer_blocks.5.attn1", + "input_blocks.7.1.transformer_blocks.6.attn1", + "input_blocks.8.1.transformer_blocks.6.attn1", + "output_blocks.0.1.transformer_blocks.6.attn1", + "output_blocks.1.1.transformer_blocks.6.attn1", + "output_blocks.2.1.transformer_blocks.6.attn1", + "input_blocks.7.1.transformer_blocks.7.attn1", + "input_blocks.8.1.transformer_blocks.7.attn1", + "output_blocks.0.1.transformer_blocks.7.attn1", + "output_blocks.1.1.transformer_blocks.7.attn1", + "output_blocks.2.1.transformer_blocks.7.attn1", + "input_blocks.7.1.transformer_blocks.8.attn1", + "input_blocks.8.1.transformer_blocks.8.attn1", + "output_blocks.0.1.transformer_blocks.8.attn1", + "output_blocks.1.1.transformer_blocks.8.attn1", + "output_blocks.2.1.transformer_blocks.8.attn1", + "input_blocks.7.1.transformer_blocks.9.attn1", + "input_blocks.8.1.transformer_blocks.9.attn1", + "output_blocks.0.1.transformer_blocks.9.attn1", + "output_blocks.1.1.transformer_blocks.9.attn1", + "output_blocks.2.1.transformer_blocks.9.attn1", + ], + 2: [ + # SD 1.5 U-Net (diffusers) + "mid_block.attentions.0.transformer_blocks.0.attn1", + # SD 1.5 U-Net (ldm) + "middle_block.1.transformer_blocks.0.attn1", + "middle_block.1.transformer_blocks.1.attn1", + "middle_block.1.transformer_blocks.2.attn1", + "middle_block.1.transformer_blocks.3.attn1", + "middle_block.1.transformer_blocks.4.attn1", + "middle_block.1.transformer_blocks.5.attn1", + "middle_block.1.transformer_blocks.6.attn1", + "middle_block.1.transformer_blocks.7.attn1", + "middle_block.1.transformer_blocks.8.attn1", + "middle_block.1.transformer_blocks.9.attn1", + ], +} + + +RNG_INSTANCE = random.Random() + +def random_divisor(value: int, min_value: int, /, max_options: int = 1) -> int: + """ + Returns a random divisor of value that + x * min_value <= value + if max_options is 1, the behavior is deterministic + """ + min_value = min(min_value, value) + + # All big divisors of value (inclusive) + divisors = [i for i in range(min_value, value + 1) if value % i == 0] # divisors in small -> big order + + ns = [value // i for i in divisors[:max_options]] # has at least 1 element # big -> small order + + idx = RNG_INSTANCE.randint(0, len(ns) - 1) + + return ns[idx] + +def set_hypertile_seed(seed: int) -> None: + RNG_INSTANCE.seed(seed) + +def largest_tile_size_available(width:int, height:int) -> int: + """ + Calculates the largest tile size available for a given width and height + Tile size is always a power of 2 + """ + gcd = math.gcd(width, height) + largest_tile_size_available = 1 + while gcd % (largest_tile_size_available * 2) == 0: + largest_tile_size_available *= 2 + return largest_tile_size_available + +def iterative_closest_divisors(hw:int, aspect_ratio:float) -> tuple[int, int]: + """ + Finds h and w such that h*w = hw and h/w = aspect_ratio + We check all possible divisors of hw and return the closest to the aspect ratio + """ + divisors = [i for i in range(2, hw + 1) if hw % i == 0] # all divisors of hw + pairs = [(i, hw // i) for i in divisors] # all pairs of divisors of hw + ratios = [w/h for h, w in pairs] # all ratios of pairs of divisors of hw + closest_ratio = min(ratios, key=lambda x: abs(x - aspect_ratio)) # closest ratio to aspect_ratio + closest_pair = pairs[ratios.index(closest_ratio)] # closest pair of divisors to aspect_ratio + return closest_pair + +@cache +def find_hw_candidates(hw:int, aspect_ratio:float) -> tuple[int, int]: + """ + Finds h and w such that h*w = hw and h/w = aspect_ratio + """ + h, w = round(math.sqrt(hw * aspect_ratio)), round(math.sqrt(hw / aspect_ratio)) + # find h and w such that h*w = hw and h/w = aspect_ratio + if h * w != hw: + w_candidate = hw / h + # check if w is an integer + if not w_candidate.is_integer(): + h_candidate = hw / w + # check if h is an integer + if not h_candidate.is_integer(): + return iterative_closest_divisors(hw, aspect_ratio) + else: + h = int(h_candidate) + else: + w = int(w_candidate) + return h, w + +@contextmanager +def split_attention( + layer: nn.Module, + /, + aspect_ratio: float, # width/height + tile_size: int = 128, # 128 for VAE + swap_size: int = 1, # 1 for VAE + *, + disable: bool = False, + max_depth: Literal[0, 1, 2, 3] = 0, # ! Try 0 or 1 + scale_depth: bool = True, # scale the tile-size depending on the depth + is_sdxl: bool = False, # is the model SD-XL +): + # Hijacks AttnBlock from ldm and Attention from diffusers + + if disable: + logging.info(f"Attention for {layer.__class__.__qualname__} not splitted") + yield + return + + latent_tile_size = max(128, tile_size) // 8 + + def self_attn_forward(forward: Callable, depth: int, layer_name: str, module: nn.Module) -> Callable: + @wraps(forward) + def wrapper(*args, **kwargs): + x = args[0] + + # VAE + if x.ndim == 4: + b, c, h, w = x.shape + + nh = random_divisor(h, latent_tile_size, swap_size) + nw = random_divisor(w, latent_tile_size, swap_size) + + if nh * nw > 1: + x = rearrange(x, "b c (nh h) (nw w) -> (b nh nw) c h w", nh=nh, nw=nw) # split into nh * nw tiles + + out = forward(x, *args[1:], **kwargs) + + if nh * nw > 1: + out = rearrange(out, "(b nh nw) c h w -> b c (nh h) (nw w)", nh=nh, nw=nw) + + # U-Net + else: + hw: int = x.size(1) + h, w = find_hw_candidates(hw, aspect_ratio) + assert h * w == hw, f"Invalid aspect ratio {aspect_ratio} for input of shape {x.shape}, hw={hw}, h={h}, w={w}" + + factor = 2**depth if scale_depth else 1 + nh = random_divisor(h, latent_tile_size * factor, swap_size) + nw = random_divisor(w, latent_tile_size * factor, swap_size) + + module._split_sizes_hypertile.append((nh, nw)) # type: ignore + + if nh * nw > 1: + x = rearrange(x, "b (nh h nw w) c -> (b nh nw) (h w) c", h=h // nh, w=w // nw, nh=nh, nw=nw) + + out = forward(x, *args[1:], **kwargs) + + if nh * nw > 1: + out = rearrange(out, "(b nh nw) hw c -> b nh nw hw c", nh=nh, nw=nw) + out = rearrange(out, "b nh nw (h w) c -> b (nh h nw w) c", h=h // nh, w=w // nw) + + return out + + return wrapper + + # Handle hijacking the forward method and recovering afterwards + try: + if is_sdxl: + layers = DEPTH_LAYERS_XL + else: + layers = DEPTH_LAYERS + for depth in range(max_depth + 1): + for layer_name, module in layer.named_modules(): + if any(layer_name.endswith(try_name) for try_name in layers[depth]): + # print input shape for debugging + logging.debug(f"HyperTile hijacking attention layer at depth {depth}: {layer_name}") + # hijack + module._original_forward_hypertile = module.forward + module.forward = self_attn_forward(module.forward, depth, layer_name, module) + module._split_sizes_hypertile = [] + yield + finally: + for layer_name, module in layer.named_modules(): + # remove hijack + if hasattr(module, "_original_forward_hypertile"): + if module._split_sizes_hypertile: + logging.debug(f"layer {layer_name} splitted with ({module._split_sizes_hypertile})") + # recover + module.forward = module._original_forward_hypertile + del module._original_forward_hypertile + del module._split_sizes_hypertile diff --git a/modules/processing.py b/modules/processing.py index e2309534334..e19a09a3c60 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -24,6 +24,7 @@ import modules.shared as shared import modules.paths as paths import modules.face_restoration +from modules.hypertile import split_attention, set_hypertile_seed, largest_tile_size_available import modules.images as images import modules.styles import modules.sd_models as sd_models @@ -799,17 +800,6 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: infotexts = [] output_images = [] - unet_object = p.sd_model.model - vae_model = p.sd_model.first_stage_model - try: - from hyper_tile import split_attention, flush - except (ImportError, ModuleNotFoundError): # pip install git+https://github.com/tfernd/HyperTile@2ef64b2800d007d305755c33550537410310d7df - split_attention = lambda *args, **kwargs: lambda x: x # return a no-op context manager - flush = lambda: None - import random - saved_rng_state = random.getstate() - random.seed(p.seed) # hyper_tile uses random, so we need to seed it - with torch.no_grad(), p.sd_model.ema_scope(): with devices.autocast(): p.init(p.all_prompts, p.all_seeds, p.all_subseeds) @@ -871,29 +861,20 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: p.comment(comment) p.extra_generation_params.update(model_hijack.extra_generation_params) - + set_hypertile_seed(p.seed) + # add batch size + hypertile status to information to reproduce the run if p.n_iter > 1: shared.state.job = f"Batch {n+1} out of {p.n_iter}" with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast(): - # get largest tile size available, which is 2^x which is factor of gcd of p.width and p.height - gcd = math.gcd(p.width, p.height) - largest_tile_size_available = 1 - while gcd % (largest_tile_size_available * 2) == 0: - largest_tile_size_available *= 2 - aspect_ratio = p.width / p.height - with split_attention(vae_model, aspect_ratio=aspect_ratio, tile_size=min(largest_tile_size_available, 128), disable=not shared.opts.hypertile_split_vae_attn): - with split_attention(unet_object, aspect_ratio=aspect_ratio, tile_size=min(largest_tile_size_available, 256), swap_size=2, disable=not shared.opts.hypertile_split_unet_attn): - flush() - samples_ddim = p.sample(conditioning=p.c, unconditional_conditioning=p.uc, seeds=p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, prompts=p.prompts) + samples_ddim = p.sample(conditioning=p.c, unconditional_conditioning=p.uc, seeds=p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, prompts=p.prompts) if getattr(samples_ddim, 'already_decoded', False): x_samples_ddim = samples_ddim else: if opts.sd_vae_decode_method != 'Full': p.extra_generation_params['VAE Decoder'] = opts.sd_vae_decode_method - with split_attention(vae_model, aspect_ratio=aspect_ratio, tile_size=min(largest_tile_size_available, 128), disable=not shared.opts.hypertile_split_vae_attn): - flush() + with split_attention(p.sd_model.first_stage_model, aspect_ratio = p.width / p.height, tile_size=min(largest_tile_size_available(p.width, p.height), 128), disable=not shared.opts.hypertile_split_vae_attn, is_sdxl=shared.sd_model.is_sdxl): x_samples_ddim = decode_latent_batch(p.sd_model, samples_ddim, target_device=devices.cpu, check_for_nans=True) x_samples_ddim = torch.stack(x_samples_ddim).float() @@ -1000,7 +981,6 @@ def infotext(index=0, use_main_prompt=False): if opts.grid_save: images.save_image(grid, p.outpath_grids, "grid", p.all_seeds[0], p.all_prompts[0], opts.grid_format, info=infotext(use_main_prompt=True), short_filename=not opts.grid_extended_filename, p=p, grid=True) - random.setstate(saved_rng_state) if not p.disable_extra_networks and p.extra_network_data: extra_networks.deactivate(p, p.extra_network_data) @@ -1161,24 +1141,25 @@ def init(self, all_prompts, all_seeds, all_subseeds): def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts): self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model) - + aspect_ratio = self.width / self.height x = self.rng.next() - samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x)) + tile_size = largest_tile_size_available(self.width, self.height) + with split_attention(self.sd_model.first_stage_model, aspect_ratio=aspect_ratio, tile_size=min(tile_size, 128), swap_size=1, disable=not shared.opts.hypertile_split_vae_attn, is_sdxl=shared.sd_model.is_sdxl): + with split_attention(self.sd_model.model, aspect_ratio=aspect_ratio, tile_size=min(tile_size, 256), swap_size=2, disable=not shared.opts.hypertile_split_unet_attn, is_sdxl=shared.sd_model.is_sdxl): + devices.torch_gc() + samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x)) del x - if not self.enable_hr: return samples if self.latent_scale_mode is None: - decoded_samples = torch.stack(decode_latent_batch(self.sd_model, samples, target_device=devices.cpu, check_for_nans=True)).to(dtype=torch.float32) + with split_attention(self.sd_model.first_stage_model, aspect_ratio=aspect_ratio, tile_size=min(tile_size, 256), swap_size=1, disable=not shared.opts.hypertile_split_vae_attn, is_sdxl=shared.sd_model.is_sdxl): + decoded_samples = torch.stack(decode_latent_batch(self.sd_model, samples, target_device=devices.cpu, check_for_nans=True)).to(dtype=torch.float32) else: decoded_samples = None with sd_models.SkipWritingToConfig(): sd_models.reload_model_weights(info=self.hr_checkpoint_info) - - devices.torch_gc() - return self.sample_hr_pass(samples, decoded_samples, seeds, subseeds, subseed_strength, prompts) def sample_hr_pass(self, samples, decoded_samples, seeds, subseeds, subseed_strength, prompts): @@ -1186,7 +1167,6 @@ def sample_hr_pass(self, samples, decoded_samples, seeds, subseeds, subseed_stre return samples self.is_hr_pass = True - target_width = self.hr_upscale_to_x target_height = self.hr_upscale_to_y @@ -1264,18 +1244,19 @@ def save_intermediate(image, index): if self.scripts is not None: self.scripts.before_hr(self) - - samples = self.sampler.sample_img2img(self, samples, noise, self.hr_c, self.hr_uc, steps=self.hr_second_pass_steps or self.steps, image_conditioning=image_conditioning) + tile_size = largest_tile_size_available(target_width, target_height) + with split_attention(self.sd_model.first_stage_model, aspect_ratio=target_width / target_height, tile_size=min(tile_size, 256), swap_size=1, disable=not opts.hypertile_split_vae_attn, is_sdxl=shared.sd_model.is_sdxl): + with split_attention(self.sd_model.model, aspect_ratio=target_width / target_height, tile_size=min(tile_size, 256), swap_size=3, max_depth=1,scale_depth=True, disable=not opts.hypertile_split_unet_attn, is_sdxl=shared.sd_model.is_sdxl): + samples = self.sampler.sample_img2img(self, samples, noise, self.hr_c, self.hr_uc, steps=self.hr_second_pass_steps or self.steps, image_conditioning=image_conditioning) sd_models.apply_token_merging(self.sd_model, self.get_token_merging_ratio()) self.sampler = None devices.torch_gc() - - decoded_samples = decode_latent_batch(self.sd_model, samples, target_device=devices.cpu, check_for_nans=True) + with split_attention(self.sd_model.first_stage_model, aspect_ratio=target_width / target_height, tile_size=min(tile_size, 256), swap_size=1, disable=not opts.hypertile_split_vae_attn, is_sdxl=shared.sd_model.is_sdxl): + decoded_samples = decode_latent_batch(self.sd_model, samples, target_device=devices.cpu, check_for_nans=True) self.is_hr_pass = False - return decoded_samples def close(self): @@ -1550,8 +1531,12 @@ def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subs if self.initial_noise_multiplier != 1.0: self.extra_generation_params["Noise multiplier"] = self.initial_noise_multiplier x *= self.initial_noise_multiplier - - samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning, image_conditioning=self.image_conditioning) + aspect_ratio = self.width / self.height + tile_size = largest_tile_size_available(self.width, self.height) + with split_attention(self.sd_model.first_stage_model, aspect_ratio=aspect_ratio, tile_size=min(tile_size, 128), swap_size=1, disable=not shared.opts.hypertile_split_vae_attn, is_sdxl=shared.sd_model.is_sdxl): + with split_attention(self.sd_model.model, aspect_ratio=aspect_ratio, tile_size=min(tile_size, 256), swap_size=2, disable=not shared.opts.hypertile_split_unet_attn, is_sdxl=shared.sd_model.is_sdxl): + devices.torch_gc() + samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning, image_conditioning=self.image_conditioning) if self.mask is not None: samples = samples * self.nmask + self.init_latent * self.mask From af45872fdb8a66ffd6a405d99120e0bacbb4a170 Mon Sep 17 00:00:00 2001 From: aria1th <35677394+aria1th@users.noreply.github.com> Date: Wed, 15 Nov 2023 15:15:14 +0900 Subject: [PATCH 0268/1174] copy LDM VAE key from XL --- modules/hypertile.py | 1 + 1 file changed, 1 insertion(+) diff --git a/modules/hypertile.py b/modules/hypertile.py index ab1c74c022f..32d8604cc97 100644 --- a/modules/hypertile.py +++ b/modules/hypertile.py @@ -35,6 +35,7 @@ "output_blocks.11.1.transformer_blocks.0.attn1", # SD 1.5 VAE "decoder.mid_block.attentions.0", + "decoder.mid.attn_1", ], 1: [ # SD 1.5 U-Net (diffusers) From d6d0b22e6657fc84039e82ee735a57101bfe7c17 Mon Sep 17 00:00:00 2001 From: v0xie <28695009+v0xie@users.noreply.github.com> Date: Wed, 15 Nov 2023 03:08:50 -0800 Subject: [PATCH 0269/1174] fix: ignore calc_scale() for COFT which has very small alpha --- extensions-builtin/Lora/network_oft.py | 16 +++++----------- 1 file changed, 5 insertions(+), 11 deletions(-) diff --git a/extensions-builtin/Lora/network_oft.py b/extensions-builtin/Lora/network_oft.py index 93402bb28f1..c45a8d23a75 100644 --- a/extensions-builtin/Lora/network_oft.py +++ b/extensions-builtin/Lora/network_oft.py @@ -99,12 +99,9 @@ def calc_updown_kb(self, orig_weight, multiplier): is_other_linear = type(self.sd_module) in [torch.nn.MultiheadAttention] if not is_other_linear: - #if is_other_linear and orig_weight.shape[0] != orig_weight.shape[1]: - # orig_weight=orig_weight.permute(1, 0) - oft_blocks = self.oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype) - # without this line the results are significantly worse / less accurate + # ensure skew-symmetric matrix oft_blocks = oft_blocks - oft_blocks.transpose(1, 2) R = oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype) @@ -118,9 +115,6 @@ def calc_updown_kb(self, orig_weight, multiplier): ) merged_weight = rearrange(merged_weight, 'k m ... -> (k m) ...') - #if is_other_linear and orig_weight.shape[0] != orig_weight.shape[1]: - # orig_weight=orig_weight.permute(1, 0) - updown = merged_weight.to(orig_weight.device, dtype=orig_weight.dtype) - orig_weight output_shape = orig_weight.shape else: @@ -132,10 +126,10 @@ def calc_updown_kb(self, orig_weight, multiplier): return self.finalize_updown(updown, orig_weight, output_shape) def calc_updown(self, orig_weight): - multiplier = self.multiplier() * self.calc_scale() - #if self.is_kohya: - # return self.calc_updown_kohya(orig_weight, multiplier) - #else: + # if alpha is a very small number as in coft, calc_scale will return a almost zero number so we ignore it + #multiplier = self.multiplier() * self.calc_scale() + multiplier = self.multiplier() + return self.calc_updown_kb(orig_weight, multiplier) # override to remove the multiplier/scale factor; it's already multiplied in get_weight From eb667e715ad3eea981f6263c143ab0422e5340c9 Mon Sep 17 00:00:00 2001 From: v0xie <28695009+v0xie@users.noreply.github.com> Date: Wed, 15 Nov 2023 18:28:48 -0800 Subject: [PATCH 0270/1174] feat: LyCORIS/kohya OFT network support --- extensions-builtin/Lora/network_oft.py | 108 ++++++------------------- 1 file changed, 26 insertions(+), 82 deletions(-) diff --git a/extensions-builtin/Lora/network_oft.py b/extensions-builtin/Lora/network_oft.py index c45a8d23a75..05c3781187a 100644 --- a/extensions-builtin/Lora/network_oft.py +++ b/extensions-builtin/Lora/network_oft.py @@ -11,8 +11,8 @@ def create_module(self, net: network.Network, weights: network.NetworkWeights): return None -# adapted from kohya-ss' implementation https://github.com/kohya-ss/sd-scripts/blob/main/networks/oft.py -# and KohakuBlueleaf's implementation https://github.com/KohakuBlueleaf/LyCORIS/blob/dev/lycoris/modules/diag_oft.py +# Supports both kohya-ss' implementation of COFT https://github.com/kohya-ss/sd-scripts/blob/main/networks/oft.py +# and KohakuBlueleaf's implementation of OFT/COFT https://github.com/KohakuBlueleaf/LyCORIS/blob/dev/lycoris/modules/diag_oft.py class NetworkModuleOFT(network.NetworkModule): def __init__(self, net: network.Network, weights: network.NetworkWeights): @@ -25,117 +25,61 @@ def __init__(self, net: network.Network, weights: network.NetworkWeights): if "oft_blocks" in weights.w.keys(): self.is_kohya = True self.oft_blocks = weights.w["oft_blocks"] # (num_blocks, block_size, block_size) - self.alpha = weights.w["alpha"] + self.alpha = weights.w["alpha"] # alpha is constraint self.dim = self.oft_blocks.shape[0] # lora dim - #self.oft_blocks = rearrange(self.oft_blocks, 'k m ... -> (k m) ...') + # LyCORIS elif "oft_diag" in weights.w.keys(): self.is_kohya = False - self.oft_blocks = weights.w["oft_diag"] # (num_blocks, block_size, block_size) - - # alpha is rank if alpha is 0 or None - if self.alpha is None: - pass - self.dim = self.oft_blocks.shape[1] # FIXME: almost certainly incorrect, assumes tensor is shape [*, m, n] - else: - raise ValueError("oft_blocks or oft_diag must be in weights dict") + self.oft_blocks = weights.w["oft_diag"] + # self.alpha is unused + self.dim = self.oft_blocks.shape[1] # (num_blocks, block_size, block_size) is_linear = type(self.sd_module) in [torch.nn.Linear, torch.nn.modules.linear.NonDynamicallyQuantizableLinear] is_conv = type(self.sd_module) in [torch.nn.Conv2d] - is_other_linear = type(self.sd_module) in [torch.nn.MultiheadAttention] + is_other_linear = type(self.sd_module) in [torch.nn.MultiheadAttention] # unsupported if is_linear: self.out_dim = self.sd_module.out_features - elif is_other_linear: - self.out_dim = self.sd_module.embed_dim elif is_conv: self.out_dim = self.sd_module.out_channels - else: - raise ValueError("sd_module must be Linear or Conv") + elif is_other_linear: + self.out_dim = self.sd_module.embed_dim if self.is_kohya: self.constraint = self.alpha * self.out_dim - self.num_blocks, self.block_size = factorization(self.out_dim, self.dim) + self.num_blocks = self.dim + self.block_size = self.out_dim // self.dim else: self.constraint = None self.block_size, self.num_blocks = factorization(self.out_dim, self.dim) - def merge_weight(self, R_weight, org_weight): - R_weight = R_weight.to(org_weight.device, dtype=org_weight.dtype) - if org_weight.dim() == 4: - weight = torch.einsum("oihw, op -> pihw", org_weight, R_weight) - else: - weight = torch.einsum("oi, op -> pi", org_weight, R_weight) - return weight - - def get_weight(self, oft_blocks, multiplier=None): - if self.constraint is not None: - constraint = self.constraint.to(oft_blocks.device, dtype=oft_blocks.dtype) - - block_Q = oft_blocks - oft_blocks.transpose(1, 2) - norm_Q = torch.norm(block_Q.flatten()) - if self.constraint is not None: - new_norm_Q = torch.clamp(norm_Q, max=constraint) - else: - new_norm_Q = norm_Q - block_Q = block_Q * ((new_norm_Q + 1e-8) / (norm_Q + 1e-8)) - m_I = torch.eye(self.num_blocks, device=oft_blocks.device).unsqueeze(0).repeat(self.block_size, 1, 1) - #m_I = torch.eye(self.block_size, device=oft_blocks.device).unsqueeze(0).repeat(self.num_blocks, 1, 1) - block_R = torch.matmul(m_I + block_Q, (m_I - block_Q).inverse()) + def calc_updown_kb(self, orig_weight, multiplier): + oft_blocks = self.oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype) + oft_blocks = oft_blocks - oft_blocks.transpose(1, 2) # ensure skew-symmetric orthogonal matrix - block_R_weighted = multiplier * block_R + (1 - multiplier) * m_I - R = torch.block_diag(*block_R_weighted) - return R + R = oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype) + R = R * multiplier + torch.eye(self.block_size, device=orig_weight.device) - def calc_updown_kohya(self, orig_weight, multiplier): - R = self.get_weight(self.oft_blocks, multiplier) - merged_weight = self.merge_weight(R, orig_weight) + # This errors out for MultiheadAttention, might need to be handled up-stream + merged_weight = rearrange(orig_weight, '(k n) ... -> k n ...', k=self.num_blocks, n=self.block_size) + merged_weight = torch.einsum( + 'k n m, k n ... -> k m ...', + R, + merged_weight + ) + merged_weight = rearrange(merged_weight, 'k m ... -> (k m) ...') updown = merged_weight.to(orig_weight.device, dtype=orig_weight.dtype) - orig_weight output_shape = orig_weight.shape - orig_weight = orig_weight - return self.finalize_updown(updown, orig_weight, output_shape) - - def calc_updown_kb(self, orig_weight, multiplier): - is_other_linear = type(self.sd_module) in [torch.nn.MultiheadAttention] - - if not is_other_linear: - oft_blocks = self.oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype) - - # ensure skew-symmetric matrix - oft_blocks = oft_blocks - oft_blocks.transpose(1, 2) - - R = oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype) - R = R * multiplier + torch.eye(self.block_size, device=orig_weight.device) - - merged_weight = rearrange(orig_weight, '(k n) ... -> k n ...', k=self.num_blocks, n=self.block_size) - merged_weight = torch.einsum( - 'k n m, k n ... -> k m ...', - R, - merged_weight - ) - merged_weight = rearrange(merged_weight, 'k m ... -> (k m) ...') - - updown = merged_weight.to(orig_weight.device, dtype=orig_weight.dtype) - orig_weight - output_shape = orig_weight.shape - else: - # FIXME: skip MultiheadAttention for now - #up = self.lin_module.weight.to(orig_weight.device, dtype=orig_weight.dtype) - updown = torch.zeros([orig_weight.shape[1], orig_weight.shape[1]], device=orig_weight.device, dtype=orig_weight.dtype) - output_shape = (orig_weight.shape[1], orig_weight.shape[1]) - return self.finalize_updown(updown, orig_weight, output_shape) def calc_updown(self, orig_weight): - # if alpha is a very small number as in coft, calc_scale will return a almost zero number so we ignore it - #multiplier = self.multiplier() * self.calc_scale() + # if alpha is a very small number as in coft, calc_scale() will return a almost zero number so we ignore it multiplier = self.multiplier() - return self.calc_updown_kb(orig_weight, multiplier) # override to remove the multiplier/scale factor; it's already multiplied in get_weight def finalize_updown(self, updown, orig_weight, output_shape, ex_bias=None): - #return super().finalize_updown(updown, orig_weight, output_shape, ex_bias) - if self.bias is not None: updown = updown.reshape(self.bias.shape) updown += self.bias.to(orig_weight.device, dtype=orig_weight.dtype) From bcfaf3979a9f93e37c418b58c75b02d9570b4354 Mon Sep 17 00:00:00 2001 From: AngelBottomless Date: Thu, 16 Nov 2023 18:43:16 +0900 Subject: [PATCH 0271/1174] convert/add hypertile options --- modules/hypertile.py | 36 ++++++++++++++++++++++++++++++++++++ modules/processing.py | 21 +++++++++++---------- modules/shared_options.py | 6 ++++++ 3 files changed, 53 insertions(+), 10 deletions(-) diff --git a/modules/hypertile.py b/modules/hypertile.py index 32d8604cc97..fee24a8cac1 100644 --- a/modules/hypertile.py +++ b/modules/hypertile.py @@ -332,3 +332,39 @@ def wrapper(*args, **kwargs): module.forward = module._original_forward_hypertile del module._original_forward_hypertile del module._split_sizes_hypertile + +def hypertile_context_vae(model:nn.Module, aspect_ratio:float, tile_size:int, opts): + """ + Returns context manager for VAE + """ + enabled = not opts.hypertile_split_vae_attn + swap_size = opts.hypertile_swap_size_vae + max_depth = opts.hypertile_max_depth_vae + tile_size_max = opts.hypertile_max_tile_vae + return split_attention( + model, + aspect_ratio=aspect_ratio, + tile_size=min(tile_size, tile_size_max), + swap_size=swap_size, + disable=not enabled, + max_depth=max_depth, + is_sdxl=False, + ) + +def hypertile_context_unet(model:nn.Module, aspect_ratio:float, tile_size:int, opts, is_sdxl:bool): + """ + Returns context manager for U-Net + """ + enabled = not opts.hypertile_split_unet_attn + swap_size = opts.hypertile_swap_size_unet + max_depth = opts.hypertile_max_depth_unet + tile_size_max = opts.hypertile_max_tile_unet + return split_attention( + model, + aspect_ratio=aspect_ratio, + tile_size=min(tile_size, tile_size_max), + swap_size=swap_size, + disable=not enabled, + max_depth=max_depth, + is_sdxl=is_sdxl, + ) \ No newline at end of file diff --git a/modules/processing.py b/modules/processing.py index e19a09a3c60..c622ff337c0 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -24,7 +24,7 @@ import modules.shared as shared import modules.paths as paths import modules.face_restoration -from modules.hypertile import split_attention, set_hypertile_seed, largest_tile_size_available +from modules.hypertile import set_hypertile_seed, largest_tile_size_available, hypertile_context_unet, hypertile_context_vae import modules.images as images import modules.styles import modules.sd_models as sd_models @@ -874,7 +874,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: else: if opts.sd_vae_decode_method != 'Full': p.extra_generation_params['VAE Decoder'] = opts.sd_vae_decode_method - with split_attention(p.sd_model.first_stage_model, aspect_ratio = p.width / p.height, tile_size=min(largest_tile_size_available(p.width, p.height), 128), disable=not shared.opts.hypertile_split_vae_attn, is_sdxl=shared.sd_model.is_sdxl): + with hypertile_context_unet(p.sd_model.first_stage_model, aspect_ratio=p.width / p.height, tile_size=largest_tile_size_available(p.width, p.height), is_sdxl=shared.sd_model.is_sdxl, opts=shared.opts): x_samples_ddim = decode_latent_batch(p.sd_model, samples_ddim, target_device=devices.cpu, check_for_nans=True) x_samples_ddim = torch.stack(x_samples_ddim).float() @@ -1144,8 +1144,8 @@ def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subs aspect_ratio = self.width / self.height x = self.rng.next() tile_size = largest_tile_size_available(self.width, self.height) - with split_attention(self.sd_model.first_stage_model, aspect_ratio=aspect_ratio, tile_size=min(tile_size, 128), swap_size=1, disable=not shared.opts.hypertile_split_vae_attn, is_sdxl=shared.sd_model.is_sdxl): - with split_attention(self.sd_model.model, aspect_ratio=aspect_ratio, tile_size=min(tile_size, 256), swap_size=2, disable=not shared.opts.hypertile_split_unet_attn, is_sdxl=shared.sd_model.is_sdxl): + with hypertile_context_vae(self.sd_model.first_stage_model, aspect_ratio=aspect_ratio, tile_size=tile_size, opts=shared.opts): + with hypertile_context_unet(self.sd_model.first_stage_model, aspect_ratio=aspect_ratio, tile_size=tile_size, is_sdxl=shared.sd_model.is_sdxl, opts=shared.opts): devices.torch_gc() samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x)) del x @@ -1153,7 +1153,7 @@ def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subs return samples if self.latent_scale_mode is None: - with split_attention(self.sd_model.first_stage_model, aspect_ratio=aspect_ratio, tile_size=min(tile_size, 256), swap_size=1, disable=not shared.opts.hypertile_split_vae_attn, is_sdxl=shared.sd_model.is_sdxl): + with hypertile_context_vae(self.sd_model.first_stage_model, aspect_ratio=aspect_ratio, tile_size=tile_size, opts=shared.opts): decoded_samples = torch.stack(decode_latent_batch(self.sd_model, samples, target_device=devices.cpu, check_for_nans=True)).to(dtype=torch.float32) else: decoded_samples = None @@ -1245,15 +1245,16 @@ def save_intermediate(image, index): if self.scripts is not None: self.scripts.before_hr(self) tile_size = largest_tile_size_available(target_width, target_height) - with split_attention(self.sd_model.first_stage_model, aspect_ratio=target_width / target_height, tile_size=min(tile_size, 256), swap_size=1, disable=not opts.hypertile_split_vae_attn, is_sdxl=shared.sd_model.is_sdxl): - with split_attention(self.sd_model.model, aspect_ratio=target_width / target_height, tile_size=min(tile_size, 256), swap_size=3, max_depth=1,scale_depth=True, disable=not opts.hypertile_split_unet_attn, is_sdxl=shared.sd_model.is_sdxl): + aspect_ratio = self.width / self.height + with hypertile_context_vae(self.sd_model.first_stage_model, aspect_ratio=aspect_ratio, tile_size=tile_size, opts=shared.opts): + with hypertile_context_unet(self.sd_model.first_stage_model, aspect_ratio=aspect_ratio, tile_size=tile_size, is_sdxl=shared.sd_model.is_sdxl, opts=shared.opts): samples = self.sampler.sample_img2img(self, samples, noise, self.hr_c, self.hr_uc, steps=self.hr_second_pass_steps or self.steps, image_conditioning=image_conditioning) sd_models.apply_token_merging(self.sd_model, self.get_token_merging_ratio()) self.sampler = None devices.torch_gc() - with split_attention(self.sd_model.first_stage_model, aspect_ratio=target_width / target_height, tile_size=min(tile_size, 256), swap_size=1, disable=not opts.hypertile_split_vae_attn, is_sdxl=shared.sd_model.is_sdxl): + with hypertile_context_vae(self.sd_model.first_stage_model, aspect_ratio=aspect_ratio, tile_size=tile_size, opts=shared.opts): decoded_samples = decode_latent_batch(self.sd_model, samples, target_device=devices.cpu, check_for_nans=True) self.is_hr_pass = False @@ -1533,8 +1534,8 @@ def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subs x *= self.initial_noise_multiplier aspect_ratio = self.width / self.height tile_size = largest_tile_size_available(self.width, self.height) - with split_attention(self.sd_model.first_stage_model, aspect_ratio=aspect_ratio, tile_size=min(tile_size, 128), swap_size=1, disable=not shared.opts.hypertile_split_vae_attn, is_sdxl=shared.sd_model.is_sdxl): - with split_attention(self.sd_model.model, aspect_ratio=aspect_ratio, tile_size=min(tile_size, 256), swap_size=2, disable=not shared.opts.hypertile_split_unet_attn, is_sdxl=shared.sd_model.is_sdxl): + with hypertile_context_vae(self.sd_model.first_stage_model, aspect_ratio=aspect_ratio, tile_size=tile_size, opts=shared.opts): + with hypertile_context_unet(self.sd_model.first_stage_model, aspect_ratio=aspect_ratio, tile_size=tile_size, is_sdxl=shared.sd_model.is_sdxl, opts=shared.opts): devices.torch_gc() samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning, image_conditioning=self.image_conditioning) diff --git a/modules/shared_options.py b/modules/shared_options.py index d965026568e..28a48906937 100644 --- a/modules/shared_options.py +++ b/modules/shared_options.py @@ -202,6 +202,12 @@ "batch_cond_uncond": OptionInfo(True, "Batch cond/uncond").info("do both conditional and unconditional denoising in one batch; uses a bit more VRAM during sampling, but improves speed; previously this was controlled by --always-batch-cond-uncond comandline argument"), "hypertile_split_unet_attn" : OptionInfo(False, "Split attention in Unet with HyperTile").link("Github", "https://github.com/tfernd/HyperTile").info("improves performance; changes behavior, but deterministic"), "hypertile_split_vae_attn": OptionInfo(False, "Split attention in VAE with HyperTile").link("Github", "https://github.com/tfernd/HyperTile").info("improves performance; changes behavior, but deterministic"), + "hypertile_max_depth_vae" : OptionInfo(3, "Max depth for VAE HyperTile hijack", gr.Slider, {"minimum": 0, "maximum": 3, "step": 1}).link("Github", "https://github.com/tfernd/HyperTile"), + "hypertile_max_depth_unet" : OptionInfo(3, "Max depth for Unet HyperTile hijack", gr.Slider, {"minimum": 0, "maximum": 3, "step": 1}).link("Github", "https://github.com/tfernd/HyperTile"), + "hypertile_max_tile_vae" : OptionInfo(128, "Max tile size for VAE HyperTile hijack", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}).link("Github", "https://github.com/tfernd/HyperTile"), + "hypertile_max_tile_unet" : OptionInfo(256, "Max tile size for Unet HyperTile hijack", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}).link("Github", "https://github.com/tfernd/HyperTile"), + "hypertile_swap_size_unet": OptionInfo(3, "Swap size for Unet HyperTile hijack", gr.Slider, {"minimum": 0, "maximum": 6, "step": 1}).link("Github", "https://github.com/tfernd/HyperTile"), + "hypertile_swap_size_vae": OptionInfo(3, "Swap size for VAE HyperTile hijack", gr.Slider, {"minimum": 0, "maximum": 6, "step": 1}).link("Github", "https://github.com/tfernd/HyperTile"), })) options_templates.update(options_section(('compatibility', "Compatibility"), { From 472c22cc8a46b825545d5c86bd2745269430d7b0 Mon Sep 17 00:00:00 2001 From: AngelBottomless Date: Thu, 16 Nov 2023 19:03:45 +0900 Subject: [PATCH 0272/1174] fix ruff - add newline --- modules/hypertile.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/hypertile.py b/modules/hypertile.py index fee24a8cac1..86acecdc08d 100644 --- a/modules/hypertile.py +++ b/modules/hypertile.py @@ -367,4 +367,4 @@ def hypertile_context_unet(model:nn.Module, aspect_ratio:float, tile_size:int, o disable=not enabled, max_depth=max_depth, is_sdxl=is_sdxl, - ) \ No newline at end of file + ) From 236eb82c3a91960ba5db7b82efbe0f6a9fd7cf24 Mon Sep 17 00:00:00 2001 From: Lucas Daniel Velazquez M <19197331+Luxter77@users.noreply.github.com> Date: Thu, 16 Nov 2023 13:20:33 -0300 Subject: [PATCH 0273/1174] Adds tqdm handler to logging_config.py for progress bar integration --- modules/logging_config.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/modules/logging_config.py b/modules/logging_config.py index 7db23d4b6e5..ce831b5c632 100644 --- a/modules/logging_config.py +++ b/modules/logging_config.py @@ -1,6 +1,19 @@ import os import logging +from tqdm.auto import tqdm + +class TqdmLoggingHandler(logging.Handler): + def __init__(self, level=logging.INFO): + super().__init__(level) + + def emit(self, record): + try: + msg = self.format(record) + tqdm.write(msg) + self.flush() + except Exception: + self.handleError(record) def setup_logging(loglevel): if loglevel is None: @@ -12,5 +25,6 @@ def setup_logging(loglevel): level=log_level, format='%(asctime)s %(levelname)s [%(name)s] %(message)s', datefmt='%Y-%m-%d %H:%M:%S', + handlers=[TqdmLoggingHandler()] ) From cdb60a690dcd35e865eb0caef6c6d8ff64e1b0d5 Mon Sep 17 00:00:00 2001 From: Lucas Daniel Velazquez M <19197331+Luxter77@users.noreply.github.com> Date: Thu, 16 Nov 2023 16:43:59 -0300 Subject: [PATCH 0274/1174] Take into account tqdm not being installed before first boot for logging --- modules/logging_config.py | 37 ++++++++++++++++++++++++------------- 1 file changed, 24 insertions(+), 13 deletions(-) diff --git a/modules/logging_config.py b/modules/logging_config.py index ce831b5c632..99ed2855b92 100644 --- a/modules/logging_config.py +++ b/modules/logging_config.py @@ -1,30 +1,41 @@ import os import logging -from tqdm.auto import tqdm +try: + from tqdm.auto import tqdm -class TqdmLoggingHandler(logging.Handler): - def __init__(self, level=logging.INFO): - super().__init__(level) + class TqdmLoggingHandler(logging.Handler): + def __init__(self, level=logging.INFO): + super().__init__(level) - def emit(self, record): - try: - msg = self.format(record) - tqdm.write(msg) - self.flush() - except Exception: - self.handleError(record) + def emit(self, record): + try: + msg = self.format(record) + tqdm.write(msg) + self.flush() + except Exception: + self.handleError(record) + + TQDM_IMPORTED = True +except ImportError: + # tqdm does not exist before first launch + # I will import once the UI finishes seting up the enviroment and reloads. + TQDM_IMPORTED = False def setup_logging(loglevel): if loglevel is None: loglevel = os.environ.get("SD_WEBUI_LOG_LEVEL") + loghandlers = [] + + if TQDM_IMPORTED: + loghandlers.append(TqdmLoggingHandler()) + if loglevel: log_level = getattr(logging, loglevel.upper(), None) or logging.INFO logging.basicConfig( level=log_level, format='%(asctime)s %(levelname)s [%(name)s] %(message)s', datefmt='%Y-%m-%d %H:%M:%S', - handlers=[TqdmLoggingHandler()] + handlers=[] ) - From 7021cdb1de12be3071ecb67aa8d2e34e7a0ec5ab Mon Sep 17 00:00:00 2001 From: Your Name Date: Thu, 16 Nov 2023 17:53:57 -0300 Subject: [PATCH 0275/1174] actually adds handler to logging_config.py --- modules/logging_config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/logging_config.py b/modules/logging_config.py index 99ed2855b92..79269875608 100644 --- a/modules/logging_config.py +++ b/modules/logging_config.py @@ -37,5 +37,5 @@ def setup_logging(loglevel): level=log_level, format='%(asctime)s %(levelname)s [%(name)s] %(message)s', datefmt='%Y-%m-%d %H:%M:%S', - handlers=[] + handlers=loghandlers ) From c40be2252ab1c8c289562db208c5ac6618bd8545 Mon Sep 17 00:00:00 2001 From: aria1th <35677394+aria1th@users.noreply.github.com> Date: Fri, 17 Nov 2023 09:22:27 +0900 Subject: [PATCH 0276/1174] Fix critical issue - unet apply --- modules/processing.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/modules/processing.py b/modules/processing.py index c622ff337c0..2fda7f332d9 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -874,7 +874,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: else: if opts.sd_vae_decode_method != 'Full': p.extra_generation_params['VAE Decoder'] = opts.sd_vae_decode_method - with hypertile_context_unet(p.sd_model.first_stage_model, aspect_ratio=p.width / p.height, tile_size=largest_tile_size_available(p.width, p.height), is_sdxl=shared.sd_model.is_sdxl, opts=shared.opts): + with hypertile_context_unet(p.sd_model.model, aspect_ratio=p.width / p.height, tile_size=largest_tile_size_available(p.width, p.height), is_sdxl=shared.sd_model.is_sdxl, opts=shared.opts): x_samples_ddim = decode_latent_batch(p.sd_model, samples_ddim, target_device=devices.cpu, check_for_nans=True) x_samples_ddim = torch.stack(x_samples_ddim).float() @@ -1145,7 +1145,7 @@ def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subs x = self.rng.next() tile_size = largest_tile_size_available(self.width, self.height) with hypertile_context_vae(self.sd_model.first_stage_model, aspect_ratio=aspect_ratio, tile_size=tile_size, opts=shared.opts): - with hypertile_context_unet(self.sd_model.first_stage_model, aspect_ratio=aspect_ratio, tile_size=tile_size, is_sdxl=shared.sd_model.is_sdxl, opts=shared.opts): + with hypertile_context_unet(self.sd_model.model, aspect_ratio=aspect_ratio, tile_size=tile_size, is_sdxl=shared.sd_model.is_sdxl, opts=shared.opts): devices.torch_gc() samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x)) del x @@ -1247,7 +1247,7 @@ def save_intermediate(image, index): tile_size = largest_tile_size_available(target_width, target_height) aspect_ratio = self.width / self.height with hypertile_context_vae(self.sd_model.first_stage_model, aspect_ratio=aspect_ratio, tile_size=tile_size, opts=shared.opts): - with hypertile_context_unet(self.sd_model.first_stage_model, aspect_ratio=aspect_ratio, tile_size=tile_size, is_sdxl=shared.sd_model.is_sdxl, opts=shared.opts): + with hypertile_context_unet(self.sd_model.model, aspect_ratio=aspect_ratio, tile_size=tile_size, is_sdxl=shared.sd_model.is_sdxl, opts=shared.opts): samples = self.sampler.sample_img2img(self, samples, noise, self.hr_c, self.hr_uc, steps=self.hr_second_pass_steps or self.steps, image_conditioning=image_conditioning) sd_models.apply_token_merging(self.sd_model, self.get_token_merging_ratio()) @@ -1535,7 +1535,7 @@ def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subs aspect_ratio = self.width / self.height tile_size = largest_tile_size_available(self.width, self.height) with hypertile_context_vae(self.sd_model.first_stage_model, aspect_ratio=aspect_ratio, tile_size=tile_size, opts=shared.opts): - with hypertile_context_unet(self.sd_model.first_stage_model, aspect_ratio=aspect_ratio, tile_size=tile_size, is_sdxl=shared.sd_model.is_sdxl, opts=shared.opts): + with hypertile_context_unet(self.sd_model.model, aspect_ratio=aspect_ratio, tile_size=tile_size, is_sdxl=shared.sd_model.is_sdxl, opts=shared.opts): devices.torch_gc() samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning, image_conditioning=self.image_conditioning) From c0725ba2d098a6a78610e7d96ee75f63a32d4e52 Mon Sep 17 00:00:00 2001 From: aria1th <35677394+aria1th@users.noreply.github.com> Date: Fri, 17 Nov 2023 09:34:50 +0900 Subject: [PATCH 0277/1174] Fix inverted option issue I'm pretty sure I was sleepy while implementing this --- modules/hypertile.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/modules/hypertile.py b/modules/hypertile.py index 86acecdc08d..3a1468c6bfc 100644 --- a/modules/hypertile.py +++ b/modules/hypertile.py @@ -337,7 +337,7 @@ def hypertile_context_vae(model:nn.Module, aspect_ratio:float, tile_size:int, op """ Returns context manager for VAE """ - enabled = not opts.hypertile_split_vae_attn + enabled = opts.hypertile_split_vae_attn swap_size = opts.hypertile_swap_size_vae max_depth = opts.hypertile_max_depth_vae tile_size_max = opts.hypertile_max_tile_vae @@ -355,7 +355,7 @@ def hypertile_context_unet(model:nn.Module, aspect_ratio:float, tile_size:int, o """ Returns context manager for U-Net """ - enabled = not opts.hypertile_split_unet_attn + enabled = opts.hypertile_split_unet_attn swap_size = opts.hypertile_swap_size_unet max_depth = opts.hypertile_max_depth_unet tile_size_max = opts.hypertile_max_tile_unet From ffd0f8ddc309688636ac1ac10d82b72ab6b466bf Mon Sep 17 00:00:00 2001 From: aria1th <35677394+aria1th@users.noreply.github.com> Date: Fri, 17 Nov 2023 09:54:33 +0900 Subject: [PATCH 0278/1174] set empty value for SD XL 3rd layer --- modules/hypertile.py | 1 + 1 file changed, 1 insertion(+) diff --git a/modules/hypertile.py b/modules/hypertile.py index 3a1468c6bfc..be898fce43f 100644 --- a/modules/hypertile.py +++ b/modules/hypertile.py @@ -170,6 +170,7 @@ "middle_block.1.transformer_blocks.8.attn1", "middle_block.1.transformer_blocks.9.attn1", ], + 3 : [] # TODO - separate layers for SD-XL } From 97431f29feb17ffc96ca95e9b3efec87be9d8b3a Mon Sep 17 00:00:00 2001 From: aria1th <35677394+aria1th@users.noreply.github.com> Date: Fri, 17 Nov 2023 10:05:28 +0900 Subject: [PATCH 0279/1174] fix double gc and decoding with unet context --- modules/processing.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/modules/processing.py b/modules/processing.py index 2fda7f332d9..36c2be5e536 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -874,7 +874,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: else: if opts.sd_vae_decode_method != 'Full': p.extra_generation_params['VAE Decoder'] = opts.sd_vae_decode_method - with hypertile_context_unet(p.sd_model.model, aspect_ratio=p.width / p.height, tile_size=largest_tile_size_available(p.width, p.height), is_sdxl=shared.sd_model.is_sdxl, opts=shared.opts): + with hypertile_context_vae(p.sd_model.first_stage_model, aspect_ratio=p.width / p.height, tile_size=largest_tile_size_available(p.width, p.height), opts=shared.opts): x_samples_ddim = decode_latent_batch(p.sd_model, samples_ddim, target_device=devices.cpu, check_for_nans=True) x_samples_ddim = torch.stack(x_samples_ddim).float() @@ -1146,11 +1146,11 @@ def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subs tile_size = largest_tile_size_available(self.width, self.height) with hypertile_context_vae(self.sd_model.first_stage_model, aspect_ratio=aspect_ratio, tile_size=tile_size, opts=shared.opts): with hypertile_context_unet(self.sd_model.model, aspect_ratio=aspect_ratio, tile_size=tile_size, is_sdxl=shared.sd_model.is_sdxl, opts=shared.opts): - devices.torch_gc() samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x)) del x if not self.enable_hr: return samples + devices.torch_gc() if self.latent_scale_mode is None: with hypertile_context_vae(self.sd_model.first_stage_model, aspect_ratio=aspect_ratio, tile_size=tile_size, opts=shared.opts): @@ -1536,7 +1536,6 @@ def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subs tile_size = largest_tile_size_available(self.width, self.height) with hypertile_context_vae(self.sd_model.first_stage_model, aspect_ratio=aspect_ratio, tile_size=tile_size, opts=shared.opts): with hypertile_context_unet(self.sd_model.model, aspect_ratio=aspect_ratio, tile_size=tile_size, is_sdxl=shared.sd_model.is_sdxl, opts=shared.opts): - devices.torch_gc() samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning, image_conditioning=self.image_conditioning) if self.mask is not None: From 4f2a4a361511ca3b8cdd7373f6c7d723583e8fdb Mon Sep 17 00:00:00 2001 From: storyicon Date: Fri, 17 Nov 2023 09:48:18 +0000 Subject: [PATCH 0280/1174] feat: fix randn found element of type float at pos 2 Signed-off-by: storyicon --- modules/rng.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/rng.py b/modules/rng.py index 9e8ba2ee9d7..8934d39bf9a 100644 --- a/modules/rng.py +++ b/modules/rng.py @@ -110,7 +110,7 @@ def __init__(self, shape, seeds, subseeds=None, subseed_strength=0.0, seed_resiz self.is_first = True def first(self): - noise_shape = self.shape if self.seed_resize_from_h <= 0 or self.seed_resize_from_w <= 0 else (self.shape[0], self.seed_resize_from_h // 8, self.seed_resize_from_w // 8) + noise_shape = self.shape if self.seed_resize_from_h <= 0 or self.seed_resize_from_w <= 0 else (self.shape[0], int(self.seed_resize_from_h) // 8, int(self.seed_resize_from_w // 8)) xs = [] From bde439ef67776be126d6a8c569a23d54dbc3e707 Mon Sep 17 00:00:00 2001 From: wfjsw Date: Sun, 19 Nov 2023 00:58:47 -0600 Subject: [PATCH 0281/1174] use metadata.ini for meta filename --- modules/extensions.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/modules/extensions.py b/modules/extensions.py index 5536db3eab1..f3988d02edf 100644 --- a/modules/extensions.py +++ b/modules/extensions.py @@ -43,13 +43,13 @@ def __init__(self, name, path, enabled=True, is_builtin=False, canonical_name=No @functools.cached_property def metadata(self): - if os.path.isfile(os.path.join(self.path, "sd_webui_metadata.ini")): + if os.path.isfile(os.path.join(self.path, "metadata.ini")): try: config = configparser.ConfigParser() - config.read(os.path.join(self.path, "sd_webui_metadata.ini")) + config.read(os.path.join(self.path, "metadata.ini")) return config except Exception: - errors.report(f"Error reading sd_webui_metadata.ini for extension {self.canonical_name}.", + errors.report(f"Error reading metadata.ini for extension {self.canonical_name}.", exc_info=True) return None @@ -177,14 +177,14 @@ def list_extensions(): canonical_name = extension_dirname requires = None - if os.path.isfile(os.path.join(path, "sd_webui_metadata.ini")): + if os.path.isfile(os.path.join(path, "metadata.ini")): try: config = configparser.ConfigParser() - config.read(os.path.join(path, "sd_webui_metadata.ini")) + config.read(os.path.join(path, "metadata.ini")) canonical_name = config.get("Extension", "Name", fallback=canonical_name) requires = config.get("Extension", "Requires", fallback=None) except Exception: - errors.report(f"Error reading sd_webui_metadata.ini for extension {extension_dirname}. " + errors.report(f"Error reading metadata.ini for extension {extension_dirname}. " f"Will load regardless.", exc_info=True) canonical_name = canonical_name.lower().strip() From 598da5cd4928618b166886d3485ce30ce3a43490 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Sun, 19 Nov 2023 15:50:06 +0800 Subject: [PATCH 0282/1174] Use options instead of cmd_args --- modules/cmd_args.py | 2 -- modules/devices.py | 25 +++++++++------- modules/initialize_util.py | 1 + modules/sd_models.py | 61 ++++++++++++++++++++------------------ modules/shared_options.py | 1 + scripts/xyz_grid.py | 1 + 6 files changed, 49 insertions(+), 42 deletions(-) diff --git a/modules/cmd_args.py b/modules/cmd_args.py index 088d5deaac1..a9fb9bfa3ef 100644 --- a/modules/cmd_args.py +++ b/modules/cmd_args.py @@ -118,5 +118,3 @@ parser.add_argument("--disable-all-extensions", action='store_true', help="prevent all extensions from running regardless of any other settings", default=False) parser.add_argument("--disable-extra-extensions", action='store_true', help="prevent all extensions except built-in from running regardless of any other settings", default=False) parser.add_argument("--skip-load-model-at-start", action='store_true', help="if load a model at web start, only take effect when --nowebui", ) -parser.add_argument("--opt-unet-fp8-storage", action='store_true', help="use fp8 for SD UNet to save vram", default=False) -parser.add_argument("--opt-unet-fp8-storage-xl", action='store_true', help="use fp8 for SD UNet to save vram", default=False) diff --git a/modules/devices.py b/modules/devices.py index d7c905c283a..03e7bdb7c2e 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -20,15 +20,15 @@ def cuda_no_autocast(device_id=None) -> bool: if device_id is None: device_id = get_cuda_device_id() return ( - torch.cuda.get_device_capability(device_id) == (7, 5) + torch.cuda.get_device_capability(device_id) == (7, 5) and torch.cuda.get_device_name(device_id).startswith("NVIDIA GeForce GTX 16") ) def get_cuda_device_id(): return ( - int(shared.cmd_opts.device_id) - if shared.cmd_opts.device_id is not None and shared.cmd_opts.device_id.isdigit() + int(shared.cmd_opts.device_id) + if shared.cmd_opts.device_id is not None and shared.cmd_opts.device_id.isdigit() else 0 ) or torch.cuda.current_device() @@ -116,16 +116,19 @@ def cond_cast_float(input): torch.nn.LayerNorm, ] + +def manual_cast_forward(self, *args, **kwargs): + org_dtype = next(self.parameters()).dtype + self.to(dtype) + args = [arg.to(dtype) if isinstance(arg, torch.Tensor) else arg for arg in args] + kwargs = {k: v.to(dtype) if isinstance(v, torch.Tensor) else v for k, v in kwargs.items()} + result = self.org_forward(*args, **kwargs) + self.to(org_dtype) + return result + + @contextlib.contextmanager def manual_autocast(): - def manual_cast_forward(self, *args, **kwargs): - org_dtype = next(self.parameters()).dtype - self.to(dtype) - args = [arg.to(dtype) if isinstance(arg, torch.Tensor) else arg for arg in args] - kwargs = {k: v.to(dtype) if isinstance(v, torch.Tensor) else v for k, v in kwargs.items()} - result = self.org_forward(*args, **kwargs) - self.to(org_dtype) - return result for module_type in patch_module_list: org_forward = module_type.forward module_type.forward = manual_cast_forward diff --git a/modules/initialize_util.py b/modules/initialize_util.py index 2e9b6d895f4..1b11ead61fd 100644 --- a/modules/initialize_util.py +++ b/modules/initialize_util.py @@ -177,6 +177,7 @@ def configure_opts_onchange(): shared.opts.onchange("temp_dir", ui_tempdir.on_tmpdir_changed) shared.opts.onchange("gradio_theme", shared.reload_gradio_theme) shared.opts.onchange("cross_attention_optimization", wrap_queued_call(lambda: sd_hijack.model_hijack.redo_hijack(shared.sd_model)), call=False) + shared.opts.onchange("fp8_storage", wrap_queued_call(lambda: sd_models.reload_model_weights()), call=False) startup_timer.record("opts onchange") diff --git a/modules/sd_models.py b/modules/sd_models.py index a6c8b2fa7db..eb4914347b9 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -339,10 +339,28 @@ def __exit__(self, exc_type, exc_value, exc_traceback): SkipWritingToConfig.skip = self.previous +def check_fp8(model): + if model is None: + return None + if devices.get_optimal_device_name() == "mps": + enable_fp8 = False + elif shared.opts.fp8_storage == "Enable": + enable_fp8 = True + elif getattr(model, "is_sdxl", False) and shared.opts.fp8_storage == "Enable for SDXL": + enable_fp8 = True + else: + enable_fp8 = False + return enable_fp8 + + def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer): sd_model_hash = checkpoint_info.calculate_shorthash() timer.record("calculate hash") + if not check_fp8(model) and devices.fp8: + # prevent model to load state dict in fp8 + model.half() + if not SkipWritingToConfig.skip: shared.opts.data["sd_model_checkpoint"] = checkpoint_info.title @@ -395,34 +413,16 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer devices.dtype_unet = torch.float16 timer.record("apply half()") - if devices.get_optimal_device_name() == "mps": - enable_fp8 = False - elif shared.cmd_opts.opt_unet_fp8_storage: - enable_fp8 = True - elif model.is_sdxl and shared.cmd_opts.opt_unet_fp8_storage_xl: - enable_fp8 = True - else: - enable_fp8 = False - - if enable_fp8: + if check_fp8(model): devices.fp8 = True - if model.is_sdxl: - cond_stage = model.conditioner - else: - cond_stage = model.cond_stage_model - - for module in cond_stage.modules(): - if isinstance(module, torch.nn.Linear): + first_stage = model.first_stage_model + model.first_stage_model = None + for module in model.modules(): + if isinstance(module, torch.nn.Conv2d): module.to(torch.float8_e4m3fn) - - if devices.device == devices.cpu: - for module in model.model.diffusion_model.modules(): - if isinstance(module, torch.nn.Conv2d): - module.to(torch.float8_e4m3fn) - elif isinstance(module, torch.nn.Linear): - module.to(torch.float8_e4m3fn) - else: - model.model.diffusion_model = model.model.diffusion_model.to(torch.float8_e4m3fn) + elif isinstance(module, torch.nn.Linear): + module.to(torch.float8_e4m3fn) + model.first_stage_model = first_stage timer.record("apply fp8") else: devices.fp8 = False @@ -769,7 +769,7 @@ def reuse_model_from_already_loaded(sd_model, checkpoint_info, timer): return None -def reload_model_weights(sd_model=None, info=None): +def reload_model_weights(sd_model=None, info=None, forced_reload=False): checkpoint_info = info or select_checkpoint() timer = Timer() @@ -781,11 +781,14 @@ def reload_model_weights(sd_model=None, info=None): current_checkpoint_info = None else: current_checkpoint_info = sd_model.sd_checkpoint_info - if sd_model.sd_model_checkpoint == checkpoint_info.filename: + if check_fp8(sd_model) != devices.fp8: + # load from state dict again to prevent extra numerical errors + forced_reload = True + elif sd_model.sd_model_checkpoint == checkpoint_info.filename: return sd_model sd_model = reuse_model_from_already_loaded(sd_model, checkpoint_info, timer) - if sd_model is not None and sd_model.sd_checkpoint_info.filename == checkpoint_info.filename: + if not forced_reload and sd_model is not None and sd_model.sd_checkpoint_info.filename == checkpoint_info.filename: return sd_model if sd_model is not None: diff --git a/modules/shared_options.py b/modules/shared_options.py index f1003f2187b..d27f35e96fc 100644 --- a/modules/shared_options.py +++ b/modules/shared_options.py @@ -200,6 +200,7 @@ "pad_cond_uncond": OptionInfo(False, "Pad prompt/negative prompt to be same length", infotext='Pad conds').info("improves performance when prompt and negative prompt have different lengths; changes seeds"), "persistent_cond_cache": OptionInfo(True, "Persistent cond cache").info("do not recalculate conds from prompts if prompts have not changed since previous calculation"), "batch_cond_uncond": OptionInfo(True, "Batch cond/uncond").info("do both conditional and unconditional denoising in one batch; uses a bit more VRAM during sampling, but improves speed; previously this was controlled by --always-batch-cond-uncond comandline argument"), + "fp8_storage": OptionInfo("Disable", "FP8 weight", gr.Dropdown, {"choices": ["Disable", "Enable for SDXL", "Enable"]}).info("Use FP8 to store Linear/Conv layers' weight. Require pytorch>=2.1.0."), })) options_templates.update(options_section(('compatibility', "Compatibility"), { diff --git a/scripts/xyz_grid.py b/scripts/xyz_grid.py index 0dc255bc43d..b2250c04dc0 100644 --- a/scripts/xyz_grid.py +++ b/scripts/xyz_grid.py @@ -270,6 +270,7 @@ def __init__(self, *args, **kwargs): AxisOption("Refiner checkpoint", str, apply_field('refiner_checkpoint'), format_value=format_remove_path, confirm=confirm_checkpoints_or_none, cost=1.0, choices=lambda: ['None'] + sorted(sd_models.checkpoints_list, key=str.casefold)), AxisOption("Refiner switch at", float, apply_field('refiner_switch_at')), AxisOption("RNG source", str, apply_override("randn_source"), choices=lambda: ["GPU", "CPU", "NV"]), + AxisOption("FP8 mode", str, apply_override("fp8_storage"), cost=0.9, choices=lambda: ["Disable", "Enable for SDXL", "Enable"]), ] From 890181e1d456b613bf60f6e8378dc68b39011af9 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Sun, 19 Nov 2023 15:54:39 +0800 Subject: [PATCH 0283/1174] Update the xformers/torch versions --- modules/errors.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/modules/errors.py b/modules/errors.py index 8c339464d46..a3498c11f4b 100644 --- a/modules/errors.py +++ b/modules/errors.py @@ -93,8 +93,8 @@ def check_versions(): import torch import gradio - expected_torch_version = "2.0.0" - expected_xformers_version = "0.0.20" + expected_torch_version = "2.1.0" + expected_xformers_version = "0.0.22.post7" expected_gradio_version = "3.41.2" if version.parse(torch.__version__) < version.parse(expected_torch_version): From f383af2729ec2d1969200218577ab19dd78f7d48 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Sun, 19 Nov 2023 15:56:23 +0800 Subject: [PATCH 0284/1174] update xformers/torch versions --- modules/launch_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/launch_utils.py b/modules/launch_utils.py index 636da679ecc..c225bbc180c 100644 --- a/modules/launch_utils.py +++ b/modules/launch_utils.py @@ -312,7 +312,7 @@ def prepare_environment(): torch_command = os.environ.get('TORCH_COMMAND', f"pip install torch==2.1.0 torchvision==0.16.0 --extra-index-url {torch_index_url}") requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt") - xformers_package = os.environ.get('XFORMERS_PACKAGE', 'xformers==0.0.20') + xformers_package = os.environ.get('XFORMERS_PACKAGE', 'xformers==0.0.22.post7') clip_package = os.environ.get('CLIP_PACKAGE', "https://github.com/openai/CLIP/archive/d50d76daa670286dd6cacf3bcd80b5e4823fc8e1.zip") openclip_package = os.environ.get('OPENCLIP_PACKAGE', "https://github.com/mlfoundations/open_clip/archive/bb6e834e9c70d9c27d0dc3ecedeebeaeb1ffad6b.zip") From 043d2edcf6a543f236f1f3cb70ac72e7b3b357b6 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Sun, 19 Nov 2023 15:56:31 +0800 Subject: [PATCH 0285/1174] Better naming --- modules/devices.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/modules/devices.py b/modules/devices.py index 03e7bdb7c2e..c19a7f402f1 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -128,7 +128,7 @@ def manual_cast_forward(self, *args, **kwargs): @contextlib.contextmanager -def manual_autocast(): +def manual_cast(): for module_type in patch_module_list: org_forward = module_type.forward module_type.forward = manual_cast_forward @@ -148,10 +148,10 @@ def autocast(disable=False): return torch.autocast("cpu", dtype=torch.bfloat16, enabled=True) if fp8 and (dtype == torch.float32 or shared.cmd_opts.precision == "full" or cuda_no_autocast()): - return manual_autocast() + return manual_cast() if has_mps() and shared.cmd_opts.precision != "full": - return manual_autocast() + return manual_cast() if dtype == torch.float32 or shared.cmd_opts.precision == "full": return contextlib.nullcontext() From dea5e43c8359b663d5599efc99278c258747db61 Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Sat, 18 Nov 2023 04:23:03 +0900 Subject: [PATCH 0286/1174] Option to show batch img2img results in UI shared.opts.img2img_batch_show_results_limit limit the number of images return to the UI for batch img2img default limit 32 0 no images are shown -1 unlimited, all images are shown --- modules/img2img.py | 24 ++++++++++++++++++++---- modules/shared_options.py | 1 + 2 files changed, 21 insertions(+), 4 deletions(-) diff --git a/modules/img2img.py b/modules/img2img.py index 52cb577a646..c583290a0ea 100644 --- a/modules/img2img.py +++ b/modules/img2img.py @@ -44,6 +44,8 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=Fal steps = p.steps override_settings = p.override_settings sd_model_checkpoint_override = get_closet_checkpoint_match(override_settings.get("sd_model_checkpoint", None)) + batch_results = None + discard_further_results = False for i, image in enumerate(images): state.job = f"{i+1} out of {len(images)}" if state.skipped: @@ -127,7 +129,21 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=Fal if proc is None: p.override_settings.pop('save_images_replace_action', None) - process_images(p) + proc = process_images(p) + + if not discard_further_results and proc: + if batch_results: + batch_results.images.extend(proc.images) + batch_results.infotexts.extend(proc.infotexts) + else: + batch_results = proc + + if 0 <= shared.opts.img2img_batch_show_results_limit < len(batch_results.images): + discard_further_results = True + batch_results.images = batch_results.images[:int(shared.opts.img2img_batch_show_results_limit)] + batch_results.infotexts = batch_results.infotexts[:int(shared.opts.img2img_batch_show_results_limit)] + + return batch_results def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_name: str, mask_blur: int, mask_alpha: float, inpainting_fill: int, n_iter: int, batch_size: int, cfg_scale: float, image_cfg_scale: float, denoising_strength: float, selected_scale_tab: int, height: int, width: int, scale_by: float, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, override_settings_texts, img2img_batch_use_png_info: bool, img2img_batch_png_info_props: list, img2img_batch_png_info_dir: str, request: gr.Request, *args): @@ -212,10 +228,10 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s with closing(p): if is_batch: assert not shared.cmd_opts.hide_ui_dir_config, "Launched with --hide-ui-dir-config, batch img2img disabled" + processed = process_batch(p, img2img_batch_input_dir, img2img_batch_output_dir, img2img_batch_inpaint_mask_dir, args, to_scale=selected_scale_tab == 1, scale_by=scale_by, use_png_info=img2img_batch_use_png_info, png_info_props=img2img_batch_png_info_props, png_info_dir=img2img_batch_png_info_dir) - process_batch(p, img2img_batch_input_dir, img2img_batch_output_dir, img2img_batch_inpaint_mask_dir, args, to_scale=selected_scale_tab == 1, scale_by=scale_by, use_png_info=img2img_batch_use_png_info, png_info_props=img2img_batch_png_info_props, png_info_dir=img2img_batch_png_info_dir) - - processed = Processed(p, [], p.seed, "") + if processed is None: + processed = Processed(p, [], p.seed, "") else: processed = modules.scripts.scripts_img2img.run(p, *args) if processed is None: diff --git a/modules/shared_options.py b/modules/shared_options.py index d40db5306ee..1ee8c7ad1cc 100644 --- a/modules/shared_options.py +++ b/modules/shared_options.py @@ -189,6 +189,7 @@ "img2img_inpaint_sketch_default_brush_color": OptionInfo("#ffffff", "Inpaint sketch initial brush color", ui_components.FormColorPicker, {}).info("default brush color of img2img inpaint sketch").needs_reload_ui(), "return_mask": OptionInfo(False, "For inpainting, include the greyscale mask in results for web"), "return_mask_composite": OptionInfo(False, "For inpainting, include masked composite in results for web"), + "img2img_batch_show_results_limit": OptionInfo(32, "Show the first N batch img2img results in UI", gr.Slider, {"minimum": -1, "maximum": 1000, "step": 1}).info('0: disable, -1: show all images. Too many images can cause lag'), })) options_templates.update(options_section(('optimizations', "Optimizations"), { From 6d337bf23dae990e7b6717da4d5f2e54f212685c Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Mon, 20 Nov 2023 01:38:31 +0900 Subject: [PATCH 0287/1174] save sysinfo as .json GitHub now allows uploading of .json files in issues --- modules/launch_utils.py | 2 +- modules/ui.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/modules/launch_utils.py b/modules/launch_utils.py index 8cdbafa50c8..264ec9ca6a4 100644 --- a/modules/launch_utils.py +++ b/modules/launch_utils.py @@ -441,7 +441,7 @@ def dump_sysinfo(): import datetime text = sysinfo.get() - filename = f"sysinfo-{datetime.datetime.utcnow().strftime('%Y-%m-%d-%H-%M')}.txt" + filename = f"sysinfo-{datetime.datetime.utcnow().strftime('%Y-%m-%d-%H-%M')}.json" with open(filename, "w", encoding="utf8") as file: file.write(text) diff --git a/modules/ui.py b/modules/ui.py index ba0d8542b60..b82f3c5e8b1 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1308,7 +1308,7 @@ def download_sysinfo(attachment=False): from fastapi.responses import PlainTextResponse text = sysinfo.get() - filename = f"sysinfo-{datetime.datetime.utcnow().strftime('%Y-%m-%d-%H-%M')}.txt" + filename = f"sysinfo-{datetime.datetime.utcnow().strftime('%Y-%m-%d-%H-%M')}.json" return PlainTextResponse(text, headers={'Content-Disposition': f'{"attachment" if attachment else "inline"}; filename="{filename}"'}) From b2e039d07bed76350120ff448964c907a3b5e4a3 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Mon, 20 Nov 2023 14:05:32 +0800 Subject: [PATCH 0288/1174] Update webui-macos-env.sh --- webui-macos-env.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/webui-macos-env.sh b/webui-macos-env.sh index 24bc5c42615..db7e8b1a05b 100644 --- a/webui-macos-env.sh +++ b/webui-macos-env.sh @@ -11,7 +11,7 @@ fi export install_dir="$HOME" export COMMANDLINE_ARGS="--skip-torch-cuda-test --upcast-sampling --no-half-vae --use-cpu interrogate" -export TORCH_COMMAND="pip install torch==2.0.1 torchvision==0.15.2" +export TORCH_COMMAND="pip install torch==2.1.0 torchvision==0.16.0" export PYTORCH_ENABLE_MPS_FALLBACK=1 #################################################################### From 9b471436b2226458a767077707ea102e331b5d78 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Mon, 20 Nov 2023 14:47:09 +0300 Subject: [PATCH 0289/1174] rework extensions metadata: use custom sorter that doesn't mess the order as much and ignores cyclic errors, use classes with named fields instead of dictionaries, eliminate some duplicated code --- modules/extensions.py | 132 +++++++++++++++++---------------- modules/scripts.py | 169 +++++++++++++++++++----------------------- 2 files changed, 148 insertions(+), 153 deletions(-) diff --git a/modules/extensions.py b/modules/extensions.py index f3988d02edf..1899cd52975 100644 --- a/modules/extensions.py +++ b/modules/extensions.py @@ -1,5 +1,6 @@ +from __future__ import annotations + import configparser -import functools import os import threading import re @@ -8,7 +9,6 @@ from modules.gitpython_hack import Repo from modules.paths_internal import extensions_dir, extensions_builtin_dir, script_path # noqa: F401 -extensions = [] os.makedirs(extensions_dir, exist_ok=True) @@ -22,13 +22,56 @@ def active(): return [x for x in extensions if x.enabled] +class ExtensionMetadata: + filename = "metadata.ini" + config: configparser.ConfigParser + canonical_name: str + requires: list + + def __init__(self, path, canonical_name): + self.config = configparser.ConfigParser() + + filepath = os.path.join(path, self.filename) + if os.path.isfile(filepath): + try: + self.config.read(filepath) + except Exception: + errors.report(f"Error reading {self.filename} for extension {canonical_name}.", exc_info=True) + + self.canonical_name = self.config.get("Extension", "Name", fallback=canonical_name) + self.canonical_name = canonical_name.lower().strip() + + self.requires = self.get_script_requirements("Requires", "Extension") + + def get_script_requirements(self, field, section, extra_section=None): + """reads a list of requirements from the config; field is the name of the field in the ini file, + like Requires or Before, and section is the name of the [section] in the ini file; additionally, + reads more requirements from [extra_section] if specified.""" + + x = self.config.get(section, field, fallback='') + + if extra_section: + x = x + ', ' + self.config.get(extra_section, field, fallback='') + + return self.parse_list(x.lower()) + + def parse_list(self, text): + """converts a line from config ("ext1 ext2, ext3 ") into a python list (["ext1", "ext2", "ext3"])""" + + if not text: + return [] + + # both "," and " " are accepted as separator + return [x for x in re.split(r"[,\s]+", text.strip()) if x] + + class Extension: lock = threading.Lock() cached_fields = ['remote', 'commit_date', 'branch', 'commit_hash', 'version'] + metadata: ExtensionMetadata - def __init__(self, name, path, enabled=True, is_builtin=False, canonical_name=None): + def __init__(self, name, path, enabled=True, is_builtin=False, metadata=None): self.name = name - self.canonical_name = canonical_name or name.lower() self.path = path self.enabled = enabled self.status = '' @@ -40,18 +83,8 @@ def __init__(self, name, path, enabled=True, is_builtin=False, canonical_name=No self.branch = None self.remote = None self.have_info_from_repo = False - - @functools.cached_property - def metadata(self): - if os.path.isfile(os.path.join(self.path, "metadata.ini")): - try: - config = configparser.ConfigParser() - config.read(os.path.join(self.path, "metadata.ini")) - return config - except Exception: - errors.report(f"Error reading metadata.ini for extension {self.canonical_name}.", - exc_info=True) - return None + self.metadata = metadata if metadata else ExtensionMetadata(self.path, name.lower()) + self.canonical_name = metadata.canonical_name def to_dict(self): return {x: getattr(self, x) for x in self.cached_fields} @@ -162,7 +195,7 @@ def list_extensions(): elif shared.opts.disable_all_extensions == "extra": print("*** \"Disable all extensions\" option was set, will only load built-in extensions ***") - extension_dependency_map = {} + loaded_extensions = {} # scan through extensions directory and load metadata for dirname in [extensions_builtin_dir, extensions_dir]: @@ -175,55 +208,30 @@ def list_extensions(): continue canonical_name = extension_dirname - requires = None + metadata = ExtensionMetadata(path, canonical_name) - if os.path.isfile(os.path.join(path, "metadata.ini")): - try: - config = configparser.ConfigParser() - config.read(os.path.join(path, "metadata.ini")) - canonical_name = config.get("Extension", "Name", fallback=canonical_name) - requires = config.get("Extension", "Requires", fallback=None) - except Exception: - errors.report(f"Error reading metadata.ini for extension {extension_dirname}. " - f"Will load regardless.", exc_info=True) + # check for duplicated canonical names + already_loaded_extension = loaded_extensions.get(metadata.canonical_name) + if already_loaded_extension is not None: + errors.report(f'Duplicate canonical name "{canonical_name}" found in extensions "{extension_dirname}" and "{already_loaded_extension.name}". Former will be discarded.', exc_info=False) + continue - canonical_name = canonical_name.lower().strip() + is_builtin = dirname == extensions_builtin_dir + extension = Extension(name=extension_dirname, path=path, enabled=extension_dirname not in shared.opts.disabled_extensions, is_builtin=is_builtin, metadata=metadata) + extensions.append(extension) + loaded_extensions[canonical_name] = extension - # check for duplicated canonical names - if canonical_name in extension_dependency_map: - errors.report(f"Duplicate canonical name \"{canonical_name}\" found in extensions " - f"\"{extension_dirname}\" and \"{extension_dependency_map[canonical_name]['dirname']}\". " - f"The current loading extension will be discarded.", exc_info=False) + # check for requirements + for extension in extensions: + for req in extension.metadata.requires: + required_extension = loaded_extensions.get(req) + if required_extension is None: + errors.report(f'Extension "{extension.name}" requires "{req}" which is not installed.', exc_info=False) continue - # both "," and " " are accepted as separator - requires = list(filter(None, re.split(r"[,\s]+", requires.lower()))) if requires else [] + if not extension.enabled: + errors.report(f'Extension "{extension.name}" requires "{required_extension.name}" which is disabled.', exc_info=False) + continue - extension_dependency_map[canonical_name] = { - "dirname": extension_dirname, - "path": path, - "requires": requires, - } - # check for requirements - for (_, extension_data) in extension_dependency_map.items(): - dirname, path, requires = extension_data['dirname'], extension_data['path'], extension_data['requires'] - requirement_met = True - for req in requires: - if req not in extension_dependency_map: - errors.report(f"Extension \"{dirname}\" requires \"{req}\" which is not installed. " - f"The current loading extension will be discarded.", exc_info=False) - requirement_met = False - break - dep_dirname = extension_dependency_map[req]['dirname'] - if dep_dirname in shared.opts.disabled_extensions: - errors.report(f"Extension \"{dirname}\" requires \"{dep_dirname}\" which is disabled. " - f"The current loading extension will be discarded.", exc_info=False) - requirement_met = False - break - - is_builtin = dirname == extensions_builtin_dir - extension = Extension(name=dirname, path=path, - enabled=dirname not in shared.opts.disabled_extensions and requirement_met, - is_builtin=is_builtin) - extensions.append(extension) +extensions: list[Extension] = [] diff --git a/modules/scripts.py b/modules/scripts.py index b1f4504a5d3..b0689a23de1 100644 --- a/modules/scripts.py +++ b/modules/scripts.py @@ -2,7 +2,6 @@ import re import sys import inspect -from graphlib import TopologicalSorter, CycleError from collections import namedtuple from dataclasses import dataclass @@ -312,27 +311,57 @@ def basedir(): postprocessing_scripts_data = [] ScriptClassData = namedtuple("ScriptClassData", ["script_class", "path", "basedir", "module"]) +def topological_sort(dependencies): + """Accepts a dictionary mapping name to its dependencies, returns a list of names ordered according to dependencies. + Ignores errors relating to missing dependeencies or circular dependencies + """ + + visited = {} + result = [] + + def inner(name): + visited[name] = True + + for dep in dependencies.get(name, []): + if dep in dependencies and dep not in visited: + inner(dep) + + result.append(name) + + for depname in dependencies: + if depname not in visited: + inner(depname) + + return result + + +@dataclass +class ScriptWithDependencies: + script_canonical_name: str + file: ScriptFile + requires: list + load_before: list + load_after: list + def list_scripts(scriptdirname, extension, *, include_extensions=True): - scripts_list = [] - script_dependency_map = {} + scripts = {} - # build script dependency map + loaded_extensions = {ext.canonical_name: ext for ext in extensions.active()} + loaded_extensions_scripts = {ext.canonical_name: [] for ext in extensions.active()} + # build script dependency map root_script_basedir = os.path.join(paths.script_path, scriptdirname) if os.path.exists(root_script_basedir): for filename in sorted(os.listdir(root_script_basedir)): if not os.path.isfile(os.path.join(root_script_basedir, filename)): continue - script_dependency_map[filename] = { - "extension": None, - "extension_dirname": None, - "script_file": ScriptFile(paths.script_path, filename, os.path.join(root_script_basedir, filename)), - "requires": [], - "load_before": [], - "load_after": [], - } + if os.path.splitext(filename)[1].lower() != extension: + continue + + script_file = ScriptFile(paths.script_path, filename, os.path.join(root_script_basedir, filename)) + scripts[filename] = ScriptWithDependencies(filename, script_file, [], [], []) if include_extensions: for ext in extensions.active(): @@ -341,96 +370,54 @@ def list_scripts(scriptdirname, extension, *, include_extensions=True): if not os.path.isfile(extension_script.path): continue - script_canonical_name = ext.canonical_name + "/" + extension_script.filename - if ext.is_builtin: - script_canonical_name = "builtin/" + script_canonical_name + script_canonical_name = ("builtin/" if ext.is_builtin else "") + ext.canonical_name + "/" + extension_script.filename relative_path = scriptdirname + "/" + extension_script.filename - requires = '' - load_before = '' - load_after = '' - - if ext.metadata is not None: - requires = ext.metadata.get(relative_path, "Requires", fallback='') - load_before = ext.metadata.get(relative_path, "Before", fallback='') - load_after = ext.metadata.get(relative_path, "After", fallback='') - - # propagate directory level metadata - requires = requires + ',' + ext.metadata.get(scriptdirname, "Requires", fallback='') - load_before = load_before + ',' + ext.metadata.get(scriptdirname, "Before", fallback='') - load_after = load_after + ',' + ext.metadata.get(scriptdirname, "After", fallback='') - - requires = list(filter(None, re.split(r"[,\s]+", requires.lower()))) if requires else [] - load_after = list(filter(None, re.split(r"[,\s]+", load_after.lower()))) if load_after else [] - load_before = list(filter(None, re.split(r"[,\s]+", load_before.lower()))) if load_before else [] - - script_dependency_map[script_canonical_name] = { - "extension": ext.canonical_name, - "extension_dirname": ext.name, - "script_file": extension_script, - "requires": requires, - "load_before": load_before, - "load_after": load_after, - } + script = ScriptWithDependencies( + script_canonical_name=script_canonical_name, + file=extension_script, + requires=ext.metadata.get_script_requirements("Requires", relative_path, scriptdirname), + load_before=ext.metadata.get_script_requirements("Before", relative_path, scriptdirname), + load_after=ext.metadata.get_script_requirements("After", relative_path, scriptdirname), + ) - # resolve dependencies + scripts[script_canonical_name] = script + loaded_extensions_scripts[ext.canonical_name].append(script) - loaded_extensions = set() - for ext in extensions.active(): - loaded_extensions.add(ext.canonical_name) - - for script_canonical_name, script_data in script_dependency_map.items(): + for script_canonical_name, script in scripts.items(): # load before requires inverse dependency # in this case, append the script name into the load_after list of the specified script - for load_before_script in script_data['load_before']: + for load_before in script.load_before: # if this requires an individual script to be loaded before - if load_before_script in script_dependency_map: - script_dependency_map[load_before_script]['load_after'].append(script_canonical_name) - elif load_before_script in loaded_extensions: - for _, script_data2 in script_dependency_map.items(): - if script_data2['extension'] == load_before_script: - script_data2['load_after'].append(script_canonical_name) - break - - # resolve extension name in load_after lists - for load_after_script in list(script_data['load_after']): - if load_after_script not in script_dependency_map and load_after_script in loaded_extensions: - script_data['load_after'].remove(load_after_script) - for script_canonical_name2, script_data2 in script_dependency_map.items(): - if script_data2['extension'] == load_after_script: - script_data['load_after'].append(script_canonical_name2) - break - - # build the DAG - sorter = TopologicalSorter() - for script_canonical_name, script_data in script_dependency_map.items(): - requirement_met = True - for required_script in script_data['requires']: - # if this requires an individual script to be loaded - if required_script not in script_dependency_map and required_script not in loaded_extensions: - errors.report(f"Script \"{script_canonical_name}\" " - f"requires \"{required_script}\" to " - f"be loaded, but it is not. Skipping.", - exc_info=False) - requirement_met = False - break - if not requirement_met: - continue + other_script = scripts.get(load_before) + if other_script: + other_script.load_after.append(script_canonical_name) - sorter.add(script_canonical_name, *script_data['load_after']) + # if this requires an extension + other_extension_scripts = loaded_extensions_scripts.get(load_before) + if other_extension_scripts: + for other_script in other_extension_scripts: + other_script.load_after.append(script_canonical_name) - # sort the scripts - try: - ordered_script = sorter.static_order() - except CycleError: - errors.report("Cycle detected in script dependencies. Scripts will load in ascending order.", exc_info=True) - ordered_script = script_dependency_map.keys() + # if After mentions an extension, remove it and instead add all of its scripts + for load_after in list(script.load_after): + if load_after not in scripts and load_after in loaded_extensions_scripts: + script.load_after.remove(load_after) + + for other_script in loaded_extensions_scripts.get(load_after, []): + script.load_after.append(other_script.script_canonical_name) + + dependencies = {} + + for script_canonical_name, script in scripts.items(): + for required_script in script.requires: + if required_script not in scripts and required_script not in loaded_extensions: + errors.report(f'Script "{script_canonical_name}" requires "{required_script}" to be loaded, but it is not.', exc_info=False) - for script_canonical_name in ordered_script: - script_data = script_dependency_map[script_canonical_name] - scripts_list.append(script_data['script_file']) + dependencies[script_canonical_name] = script.load_after - scripts_list = [x for x in scripts_list if os.path.splitext(x.path)[1].lower() == extension and os.path.isfile(x.path)] + ordered_scripts = topological_sort(dependencies) + scripts_list = [scripts[script_canonical_name].file for script_canonical_name in ordered_scripts] return scripts_list From 314ae1535ea172fcdb0f5b3b2eecc5d4ce9112b5 Mon Sep 17 00:00:00 2001 From: Tom Haelbich Date: Mon, 20 Nov 2023 16:19:54 +0100 Subject: [PATCH 0290/1174] added option for default behavior of dir buttons --- modules/shared_options.py | 1 + modules/ui_extra_networks.py | 9 +++++++-- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/modules/shared_options.py b/modules/shared_options.py index 00b273faa54..1d2dca797ce 100644 --- a/modules/shared_options.py +++ b/modules/shared_options.py @@ -224,6 +224,7 @@ options_templates.update(options_section(('extra_networks', "Extra Networks"), { "extra_networks_show_hidden_directories": OptionInfo(True, "Show hidden directories").info("directory is hidden if its name starts with \".\"."), + "extra_networks_dir_button_function": OptionInfo(False, "Add a '/' to the beginning of directory buttons").info("Buttons will display the contents of the selected directory without acting as a search filter."), "extra_networks_hidden_models": OptionInfo("When searched", "Show cards for models in hidden directories", gr.Radio, {"choices": ["Always", "When searched", "Never"]}).info('"When searched" option will only show the item when the search string has 4 characters or more'), "extra_networks_default_multiplier": OptionInfo(1.0, "Default multiplier for extra networks", gr.Slider, {"minimum": 0.0, "maximum": 2.0, "step": 0.01}), "extra_networks_card_width": OptionInfo(0, "Card width for Extra Networks").info("in pixels"), diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index bd673285680..27a37295fc5 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -138,8 +138,13 @@ def create_html(self, tabname): continue subdir = os.path.abspath(x)[len(parentdir):].replace("\\", "/") - if not subdir.startswith("/"): - subdir = "/" + subdir + + if shared.opts.extra_networks_dir_button_function: + if not subdir.startswith("/"): + subdir = "/" + subdir + else: + while subdir.startswith("/"): + subdir = subdir[1:] is_empty = len(os.listdir(x)) == 0 if not is_empty and not subdir.endswith("/"): From 58c19545c83fa6925c9ce2216ee64964eb5129ce Mon Sep 17 00:00:00 2001 From: hidenorly Date: Tue, 21 Nov 2023 01:13:53 +0900 Subject: [PATCH 0291/1174] Add FP32 fallback support on sd_vae_approx This tries to execute interpolate with FP32 if it failed. Background is that on some environment such as Mx chip MacOS devices, we get error as follows: ``` "torch/nn/functional.py", line 3931, in interpolate return torch._C._nn.upsample_nearest2d(input, output_size, scale_factors) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ RuntimeError: "upsample_nearest2d_channels_last" not implemented for 'Half' ``` In this case, ```--no-half``` doesn't help to solve. Therefore this commits add the FP32 fallback execution to solve it. Note that the submodule may require additional modifications. The following is the example modification on the other submodule. ```repositories/stable-diffusion-stability-ai/ldm/modules/diffusionmodules/openaimodel.py class Upsample(nn.Module): ..snip.. def forward(self, x): assert x.shape[1] == self.channels if self.dims == 3: x = F.interpolate( x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest" ) else: try: x = F.interpolate(x, scale_factor=2, mode="nearest") except: x = F.interpolate(x.to(th.float32), scale_factor=2, mode="nearest").to(x.dtype) if self.use_conv: x = self.conv(x) return x ..snip.. ``` You can see the FP32 fallback execution as same as sd_vae_approx.py. --- modules/sd_vae_approx.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/modules/sd_vae_approx.py b/modules/sd_vae_approx.py index 3965e223e6f..8370493f932 100644 --- a/modules/sd_vae_approx.py +++ b/modules/sd_vae_approx.py @@ -21,7 +21,13 @@ def __init__(self): def forward(self, x): extra = 11 - x = nn.functional.interpolate(x, (x.shape[2] * 2, x.shape[3] * 2)) + try: + x = nn.functional.interpolate(x, (x.shape[2] * 2, x.shape[3] * 2)) + except RuntimeError as e: + if "not implemented for" in str(e) and "Half" in str(e): + x = nn.functional.interpolate(x.to(torch.float32), (x.shape[2] * 2, x.shape[3] * 2)).to(x.dtype) + else: + print(f"An unexpected RuntimeError occurred: {str(e)}") x = nn.functional.pad(x, (extra, extra, extra, extra)) for layer in [self.conv1, self.conv2, self.conv3, self.conv4, self.conv5, self.conv6, self.conv7, self.conv8, ]: From 8aa51f682c17d85f4585b9471860224568d25e95 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Tue, 21 Nov 2023 08:32:00 +0300 Subject: [PATCH 0292/1174] fix [Bug]: (Dev Branch) Placing "Dimensions" first in "ui_reorder_list" prevents start #14047 --- modules/ui.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/modules/ui.py b/modules/ui.py index b82f3c5e8b1..08e0ad775c6 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -635,12 +635,6 @@ def copy_image(img): scale_by.release(**on_change_args) button_update_resize_to.click(**on_change_args) - # the code below is meant to update the resolution label after the image in the image selection UI has changed. - # as it is now the event keeps firing continuously for inpaint edits, which ruins the page with constant requests. - # I assume this must be a gradio bug and for now we'll just do it for non-inpaint inputs. - for component in [init_img, sketch]: - component.change(fn=lambda: None, _js="updateImg2imgResizeToTextAfterChangingImage", inputs=[], outputs=[], show_progress=False) - tab_scale_to.select(fn=lambda: 0, inputs=[], outputs=[selected_scale_tab]) tab_scale_by.select(fn=lambda: 1, inputs=[], outputs=[selected_scale_tab]) @@ -701,6 +695,12 @@ def copy_image(img): if category not in {"accordions"}: scripts.scripts_img2img.setup_ui_for_section(category) + # the code below is meant to update the resolution label after the image in the image selection UI has changed. + # as it is now the event keeps firing continuously for inpaint edits, which ruins the page with constant requests. + # I assume this must be a gradio bug and for now we'll just do it for non-inpaint inputs. + for component in [init_img, sketch]: + component.change(fn=lambda: None, _js="updateImg2imgResizeToTextAfterChangingImage", inputs=[], outputs=[], show_progress=False) + def select_img2img_tab(tab): return gr.update(visible=tab in [2, 3, 4]), gr.update(visible=tab == 3), From 370a77f8e78e65a8a1339289d684cb43df142f70 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Tue, 21 Nov 2023 19:59:34 +0800 Subject: [PATCH 0293/1174] Option for using fp16 weight when apply lora --- extensions-builtin/Lora/networks.py | 16 ++++++++++++---- modules/initialize_util.py | 1 + modules/sd_models.py | 14 +++++++++++--- modules/shared_options.py | 1 + 4 files changed, 25 insertions(+), 7 deletions(-) diff --git a/extensions-builtin/Lora/networks.py b/extensions-builtin/Lora/networks.py index 0170dbfbd23..d22ed843868 100644 --- a/extensions-builtin/Lora/networks.py +++ b/extensions-builtin/Lora/networks.py @@ -388,18 +388,26 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn if module is not None and hasattr(self, 'weight'): try: with torch.no_grad(): - updown, ex_bias = module.calc_updown(self.weight) + if getattr(self, 'fp16_weight', None) is None: + weight = self.weight + bias = self.bias + else: + weight = self.fp16_weight.clone().to(self.weight.device) + bias = getattr(self, 'fp16_bias', None) + if bias is not None: + bias = bias.clone().to(self.bias.device) + updown, ex_bias = module.calc_updown(weight) - if len(self.weight.shape) == 4 and self.weight.shape[1] == 9: + if len(weight.shape) == 4 and weight.shape[1] == 9: # inpainting model. zero pad updown to make channel[1] 4 to 9 updown = torch.nn.functional.pad(updown, (0, 0, 0, 0, 0, 5)) - self.weight.copy_((self.weight.to(dtype=updown.dtype) + updown).to(dtype=self.weight.dtype)) + self.weight.copy_((weight.to(dtype=updown.dtype) + updown).to(dtype=self.weight.dtype)) if ex_bias is not None and hasattr(self, 'bias'): if self.bias is None: self.bias = torch.nn.Parameter(ex_bias).to(self.weight.dtype) else: - self.bias.copy_((self.bias.to(dtype=ex_bias.dtype) + ex_bias).to(dtype=self.bias.dtype)) + self.bias.copy_((bias + ex_bias).to(dtype=self.bias.dtype)) except RuntimeError as e: logging.debug(f"Network {net.name} layer {network_layer_name}: {e}") extra_network_lora.errors[net.name] = extra_network_lora.errors.get(net.name, 0) + 1 diff --git a/modules/initialize_util.py b/modules/initialize_util.py index 1b11ead61fd..7fb1d8d5093 100644 --- a/modules/initialize_util.py +++ b/modules/initialize_util.py @@ -178,6 +178,7 @@ def configure_opts_onchange(): shared.opts.onchange("gradio_theme", shared.reload_gradio_theme) shared.opts.onchange("cross_attention_optimization", wrap_queued_call(lambda: sd_hijack.model_hijack.redo_hijack(shared.sd_model)), call=False) shared.opts.onchange("fp8_storage", wrap_queued_call(lambda: sd_models.reload_model_weights()), call=False) + shared.opts.onchange("cache_fp16_weight", wrap_queued_call(lambda: sd_models.reload_model_weights()), call=False) startup_timer.record("opts onchange") diff --git a/modules/sd_models.py b/modules/sd_models.py index eb4914347b9..0a7777f1c82 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -413,14 +413,22 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer devices.dtype_unet = torch.float16 timer.record("apply half()") + for module in model.modules(): + if hasattr(module, 'fp16_weight'): + del module.fp16_weight + if hasattr(module, 'fp16_bias'): + del module.fp16_bias + if check_fp8(model): devices.fp8 = True first_stage = model.first_stage_model model.first_stage_model = None for module in model.modules(): - if isinstance(module, torch.nn.Conv2d): - module.to(torch.float8_e4m3fn) - elif isinstance(module, torch.nn.Linear): + if isinstance(module, (torch.nn.Conv2d, torch.nn.Linear)): + if shared.opts.cache_fp16_weight: + module.fp16_weight = module.weight.clone().half() + if module.bias is not None: + module.fp16_bias = module.bias.clone().half() module.to(torch.float8_e4m3fn) model.first_stage_model = first_stage timer.record("apply fp8") diff --git a/modules/shared_options.py b/modules/shared_options.py index d27f35e96fc..eaa9f1352c7 100644 --- a/modules/shared_options.py +++ b/modules/shared_options.py @@ -201,6 +201,7 @@ "persistent_cond_cache": OptionInfo(True, "Persistent cond cache").info("do not recalculate conds from prompts if prompts have not changed since previous calculation"), "batch_cond_uncond": OptionInfo(True, "Batch cond/uncond").info("do both conditional and unconditional denoising in one batch; uses a bit more VRAM during sampling, but improves speed; previously this was controlled by --always-batch-cond-uncond comandline argument"), "fp8_storage": OptionInfo("Disable", "FP8 weight", gr.Dropdown, {"choices": ["Disable", "Enable for SDXL", "Enable"]}).info("Use FP8 to store Linear/Conv layers' weight. Require pytorch>=2.1.0."), + "cache_fp16_weight": OptionInfo(False, "Cache FP16 weight for LoRA").info("Cache fp16 weight when enabling FP8, will increase the quality of LoRA. Use more system ram."), })) options_templates.update(options_section(('compatibility', "Compatibility"), { From f5d719d1f1baa775d838aa75d9af1971bcc78e8f Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Wed, 22 Nov 2023 01:45:56 +0800 Subject: [PATCH 0294/1174] Add forced reload for fp16 cache --- modules/initialize_util.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/initialize_util.py b/modules/initialize_util.py index 7fb1d8d5093..b6767138dd5 100644 --- a/modules/initialize_util.py +++ b/modules/initialize_util.py @@ -178,7 +178,7 @@ def configure_opts_onchange(): shared.opts.onchange("gradio_theme", shared.reload_gradio_theme) shared.opts.onchange("cross_attention_optimization", wrap_queued_call(lambda: sd_hijack.model_hijack.redo_hijack(shared.sd_model)), call=False) shared.opts.onchange("fp8_storage", wrap_queued_call(lambda: sd_models.reload_model_weights()), call=False) - shared.opts.onchange("cache_fp16_weight", wrap_queued_call(lambda: sd_models.reload_model_weights()), call=False) + shared.opts.onchange("cache_fp16_weight", wrap_queued_call(lambda: sd_models.reload_model_weights(forced_reload=True)), call=False) startup_timer.record("opts onchange") From 8fe1e195228162a4510925de05015f361efa1087 Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Wed, 22 Nov 2023 18:01:34 +0200 Subject: [PATCH 0295/1174] Update ruff to 0.1.6 --- .github/workflows/on_pull_request.yaml | 2 +- pyproject.toml | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/.github/workflows/on_pull_request.yaml b/.github/workflows/on_pull_request.yaml index 78e608ee945..9e44c806ab3 100644 --- a/.github/workflows/on_pull_request.yaml +++ b/.github/workflows/on_pull_request.yaml @@ -20,7 +20,7 @@ jobs: # not to have GHA download an (at the time of writing) 4 GB cache # of PyTorch and other dependencies. - name: Install Ruff - run: pip install ruff==0.0.272 + run: pip install ruff==0.1.6 - name: Run Ruff run: ruff . lint-js: diff --git a/pyproject.toml b/pyproject.toml index 80541a8f353..d03036e7d05 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,6 +16,7 @@ exclude = [ ignore = [ "E501", # Line too long + "E721", # Do not compare types, use `isinstance` "E731", # Do not assign a `lambda` expression, use a `def` "I001", # Import block is un-sorted or un-formatted From 066afda2f6f650fe108d285a239d08d59d92590d Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Wed, 22 Nov 2023 18:02:39 +0200 Subject: [PATCH 0296/1174] Simplify restart_sampler (suggested by ruff) --- modules/sd_samplers_extra.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/sd_samplers_extra.py b/modules/sd_samplers_extra.py index 1b981ca80c3..72fd0aa5e60 100644 --- a/modules/sd_samplers_extra.py +++ b/modules/sd_samplers_extra.py @@ -60,7 +60,7 @@ def heun_step(x, old_sigma, new_sigma, second_order=True): sigma_restart = get_sigmas_karras(restart_steps, sigmas[min_idx].item(), sigmas[max_idx].item(), device=sigmas.device)[:-1] while restart_times > 0: restart_times -= 1 - step_list.extend([(old_sigma, new_sigma) for (old_sigma, new_sigma) in zip(sigma_restart[:-1], sigma_restart[1:])]) + step_list.extend(zip(sigma_restart[:-1], sigma_restart[1:])) last_sigma = None for old_sigma, new_sigma in tqdm.tqdm(step_list, disable=disable): From ac2a981c4f30d77cdb674948fe0e2aa7264a93e1 Mon Sep 17 00:00:00 2001 From: wfjsw Date: Wed, 22 Nov 2023 22:40:24 -0600 Subject: [PATCH 0297/1174] use extension name for determining an extension is installed in the index --- modules/ui_extensions.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/modules/ui_extensions.py b/modules/ui_extensions.py index c0a73b57380..b670888111a 100644 --- a/modules/ui_extensions.py +++ b/modules/ui_extensions.py @@ -335,6 +335,11 @@ def normalize_git_url(url): return url +def get_extension_dirname_from_url(url): + *parts, last_part = url.split('/') + return normalize_git_url(last_part) + + def install_extension_from_url(dirname, url, branch_name=None): check_access() @@ -346,10 +351,7 @@ def install_extension_from_url(dirname, url, branch_name=None): assert url, 'No URL specified' if dirname is None or dirname == "": - *parts, last_part = url.split('/') - last_part = normalize_git_url(last_part) - - dirname = last_part + dirname = get_extension_dirname_from_url(url) target_dir = os.path.join(extensions.extensions_dir, dirname) assert not os.path.exists(target_dir), f'Extension directory already exists: {target_dir}' @@ -449,7 +451,7 @@ def get_date(info: dict, key): def refresh_available_extensions_from_data(hide_tags, sort_column, filter_text=""): extlist = available_extensions["extensions"] - installed_extension_urls = {normalize_git_url(extension.remote): extension.name for extension in extensions.extensions} + installed_extensions = {extension.name for extension in extensions.extensions} tags = available_extensions.get("tags", {}) tags_to_hide = set(hide_tags) @@ -482,7 +484,7 @@ def refresh_available_extensions_from_data(hide_tags, sort_column, filter_text=" if url is None: continue - existing = installed_extension_urls.get(normalize_git_url(url), None) + existing = get_extension_dirname_from_url(url) in installed_extensions extension_tags = extension_tags + ["installed"] if existing else extension_tags if any(x for x in extension_tags if x in tags_to_hide): From 86b99b1e98fcdd6e7e5f6017071944364e01e6ad Mon Sep 17 00:00:00 2001 From: Jabasukuriputo Wang Date: Fri, 24 Nov 2023 11:28:54 -0600 Subject: [PATCH 0298/1174] Move exception_records related methods to errors.py --- modules/errors.py | 18 ++++++++++++++++-- modules/sysinfo.py | 17 +---------------- 2 files changed, 17 insertions(+), 18 deletions(-) diff --git a/modules/errors.py b/modules/errors.py index 192cd8ffd62..ac9f1ee5e1f 100644 --- a/modules/errors.py +++ b/modules/errors.py @@ -6,6 +6,21 @@ exception_records = [] +def format_traceback(tb): + return [[f"{x.filename}, line {x.lineno}, {x.name}", x.line] for x in traceback.extract_tb(tb)] + + +def format_exception(e, tb): + return {"exception": str(e), "traceback": format_traceback(tb)} + + +def get_exceptions(): + try: + return list(reversed(exception_records)) + except Exception as e: + return str(e) + + def record_exception(): _, e, tb = sys.exc_info() if e is None: @@ -14,8 +29,7 @@ def record_exception(): if exception_records and exception_records[-1] == e: return - from modules import sysinfo - exception_records.append(sysinfo.format_exception(e, tb)) + exception_records.append(format_exception(e, tb)) if len(exception_records) > 5: exception_records.pop(0) diff --git a/modules/sysinfo.py b/modules/sysinfo.py index 7d906e1febd..226b204d96e 100644 --- a/modules/sysinfo.py +++ b/modules/sysinfo.py @@ -85,7 +85,7 @@ def get_dict(): "Checksum": checksum_token, "Commandline": sys.argv, "Torch env info": get_torch_sysinfo(), - "Exceptions": get_exceptions(), + "Exceptions": errors.get_exceptions(), "CPU": { "model": platform.processor(), "count logical": psutil.cpu_count(logical=True), @@ -105,21 +105,6 @@ def get_dict(): return res -def format_traceback(tb): - return [[f"{x.filename}, line {x.lineno}, {x.name}", x.line] for x in traceback.extract_tb(tb)] - - -def format_exception(e, tb): - return {"exception": str(e), "traceback": format_traceback(tb)} - - -def get_exceptions(): - try: - return list(reversed(errors.exception_records)) - except Exception as e: - return str(e) - - def get_environment(): return {k: os.environ[k] for k in sorted(os.environ) if k in environment_whitelist} From 5cedc8f9b2b51f392e7c8f5e29286466e3bee8d6 Mon Sep 17 00:00:00 2001 From: Jabasukuriputo Wang Date: Fri, 24 Nov 2023 11:30:30 -0600 Subject: [PATCH 0299/1174] remove traceback in sysinfo --- modules/sysinfo.py | 1 - 1 file changed, 1 deletion(-) diff --git a/modules/sysinfo.py b/modules/sysinfo.py index 226b204d96e..1d058950ae6 100644 --- a/modules/sysinfo.py +++ b/modules/sysinfo.py @@ -1,7 +1,6 @@ import json import os import sys -import traceback import platform import hashlib From 40ac134c553ac824d4a96666bba14d550300daa5 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Sat, 25 Nov 2023 12:35:09 +0800 Subject: [PATCH 0300/1174] Fix pre-fp8 --- modules/sd_models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/sd_models.py b/modules/sd_models.py index 0a7777f1c82..90437c87ac1 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -357,7 +357,7 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer sd_model_hash = checkpoint_info.calculate_shorthash() timer.record("calculate hash") - if not check_fp8(model) and devices.fp8: + if devices.fp8: # prevent model to load state dict in fp8 model.half() From 3a9bf4ac10d99feb81b0e637417a108d3fa5ac06 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sun, 26 Nov 2023 08:29:12 +0300 Subject: [PATCH 0301/1174] move file --- {modules => extensions-builtin/hypertile}/hypertile.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename {modules => extensions-builtin/hypertile}/hypertile.py (100%) diff --git a/modules/hypertile.py b/extensions-builtin/hypertile/hypertile.py similarity index 100% rename from modules/hypertile.py rename to extensions-builtin/hypertile/hypertile.py From d2e0c1ca132f4f0d98b77397a9f353d4ad8e7c4b Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sun, 26 Nov 2023 10:51:45 +0300 Subject: [PATCH 0302/1174] rework hypertile into a built-in extension --- README.md | 1 + extensions-builtin/hypertile/hypertile.py | 221 ++++++++---------- .../hypertile/scripts/hypertile_script.py | 73 ++++++ modules/processing.py | 37 ++- modules/shared_options.py | 8 - 5 files changed, 186 insertions(+), 154 deletions(-) create mode 100644 extensions-builtin/hypertile/scripts/hypertile_script.py diff --git a/README.md b/README.md index 25ba070e7e8..3b3f93adc19 100644 --- a/README.md +++ b/README.md @@ -174,5 +174,6 @@ Licenses for borrowed code can be found in `Settings -> Licenses` screen, and al - TAESD - Ollin Boer Bohan - https://github.com/madebyollin/taesd - LyCORIS - KohakuBlueleaf - Restart sampling - lambertae - https://github.com/Newbeeer/diffusion_restart_sampling +- Hypertile - tfernd - https://github.com/tfernd/HyperTile - Initial Gradio script - posted on 4chan by an Anonymous user. Thank you Anonymous user. - (You) diff --git a/extensions-builtin/hypertile/hypertile.py b/extensions-builtin/hypertile/hypertile.py index be898fce43f..a40c131186d 100644 --- a/extensions-builtin/hypertile/hypertile.py +++ b/extensions-builtin/hypertile/hypertile.py @@ -1,10 +1,13 @@ """ Hypertile module for splitting attention layers in SD-1.5 U-Net and SD-1.5 VAE -Warn : The patch works well only if the input image has a width and height that are multiples of 128 -Author : @tfernd Github : https://github.com/tfernd/HyperTile +Warn: The patch works well only if the input image has a width and height that are multiples of 128 +Original author: @tfernd Github: https://github.com/tfernd/HyperTile """ from __future__ import annotations + +import functools +from dataclasses import dataclass from typing import Callable from typing_extensions import Literal @@ -18,6 +21,19 @@ from einops import rearrange + +@dataclass +class HypertileParams: + depth = 0 + layer_name = "" + tile_size: int = 0 + swap_size: int = 0 + aspect_ratio: float = 1.0 + forward = None + enabled = False + + + # TODO add SD-XL layers DEPTH_LAYERS = { 0: [ @@ -176,6 +192,7 @@ RNG_INSTANCE = random.Random() + def random_divisor(value: int, min_value: int, /, max_options: int = 1) -> int: """ Returns a random divisor of value that @@ -193,10 +210,13 @@ def random_divisor(value: int, min_value: int, /, max_options: int = 1) -> int: return ns[idx] + def set_hypertile_seed(seed: int) -> None: RNG_INSTANCE.seed(seed) -def largest_tile_size_available(width:int, height:int) -> int: + +@functools.cache +def largest_tile_size_available(width: int, height: int) -> int: """ Calculates the largest tile size available for a given width and height Tile size is always a power of 2 @@ -207,6 +227,7 @@ def largest_tile_size_available(width:int, height:int) -> int: largest_tile_size_available *= 2 return largest_tile_size_available + def iterative_closest_divisors(hw:int, aspect_ratio:float) -> tuple[int, int]: """ Finds h and w such that h*w = hw and h/w = aspect_ratio @@ -219,6 +240,7 @@ def iterative_closest_divisors(hw:int, aspect_ratio:float) -> tuple[int, int]: closest_pair = pairs[ratios.index(closest_ratio)] # closest pair of divisors to aspect_ratio return closest_pair + @cache def find_hw_candidates(hw:int, aspect_ratio:float) -> tuple[int, int]: """ @@ -240,132 +262,87 @@ def find_hw_candidates(hw:int, aspect_ratio:float) -> tuple[int, int]: w = int(w_candidate) return h, w -@contextmanager -def split_attention( - layer: nn.Module, - /, - aspect_ratio: float, # width/height - tile_size: int = 128, # 128 for VAE - swap_size: int = 1, # 1 for VAE - *, - disable: bool = False, - max_depth: Literal[0, 1, 2, 3] = 0, # ! Try 0 or 1 - scale_depth: bool = True, # scale the tile-size depending on the depth - is_sdxl: bool = False, # is the model SD-XL -): - # Hijacks AttnBlock from ldm and Attention from diffusers - - if disable: - logging.info(f"Attention for {layer.__class__.__qualname__} not splitted") - yield - return - - latent_tile_size = max(128, tile_size) // 8 - - def self_attn_forward(forward: Callable, depth: int, layer_name: str, module: nn.Module) -> Callable: - @wraps(forward) - def wrapper(*args, **kwargs): - x = args[0] - - # VAE - if x.ndim == 4: - b, c, h, w = x.shape - - nh = random_divisor(h, latent_tile_size, swap_size) - nw = random_divisor(w, latent_tile_size, swap_size) - - if nh * nw > 1: - x = rearrange(x, "b c (nh h) (nw w) -> (b nh nw) c h w", nh=nh, nw=nw) # split into nh * nw tiles - - out = forward(x, *args[1:], **kwargs) - - if nh * nw > 1: - out = rearrange(out, "(b nh nw) c h w -> b c (nh h) (nw w)", nh=nh, nw=nw) - - # U-Net - else: - hw: int = x.size(1) - h, w = find_hw_candidates(hw, aspect_ratio) - assert h * w == hw, f"Invalid aspect ratio {aspect_ratio} for input of shape {x.shape}, hw={hw}, h={h}, w={w}" - factor = 2**depth if scale_depth else 1 - nh = random_divisor(h, latent_tile_size * factor, swap_size) - nw = random_divisor(w, latent_tile_size * factor, swap_size) +def self_attn_forward(params: HypertileParams, scale_depth=True) -> Callable: + + @wraps(params.forward) + def wrapper(*args, **kwargs): + if not params.enabled: + return params.forward(*args, **kwargs) - module._split_sizes_hypertile.append((nh, nw)) # type: ignore + latent_tile_size = max(128, params.tile_size) // 8 + x = args[0] - if nh * nw > 1: - x = rearrange(x, "b (nh h nw w) c -> (b nh nw) (h w) c", h=h // nh, w=w // nw, nh=nh, nw=nw) + # VAE + if x.ndim == 4: + b, c, h, w = x.shape - out = forward(x, *args[1:], **kwargs) + nh = random_divisor(h, latent_tile_size, params.swap_size) + nw = random_divisor(w, latent_tile_size, params.swap_size) - if nh * nw > 1: - out = rearrange(out, "(b nh nw) hw c -> b nh nw hw c", nh=nh, nw=nw) - out = rearrange(out, "b nh nw (h w) c -> b (nh h nw w) c", h=h // nh, w=w // nw) + if nh * nw > 1: + x = rearrange(x, "b c (nh h) (nw w) -> (b nh nw) c h w", nh=nh, nw=nw) # split into nh * nw tiles - return out + out = params.forward(x, *args[1:], **kwargs) - return wrapper + if nh * nw > 1: + out = rearrange(out, "(b nh nw) c h w -> b c (nh h) (nw w)", nh=nh, nw=nw) - # Handle hijacking the forward method and recovering afterwards - try: - if is_sdxl: - layers = DEPTH_LAYERS_XL + # U-Net else: - layers = DEPTH_LAYERS - for depth in range(max_depth + 1): - for layer_name, module in layer.named_modules(): + hw: int = x.size(1) + h, w = find_hw_candidates(hw, params.aspect_ratio) + assert h * w == hw, f"Invalid aspect ratio {params.aspect_ratio} for input of shape {x.shape}, hw={hw}, h={h}, w={w}" + + factor = 2 ** params.depth if scale_depth else 1 + nh = random_divisor(h, latent_tile_size * factor, params.swap_size) + nw = random_divisor(w, latent_tile_size * factor, params.swap_size) + + if nh * nw > 1: + x = rearrange(x, "b (nh h nw w) c -> (b nh nw) (h w) c", h=h // nh, w=w // nw, nh=nh, nw=nw) + + out = params.forward(x, *args[1:], **kwargs) + + if nh * nw > 1: + out = rearrange(out, "(b nh nw) hw c -> b nh nw hw c", nh=nh, nw=nw) + out = rearrange(out, "b nh nw (h w) c -> b (nh h nw w) c", h=h // nh, w=w // nw) + + return out + + return wrapper + + +def hypertile_hook_model(model: nn.Module, width, height, *, enable=False, tile_size_max=128, swap_size=1, max_depth=3, is_sdxl=False): + hypertile_layers = getattr(model, "__webui_hypertile_layers", None) + if hypertile_layers is None: + if not enable: + return + + hypertile_layers = {} + layers = DEPTH_LAYERS_XL if is_sdxl else DEPTH_LAYERS + + for depth in range(4): + for layer_name, module in model.named_modules(): if any(layer_name.endswith(try_name) for try_name in layers[depth]): - # print input shape for debugging - logging.debug(f"HyperTile hijacking attention layer at depth {depth}: {layer_name}") - # hijack - module._original_forward_hypertile = module.forward - module.forward = self_attn_forward(module.forward, depth, layer_name, module) - module._split_sizes_hypertile = [] - yield - finally: - for layer_name, module in layer.named_modules(): - # remove hijack - if hasattr(module, "_original_forward_hypertile"): - if module._split_sizes_hypertile: - logging.debug(f"layer {layer_name} splitted with ({module._split_sizes_hypertile})") - # recover - module.forward = module._original_forward_hypertile - del module._original_forward_hypertile - del module._split_sizes_hypertile - -def hypertile_context_vae(model:nn.Module, aspect_ratio:float, tile_size:int, opts): - """ - Returns context manager for VAE - """ - enabled = opts.hypertile_split_vae_attn - swap_size = opts.hypertile_swap_size_vae - max_depth = opts.hypertile_max_depth_vae - tile_size_max = opts.hypertile_max_tile_vae - return split_attention( - model, - aspect_ratio=aspect_ratio, - tile_size=min(tile_size, tile_size_max), - swap_size=swap_size, - disable=not enabled, - max_depth=max_depth, - is_sdxl=False, - ) - -def hypertile_context_unet(model:nn.Module, aspect_ratio:float, tile_size:int, opts, is_sdxl:bool): - """ - Returns context manager for U-Net - """ - enabled = opts.hypertile_split_unet_attn - swap_size = opts.hypertile_swap_size_unet - max_depth = opts.hypertile_max_depth_unet - tile_size_max = opts.hypertile_max_tile_unet - return split_attention( - model, - aspect_ratio=aspect_ratio, - tile_size=min(tile_size, tile_size_max), - swap_size=swap_size, - disable=not enabled, - max_depth=max_depth, - is_sdxl=is_sdxl, - ) + params = HypertileParams() + module.__webui_hypertile_params = params + params.forward = module.forward + params.depth = depth + params.layer_name = layer_name + module.forward = self_attn_forward(params) + + hypertile_layers[layer_name] = 1 + + model.__webui_hypertile_layers = hypertile_layers + + aspect_ratio = width / height + tile_size = min(largest_tile_size_available(width, height), tile_size_max) + + for layer_name, module in model.named_modules(): + if layer_name in hypertile_layers: + params = module.__webui_hypertile_params + + params.tile_size = tile_size + params.swap_size = swap_size + params.aspect_ratio = aspect_ratio + params.enabled = enable and params.depth <= max_depth diff --git a/extensions-builtin/hypertile/scripts/hypertile_script.py b/extensions-builtin/hypertile/scripts/hypertile_script.py new file mode 100644 index 00000000000..3cc29cd1fa4 --- /dev/null +++ b/extensions-builtin/hypertile/scripts/hypertile_script.py @@ -0,0 +1,73 @@ +import hypertile +from modules import scripts, script_callbacks, shared + + +class ScriptHypertile(scripts.Script): + name = "Hypertile" + + def title(self): + return self.name + + def show(self, is_img2img): + return scripts.AlwaysVisible + + def process(self, p, *args): + hypertile.set_hypertile_seed(p.all_seeds[0]) + + configure_hypertile(p.width, p.height, enable_unet=shared.opts.hypertile_enable_unet) + + def before_hr(self, p, *args): + configure_hypertile(p.hr_upscale_to_x, p.hr_upscale_to_y, enable_unet=shared.opts.hypertile_enable_unet_secondpass or shared.opts.hypertile_enable_unet) + + +def configure_hypertile(width, height, enable_unet=True): + hypertile.hypertile_hook_model( + shared.sd_model.first_stage_model, + width, + height, + swap_size=shared.opts.hypertile_swap_size_vae, + max_depth=shared.opts.hypertile_max_depth_vae, + tile_size_max=shared.opts.hypertile_max_tile_vae, + enable=shared.opts.hypertile_enable_vae, + ) + + hypertile.hypertile_hook_model( + shared.sd_model.model, + width, + height, + swap_size=shared.opts.hypertile_swap_size_unet, + max_depth=shared.opts.hypertile_max_depth_unet, + tile_size_max=shared.opts.hypertile_max_tile_unet, + enable=enable_unet, + is_sdxl=shared.sd_model.is_sdxl + ) + + +def on_ui_settings(): + import gradio as gr + + options = { + "hypertile_explanation": shared.OptionHTML(""" + Hypertile optimizes the self-attention layer within U-Net and VAE models, + resulting in a reduction in computation time ranging from 1 to 4 times. The larger the generated image is, the greater the + benefit. + """), + + "hypertile_enable_unet": shared.OptionInfo(False, "Enable Hypertile U-Net").info("noticeable change in details of the generated picture; if enabled, overrides the setting below"), + "hypertile_enable_unet_secondpass": shared.OptionInfo(False, "Enable Hypertile U-Net for hires fix second pass"), + "hypertile_max_depth_unet": shared.OptionInfo(3, "Hypertile U-Net max depth", gr.Slider, {"minimum": 0, "maximum": 3, "step": 1}), + "hypertile_max_tile_unet": shared.OptionInfo(256, "Hypertile U-net max tile size", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}), + "hypertile_swap_size_unet": shared.OptionInfo(3, "Hypertile U-net swap size", gr.Slider, {"minimum": 0, "maximum": 6, "step": 1}), + + "hypertile_enable_vae": shared.OptionInfo(False, "Enable Hypertile VAE").info("minimal change in the generated picture"), + "hypertile_max_depth_vae": shared.OptionInfo(3, "Hypertile VAE max depth", gr.Slider, {"minimum": 0, "maximum": 3, "step": 1}), + "hypertile_max_tile_vae": shared.OptionInfo(128, "Hypertile VAE max tile size", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}), + "hypertile_swap_size_vae": shared.OptionInfo(3, "Hypertile VAE swap size ", gr.Slider, {"minimum": 0, "maximum": 6, "step": 1}), + } + + for name, opt in options.items(): + opt.section = ('hypertile', "Hypertile") + shared.opts.add_option(name, opt) + + +script_callbacks.on_ui_settings(on_ui_settings) diff --git a/modules/processing.py b/modules/processing.py index 36c2be5e536..ac58ef86940 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -24,7 +24,6 @@ import modules.shared as shared import modules.paths as paths import modules.face_restoration -from modules.hypertile import set_hypertile_seed, largest_tile_size_available, hypertile_context_unet, hypertile_context_vae import modules.images as images import modules.styles import modules.sd_models as sd_models @@ -861,8 +860,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: p.comment(comment) p.extra_generation_params.update(model_hijack.extra_generation_params) - set_hypertile_seed(p.seed) - # add batch size + hypertile status to information to reproduce the run + if p.n_iter > 1: shared.state.job = f"Batch {n+1} out of {p.n_iter}" @@ -874,8 +872,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: else: if opts.sd_vae_decode_method != 'Full': p.extra_generation_params['VAE Decoder'] = opts.sd_vae_decode_method - with hypertile_context_vae(p.sd_model.first_stage_model, aspect_ratio=p.width / p.height, tile_size=largest_tile_size_available(p.width, p.height), opts=shared.opts): - x_samples_ddim = decode_latent_batch(p.sd_model, samples_ddim, target_device=devices.cpu, check_for_nans=True) + x_samples_ddim = decode_latent_batch(p.sd_model, samples_ddim, target_device=devices.cpu, check_for_nans=True) x_samples_ddim = torch.stack(x_samples_ddim).float() x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) @@ -1141,25 +1138,23 @@ def init(self, all_prompts, all_seeds, all_subseeds): def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts): self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model) - aspect_ratio = self.width / self.height + x = self.rng.next() - tile_size = largest_tile_size_available(self.width, self.height) - with hypertile_context_vae(self.sd_model.first_stage_model, aspect_ratio=aspect_ratio, tile_size=tile_size, opts=shared.opts): - with hypertile_context_unet(self.sd_model.model, aspect_ratio=aspect_ratio, tile_size=tile_size, is_sdxl=shared.sd_model.is_sdxl, opts=shared.opts): - samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x)) + samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x)) del x + if not self.enable_hr: return samples devices.torch_gc() if self.latent_scale_mode is None: - with hypertile_context_vae(self.sd_model.first_stage_model, aspect_ratio=aspect_ratio, tile_size=tile_size, opts=shared.opts): - decoded_samples = torch.stack(decode_latent_batch(self.sd_model, samples, target_device=devices.cpu, check_for_nans=True)).to(dtype=torch.float32) + decoded_samples = torch.stack(decode_latent_batch(self.sd_model, samples, target_device=devices.cpu, check_for_nans=True)).to(dtype=torch.float32) else: decoded_samples = None with sd_models.SkipWritingToConfig(): sd_models.reload_model_weights(info=self.hr_checkpoint_info) + return self.sample_hr_pass(samples, decoded_samples, seeds, subseeds, subseed_strength, prompts) def sample_hr_pass(self, samples, decoded_samples, seeds, subseeds, subseed_strength, prompts): @@ -1244,18 +1239,15 @@ def save_intermediate(image, index): if self.scripts is not None: self.scripts.before_hr(self) - tile_size = largest_tile_size_available(target_width, target_height) - aspect_ratio = self.width / self.height - with hypertile_context_vae(self.sd_model.first_stage_model, aspect_ratio=aspect_ratio, tile_size=tile_size, opts=shared.opts): - with hypertile_context_unet(self.sd_model.model, aspect_ratio=aspect_ratio, tile_size=tile_size, is_sdxl=shared.sd_model.is_sdxl, opts=shared.opts): - samples = self.sampler.sample_img2img(self, samples, noise, self.hr_c, self.hr_uc, steps=self.hr_second_pass_steps or self.steps, image_conditioning=image_conditioning) + + samples = self.sampler.sample_img2img(self, samples, noise, self.hr_c, self.hr_uc, steps=self.hr_second_pass_steps or self.steps, image_conditioning=image_conditioning) sd_models.apply_token_merging(self.sd_model, self.get_token_merging_ratio()) self.sampler = None devices.torch_gc() - with hypertile_context_vae(self.sd_model.first_stage_model, aspect_ratio=aspect_ratio, tile_size=tile_size, opts=shared.opts): - decoded_samples = decode_latent_batch(self.sd_model, samples, target_device=devices.cpu, check_for_nans=True) + + decoded_samples = decode_latent_batch(self.sd_model, samples, target_device=devices.cpu, check_for_nans=True) self.is_hr_pass = False return decoded_samples @@ -1532,11 +1524,8 @@ def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subs if self.initial_noise_multiplier != 1.0: self.extra_generation_params["Noise multiplier"] = self.initial_noise_multiplier x *= self.initial_noise_multiplier - aspect_ratio = self.width / self.height - tile_size = largest_tile_size_available(self.width, self.height) - with hypertile_context_vae(self.sd_model.first_stage_model, aspect_ratio=aspect_ratio, tile_size=tile_size, opts=shared.opts): - with hypertile_context_unet(self.sd_model.model, aspect_ratio=aspect_ratio, tile_size=tile_size, is_sdxl=shared.sd_model.is_sdxl, opts=shared.opts): - samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning, image_conditioning=self.image_conditioning) + + samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning, image_conditioning=self.image_conditioning) if self.mask is not None: samples = samples * self.nmask + self.init_latent * self.mask diff --git a/modules/shared_options.py b/modules/shared_options.py index 28a48906937..d40db5306ee 100644 --- a/modules/shared_options.py +++ b/modules/shared_options.py @@ -200,14 +200,6 @@ "pad_cond_uncond": OptionInfo(False, "Pad prompt/negative prompt to be same length", infotext='Pad conds').info("improves performance when prompt and negative prompt have different lengths; changes seeds"), "persistent_cond_cache": OptionInfo(True, "Persistent cond cache").info("do not recalculate conds from prompts if prompts have not changed since previous calculation"), "batch_cond_uncond": OptionInfo(True, "Batch cond/uncond").info("do both conditional and unconditional denoising in one batch; uses a bit more VRAM during sampling, but improves speed; previously this was controlled by --always-batch-cond-uncond comandline argument"), - "hypertile_split_unet_attn" : OptionInfo(False, "Split attention in Unet with HyperTile").link("Github", "https://github.com/tfernd/HyperTile").info("improves performance; changes behavior, but deterministic"), - "hypertile_split_vae_attn": OptionInfo(False, "Split attention in VAE with HyperTile").link("Github", "https://github.com/tfernd/HyperTile").info("improves performance; changes behavior, but deterministic"), - "hypertile_max_depth_vae" : OptionInfo(3, "Max depth for VAE HyperTile hijack", gr.Slider, {"minimum": 0, "maximum": 3, "step": 1}).link("Github", "https://github.com/tfernd/HyperTile"), - "hypertile_max_depth_unet" : OptionInfo(3, "Max depth for Unet HyperTile hijack", gr.Slider, {"minimum": 0, "maximum": 3, "step": 1}).link("Github", "https://github.com/tfernd/HyperTile"), - "hypertile_max_tile_vae" : OptionInfo(128, "Max tile size for VAE HyperTile hijack", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}).link("Github", "https://github.com/tfernd/HyperTile"), - "hypertile_max_tile_unet" : OptionInfo(256, "Max tile size for Unet HyperTile hijack", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}).link("Github", "https://github.com/tfernd/HyperTile"), - "hypertile_swap_size_unet": OptionInfo(3, "Swap size for Unet HyperTile hijack", gr.Slider, {"minimum": 0, "maximum": 6, "step": 1}).link("Github", "https://github.com/tfernd/HyperTile"), - "hypertile_swap_size_vae": OptionInfo(3, "Swap size for VAE HyperTile hijack", gr.Slider, {"minimum": 0, "maximum": 6, "step": 1}).link("Github", "https://github.com/tfernd/HyperTile"), })) options_templates.update(options_section(('compatibility', "Compatibility"), { From c5a0c59a83c950c64bc44427d3478aaa78c296cf Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sun, 26 Nov 2023 11:36:17 +0300 Subject: [PATCH 0303/1174] do not save HTML explanations from options page to config --- modules/options.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/modules/options.py b/modules/options.py index ab40aff73f0..7703d80ec76 100644 --- a/modules/options.py +++ b/modules/options.py @@ -76,7 +76,7 @@ class Options: def __init__(self, data_labels: dict[str, OptionInfo], restricted_opts): self.data_labels = data_labels - self.data = {k: v.default for k, v in self.data_labels.items()} + self.data = {k: v.default for k, v in self.data_labels.items() if not v.do_not_save} self.restricted_opts = restricted_opts def __setattr__(self, key, value): @@ -210,7 +210,7 @@ def dumpjson(self): def add_option(self, key, info): self.data_labels[key] = info - if key not in self.data: + if key not in self.data and not info.do_not_save: self.data[key] = info.default def reorder(self): From d1750e5eca6fd95db3516928cad18b32e557f56f Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sun, 26 Nov 2023 11:37:12 +0300 Subject: [PATCH 0304/1174] fix linter errors --- extensions-builtin/hypertile/hypertile.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/extensions-builtin/hypertile/hypertile.py b/extensions-builtin/hypertile/hypertile.py index a40c131186d..feb02fd2737 100644 --- a/extensions-builtin/hypertile/hypertile.py +++ b/extensions-builtin/hypertile/hypertile.py @@ -9,11 +9,8 @@ import functools from dataclasses import dataclass from typing import Callable -from typing_extensions import Literal -import logging from functools import wraps, cache -from contextlib import contextmanager import math import torch.nn as nn From 29f04149b60bcf6e8e2b41a161d6cc7e8981710f Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sun, 26 Nov 2023 12:07:33 +0300 Subject: [PATCH 0305/1174] update torch to 2.1.0 --- modules/errors.py | 4 ++-- modules/launch_utils.py | 6 +++--- webui-macos-env.sh | 2 +- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/modules/errors.py b/modules/errors.py index eb234a83811..c534a5d66fe 100644 --- a/modules/errors.py +++ b/modules/errors.py @@ -107,8 +107,8 @@ def check_versions(): import torch import gradio - expected_torch_version = "2.0.0" - expected_xformers_version = "0.0.20" + expected_torch_version = "2.1.0" + expected_xformers_version = "0.0.22.post7" expected_gradio_version = "3.41.2" if version.parse(torch.__version__) < version.parse(expected_torch_version): diff --git a/modules/launch_utils.py b/modules/launch_utils.py index 264ec9ca6a4..1f2b6c5e8f6 100644 --- a/modules/launch_utils.py +++ b/modules/launch_utils.py @@ -308,11 +308,11 @@ def requirements_met(requirements_file): def prepare_environment(): - torch_index_url = os.environ.get('TORCH_INDEX_URL', "https://download.pytorch.org/whl/cu118") - torch_command = os.environ.get('TORCH_COMMAND', f"pip install torch==2.0.1 torchvision==0.15.2 --extra-index-url {torch_index_url}") + torch_index_url = os.environ.get('TORCH_INDEX_URL', "https://download.pytorch.org/whl/cu121") + torch_command = os.environ.get('TORCH_COMMAND', f"pip install torch==2.1.0 torchvision==0.16.0 --extra-index-url {torch_index_url}") requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt") - xformers_package = os.environ.get('XFORMERS_PACKAGE', 'xformers==0.0.20') + xformers_package = os.environ.get('XFORMERS_PACKAGE', 'xformers==0.0.22.post7') clip_package = os.environ.get('CLIP_PACKAGE', "https://github.com/openai/CLIP/archive/d50d76daa670286dd6cacf3bcd80b5e4823fc8e1.zip") openclip_package = os.environ.get('OPENCLIP_PACKAGE', "https://github.com/mlfoundations/open_clip/archive/bb6e834e9c70d9c27d0dc3ecedeebeaeb1ffad6b.zip") diff --git a/webui-macos-env.sh b/webui-macos-env.sh index 24bc5c42615..db7e8b1a05b 100644 --- a/webui-macos-env.sh +++ b/webui-macos-env.sh @@ -11,7 +11,7 @@ fi export install_dir="$HOME" export COMMANDLINE_ARGS="--skip-torch-cuda-test --upcast-sampling --no-half-vae --use-cpu interrogate" -export TORCH_COMMAND="pip install torch==2.0.1 torchvision==0.15.2" +export TORCH_COMMAND="pip install torch==2.1.0 torchvision==0.16.0" export PYTORCH_ENABLE_MPS_FALLBACK=1 #################################################################### From 2a40d3c603448d15e209814366f2d6ab25e52398 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sun, 26 Nov 2023 14:58:47 +0300 Subject: [PATCH 0306/1174] compact prompt layout: preserve scroll when switching between lora tabs --- javascript/extraNetworks.js | 4 ++++ modules/ui_extra_networks.py | 5 ++++- style.css | 12 ++++++++++-- 3 files changed, 18 insertions(+), 3 deletions(-) diff --git a/javascript/extraNetworks.js b/javascript/extraNetworks.js index a1bf29a8c72..a787372cff3 100644 --- a/javascript/extraNetworks.js +++ b/javascript/extraNetworks.js @@ -130,6 +130,10 @@ function extraNetworksMovePromptToTab(tabname, id, showPrompt, showNegativePromp } else { promptContainer.insertBefore(prompt, promptContainer.firstChild); } + + if (elem) { + elem.classList.toggle('extra-page-prompts-active', showNegativePrompt || showPrompt); + } } diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index f03e20337c2..f3b23cc9c29 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -370,6 +370,9 @@ def create_ui(interface: gr.Blocks, unrelated_tabs, tabname): for page in ui.stored_extra_pages: with gr.Tab(page.title, elem_id=f"{tabname}_{page.id_page}", elem_classes=["extra-page"]) as tab: + with gr.Column(elem_id=f"{tabname}_{page.id_page}_prompts", elem_classes=["extra-page-prompts"]): + pass + elem_id = f"{tabname}_{page.id_page}_cards_html" page_elem = gr.HTML('Loading...', elem_id=elem_id) ui.pages.append(page_elem) @@ -400,7 +403,7 @@ def create_ui(interface: gr.Blocks, unrelated_tabs, tabname): allow_prompt = "true" if page.allow_prompt else "false" allow_negative_prompt = "true" if page.allow_negative_prompt else "false" - jscode = 'extraNetworksTabSelected("' + tabname + '", "' + f"{tabname}_{page.id_page}" + '", ' + allow_prompt + ', ' + allow_negative_prompt + ');' + jscode = 'extraNetworksTabSelected("' + tabname + '", "' + f"{tabname}_{page.id_page}_prompts" + '", ' + allow_prompt + ', ' + allow_negative_prompt + ');' tab.select(fn=lambda: [gr.update(visible=True) for _ in tab_controls], _js='function(){ ' + jscode + ' }', inputs=[], outputs=tab_controls, show_progress=False) diff --git a/style.css b/style.css index 731620226e9..f8b42636d7f 100644 --- a/style.css +++ b/style.css @@ -840,8 +840,16 @@ footer { /* extra networks UI */ -.extra-page .prompt{ - margin: 0 0 0.5em 0; +.extra-page > div.gap{ + gap: 0; +} + +.extra-page-prompts{ + margin-bottom: 0; +} + +.extra-page-prompts.extra-page-prompts-active{ + margin-bottom: 1em; } .extra-network-cards{ From a15dd151ffb4d11556028b34561058bc44930427 Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Sun, 26 Nov 2023 21:55:50 +0900 Subject: [PATCH 0307/1174] json.dump(ensure_ascii=False) improve json readability --- modules/cache.py | 2 +- modules/options.py | 2 +- modules/ui_extensions.py | 2 +- modules/ui_extra_networks_user_metadata.py | 2 +- modules/ui_loadsave.py | 2 +- 5 files changed, 5 insertions(+), 5 deletions(-) diff --git a/modules/cache.py b/modules/cache.py index ff26a2132d9..2d37e7b99d5 100644 --- a/modules/cache.py +++ b/modules/cache.py @@ -32,7 +32,7 @@ def thread_func(): with cache_lock: cache_filename_tmp = cache_filename + "-" with open(cache_filename_tmp, "w", encoding="utf8") as file: - json.dump(cache_data, file, indent=4) + json.dump(cache_data, file, indent=4, ensure_ascii=False) os.replace(cache_filename_tmp, cache_filename) diff --git a/modules/options.py b/modules/options.py index 7703d80ec76..40cb4799139 100644 --- a/modules/options.py +++ b/modules/options.py @@ -158,7 +158,7 @@ def save(self, filename): assert not cmd_opts.freeze_settings, "saving settings is disabled" with open(filename, "w", encoding="utf8") as file: - json.dump(self.data, file, indent=4) + json.dump(self.data, file, indent=4, ensure_ascii=False) def same_type(self, x, y): if x is None or y is None: diff --git a/modules/ui_extensions.py b/modules/ui_extensions.py index c0a73b57380..96dc9db2ce2 100644 --- a/modules/ui_extensions.py +++ b/modules/ui_extensions.py @@ -65,7 +65,7 @@ def save_config_state(name): filename = os.path.join(config_states_dir, f"{timestamp}_{name}.json") print(f"Saving backup of webui/extension state to {filename}.") with open(filename, "w", encoding="utf-8") as f: - json.dump(current_config_state, f, indent=4) + json.dump(current_config_state, f, indent=4, ensure_ascii=False) config_states.list_config_states() new_value = next(iter(config_states.all_config_states.keys()), "Current") new_choices = ["Current"] + list(config_states.all_config_states.keys()) diff --git a/modules/ui_extra_networks_user_metadata.py b/modules/ui_extra_networks_user_metadata.py index bfec140cc73..36a807fcdf9 100644 --- a/modules/ui_extra_networks_user_metadata.py +++ b/modules/ui_extra_networks_user_metadata.py @@ -134,7 +134,7 @@ def write_user_metadata(self, name, metadata): basename, ext = os.path.splitext(filename) with open(basename + '.json', "w", encoding="utf8") as file: - json.dump(metadata, file, indent=4) + json.dump(metadata, file, indent=4, ensure_ascii=False) def save_user_metadata(self, name, desc, notes): user_metadata = self.get_user_metadata(name) diff --git a/modules/ui_loadsave.py b/modules/ui_loadsave.py index eb20ff258a4..7826786ccde 100644 --- a/modules/ui_loadsave.py +++ b/modules/ui_loadsave.py @@ -141,7 +141,7 @@ def read_from_file(self): def write_to_file(self, current_ui_settings): with open(self.filename, "w", encoding="utf8") as file: - json.dump(current_ui_settings, file, indent=4) + json.dump(current_ui_settings, file, indent=4, ensure_ascii=False) def dump_defaults(self): """saves default values to a file unless tjhe file is present and there was an error loading default values at start""" From f0f100e67b78f686dc73cf3c8cad422e45cc9b8a Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sun, 26 Nov 2023 17:56:16 +0300 Subject: [PATCH 0308/1174] add categories to settings --- javascript/settings.js | 25 +++++++++++++ modules/options.py | 75 +++++++++++++++++++++++++++++++++++---- modules/shared_options.py | 49 ++++++++++++++----------- style.css | 9 +++++ 4 files changed, 130 insertions(+), 28 deletions(-) diff --git a/javascript/settings.js b/javascript/settings.js index 4e79ec003a5..e6009290ab3 100644 --- a/javascript/settings.js +++ b/javascript/settings.js @@ -44,3 +44,28 @@ onUiLoaded(function() { buttonShowAllPages.addEventListener("click", settingsShowAllTabs); }); + + +onOptionsChanged(function() { + if (gradioApp().querySelector('#settings .settings-category')) return; + + var sectionMap = {}; + gradioApp().querySelectorAll('#settings > div > button').forEach(function(x) { + sectionMap[x.textContent.trim()] = x; + }); + + opts._categories.forEach(function(x) { + var section = x[0]; + var category = x[1]; + + var span = document.createElement('SPAN'); + span.textContent = category; + span.className = 'settings-category'; + + var sectionElem = sectionMap[section]; + if (!sectionElem) return; + + sectionElem.parentElement.insertBefore(span, sectionElem); + }); +}); + diff --git a/modules/options.py b/modules/options.py index 40cb4799139..4fead690cea 100644 --- a/modules/options.py +++ b/modules/options.py @@ -1,5 +1,6 @@ import json import sys +from dataclasses import dataclass import gradio as gr @@ -8,13 +9,14 @@ class OptionInfo: - def __init__(self, default=None, label="", component=None, component_args=None, onchange=None, section=None, refresh=None, comment_before='', comment_after='', infotext=None, restrict_api=False): + def __init__(self, default=None, label="", component=None, component_args=None, onchange=None, section=None, refresh=None, comment_before='', comment_after='', infotext=None, restrict_api=False, category_id=None): self.default = default self.label = label self.component = component self.component_args = component_args self.onchange = onchange self.section = section + self.category_id = category_id self.refresh = refresh self.do_not_save = False @@ -63,7 +65,11 @@ def __init__(self, text): def options_section(section_identifier, options_dict): for v in options_dict.values(): - v.section = section_identifier + if len(section_identifier) == 2: + v.section = section_identifier + elif len(section_identifier) == 3: + v.section = section_identifier[0:2] + v.category_id = section_identifier[2] return options_dict @@ -206,6 +212,17 @@ def dumpjson(self): d = {k: self.data.get(k, v.default) for k, v in self.data_labels.items()} d["_comments_before"] = {k: v.comment_before for k, v in self.data_labels.items() if v.comment_before is not None} d["_comments_after"] = {k: v.comment_after for k, v in self.data_labels.items() if v.comment_after is not None} + + item_categories = {} + for item in self.data_labels.values(): + category = categories.mapping.get(item.category_id) + category = "Uncategorized" if category is None else category.label + if category not in item_categories: + item_categories[category] = item.section[1] + + # _categories is a list of pairs: [section, category]. Each section (a setting page) will get a special heading above it with the category as text. + d["_categories"] = [[v, k] for k, v in item_categories.items()] + [["Defaults", "Other"]] + return json.dumps(d) def add_option(self, key, info): @@ -214,15 +231,40 @@ def add_option(self, key, info): self.data[key] = info.default def reorder(self): - """reorder settings so that all items related to section always go together""" + """Reorder settings so that: + - all items related to section always go together + - all sections belonging to a category go together + - sections inside a category are ordered alphabetically + - categories are ordered by creation order + + Category is a superset of sections: for category "postprocessing" there could be multiple sections: "face restoration", "upscaling". + + This function also changes items' category_id so that all items belonging to a section have the same category_id. + """ + + category_ids = {} + section_categories = {} - section_ids = {} settings_items = self.data_labels.items() for _, item in settings_items: - if item.section not in section_ids: - section_ids[item.section] = len(section_ids) + if item.section not in section_categories: + section_categories[item.section] = item.category_id + + for _, item in settings_items: + item.category_id = section_categories.get(item.section) + + for category_id in categories.mapping: + if category_id not in category_ids: + category_ids[category_id] = len(category_ids) - self.data_labels = dict(sorted(settings_items, key=lambda x: section_ids[x[1].section])) + def sort_key(x): + item: OptionInfo = x[1] + category_order = category_ids.get(item.category_id, len(category_ids)) + section_order = item.section[1] + + return category_order, section_order + + self.data_labels = dict(sorted(settings_items, key=sort_key)) def cast_value(self, key, value): """casts an arbitrary to the same type as this setting's value with key @@ -245,3 +287,22 @@ def cast_value(self, key, value): value = expected_type(value) return value + + +@dataclass +class OptionsCategory: + id: str + label: str + +class OptionsCategories: + def __init__(self): + self.mapping = {} + + def register_category(self, category_id, label): + if category_id in self.mapping: + return category_id + + self.mapping[category_id] = OptionsCategory(category_id, label) + + +categories = OptionsCategories() diff --git a/modules/shared_options.py b/modules/shared_options.py index 9bcd7914b78..04e68a712ef 100644 --- a/modules/shared_options.py +++ b/modules/shared_options.py @@ -3,7 +3,7 @@ from modules import localization, ui_components, shared_items, shared, interrogate, shared_gradio_themes from modules.paths_internal import models_path, script_path, data_path, sd_configs_path, sd_default_config, sd_model_file, default_sd_model_file, extensions_dir, extensions_builtin_dir # noqa: F401 from modules.shared_cmd_options import cmd_opts -from modules.options import options_section, OptionInfo, OptionHTML +from modules.options import options_section, OptionInfo, OptionHTML, categories options_templates = {} hide_dirs = shared.hide_dirs @@ -21,7 +21,14 @@ "outdir_init_images" } -options_templates.update(options_section(('saving-images', "Saving images/grids"), { +categories.register_category("saving", "Saving images") +categories.register_category("sd", "Stable Diffusion") +categories.register_category("ui", "User Interface") +categories.register_category("system", "System") +categories.register_category("postprocessing", "Postprocessing") +categories.register_category("training", "Training") + +options_templates.update(options_section(('saving-images', "Saving images/grids", "saving"), { "samples_save": OptionInfo(True, "Always save all generated images"), "samples_format": OptionInfo('png', 'File format for images'), "samples_filename_pattern": OptionInfo("", "Images filename pattern", component_args=hide_dirs).link("wiki", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Custom-Images-Filename-Name-and-Subdirectory"), @@ -67,7 +74,7 @@ "notification_volume": OptionInfo(100, "Notification sound volume", gr.Slider, {"minimum": 0, "maximum": 100, "step": 1}).info("in %"), })) -options_templates.update(options_section(('saving-paths', "Paths for saving"), { +options_templates.update(options_section(('saving-paths', "Paths for saving", "saving"), { "outdir_samples": OptionInfo("", "Output directory for images; if empty, defaults to three directories below", component_args=hide_dirs), "outdir_txt2img_samples": OptionInfo("outputs/txt2img-images", 'Output directory for txt2img images', component_args=hide_dirs), "outdir_img2img_samples": OptionInfo("outputs/img2img-images", 'Output directory for img2img images', component_args=hide_dirs), @@ -79,7 +86,7 @@ "outdir_init_images": OptionInfo("outputs/init-images", "Directory for saving init images when using img2img", component_args=hide_dirs), })) -options_templates.update(options_section(('saving-to-dirs', "Saving to a directory"), { +options_templates.update(options_section(('saving-to-dirs', "Saving to a directory", "saving"), { "save_to_dirs": OptionInfo(True, "Save images to a subdirectory"), "grid_save_to_dirs": OptionInfo(True, "Save grids to a subdirectory"), "use_save_to_dirs_for_ui": OptionInfo(False, "When using \"Save\" button, save images to a subdirectory"), @@ -87,21 +94,21 @@ "directories_max_prompt_words": OptionInfo(8, "Max prompt words for [prompt_words] pattern", gr.Slider, {"minimum": 1, "maximum": 20, "step": 1, **hide_dirs}), })) -options_templates.update(options_section(('upscaling', "Upscaling"), { +options_templates.update(options_section(('upscaling', "Upscaling", "postprocessing"), { "ESRGAN_tile": OptionInfo(192, "Tile size for ESRGAN upscalers.", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}).info("0 = no tiling"), "ESRGAN_tile_overlap": OptionInfo(8, "Tile overlap for ESRGAN upscalers.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}).info("Low values = visible seam"), "realesrgan_enabled_models": OptionInfo(["R-ESRGAN 4x+", "R-ESRGAN 4x+ Anime6B"], "Select which Real-ESRGAN models to show in the web UI.", gr.CheckboxGroup, lambda: {"choices": shared_items.realesrgan_models_names()}), "upscaler_for_img2img": OptionInfo(None, "Upscaler for img2img", gr.Dropdown, lambda: {"choices": [x.name for x in shared.sd_upscalers]}), })) -options_templates.update(options_section(('face-restoration', "Face restoration"), { +options_templates.update(options_section(('face-restoration', "Face restoration", "postprocessing"), { "face_restoration": OptionInfo(False, "Restore faces", infotext='Face restoration').info("will use a third-party model on generation result to reconstruct faces"), "face_restoration_model": OptionInfo("CodeFormer", "Face restoration model", gr.Radio, lambda: {"choices": [x.name() for x in shared.face_restorers]}), "code_former_weight": OptionInfo(0.5, "CodeFormer weight", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01}).info("0 = maximum effect; 1 = minimum effect"), "face_restoration_unload": OptionInfo(False, "Move face restoration model from VRAM into RAM after processing"), })) -options_templates.update(options_section(('system', "System"), { +options_templates.update(options_section(('system', "System", "system"), { "auto_launch_browser": OptionInfo("Local", "Automatically open webui in browser on startup", gr.Radio, lambda: {"choices": ["Disable", "Local", "Remote"]}), "enable_console_prompts": OptionInfo(shared.cmd_opts.enable_console_prompts, "Print prompts to console when generating with txt2img and img2img."), "show_warnings": OptionInfo(False, "Show warnings in console.").needs_reload_ui(), @@ -116,13 +123,13 @@ "dump_stacks_on_signal": OptionInfo(False, "Print stack traces before exiting the program with ctrl+c."), })) -options_templates.update(options_section(('API', "API"), { +options_templates.update(options_section(('API', "API", "system"), { "api_enable_requests": OptionInfo(True, "Allow http:// and https:// URLs for input images in API", restrict_api=True), "api_forbid_local_requests": OptionInfo(True, "Forbid URLs to local resources", restrict_api=True), "api_useragent": OptionInfo("", "User agent for requests", restrict_api=True), })) -options_templates.update(options_section(('training', "Training"), { +options_templates.update(options_section(('training', "Training", "training"), { "unload_models_when_training": OptionInfo(False, "Move VAE and CLIP to RAM when training if possible. Saves VRAM."), "pin_memory": OptionInfo(False, "Turn on pin_memory for DataLoader. Makes training slightly faster but can increase memory usage."), "save_optimizer_state": OptionInfo(False, "Saves Optimizer state as separate *.optim file. Training of embedding or HN can be resumed with the matching optim file."), @@ -137,7 +144,7 @@ "training_tensorboard_flush_every": OptionInfo(120, "How often, in seconds, to flush the pending tensorboard events and summaries to disk."), })) -options_templates.update(options_section(('sd', "Stable Diffusion"), { +options_templates.update(options_section(('sd', "Stable Diffusion", "sd"), { "sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": shared_items.list_checkpoint_tiles(shared.opts.sd_checkpoint_dropdown_use_short)}, refresh=shared_items.refresh_checkpoints, infotext='Model hash'), "sd_checkpoints_limit": OptionInfo(1, "Maximum number of checkpoints loaded at the same time", gr.Slider, {"minimum": 1, "maximum": 10, "step": 1}), "sd_checkpoints_keep_in_cpu": OptionInfo(True, "Only keep one model on device").info("will keep models other than the currently used one in RAM rather than VRAM"), @@ -154,14 +161,14 @@ "hires_fix_refiner_pass": OptionInfo("second pass", "Hires fix: which pass to enable refiner for", gr.Radio, {"choices": ["first pass", "second pass", "both passes"]}, infotext="Hires refiner"), })) -options_templates.update(options_section(('sdxl', "Stable Diffusion XL"), { +options_templates.update(options_section(('sdxl', "Stable Diffusion XL", "sd"), { "sdxl_crop_top": OptionInfo(0, "crop top coordinate"), "sdxl_crop_left": OptionInfo(0, "crop left coordinate"), "sdxl_refiner_low_aesthetic_score": OptionInfo(2.5, "SDXL low aesthetic score", gr.Number).info("used for refiner model negative prompt"), "sdxl_refiner_high_aesthetic_score": OptionInfo(6.0, "SDXL high aesthetic score", gr.Number).info("used for refiner model prompt"), })) -options_templates.update(options_section(('vae', "VAE"), { +options_templates.update(options_section(('vae', "VAE", "sd"), { "sd_vae_explanation": OptionHTML(""" VAE is a neural network that transforms a standard RGB image into latent space representation and back. Latent space representation is what stable diffusion is working on during sampling @@ -176,7 +183,7 @@ "sd_vae_decode_method": OptionInfo("Full", "VAE type for decode", gr.Radio, {"choices": ["Full", "TAESD"]}, infotext='VAE Decoder').info("method to decode latent to image"), })) -options_templates.update(options_section(('img2img', "img2img"), { +options_templates.update(options_section(('img2img', "img2img", "sd"), { "inpainting_mask_weight": OptionInfo(1.0, "Inpainting conditioning mask strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}, infotext='Conditional mask weight'), "initial_noise_multiplier": OptionInfo(1.0, "Noise multiplier for img2img", gr.Slider, {"minimum": 0.0, "maximum": 1.5, "step": 0.001}, infotext='Noise multiplier'), "img2img_extra_noise": OptionInfo(0.0, "Extra noise multiplier for img2img and hires fix", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}, infotext='Extra noise').info("0 = disabled (default); should be lower than denoising strength"), @@ -192,7 +199,7 @@ "img2img_batch_show_results_limit": OptionInfo(32, "Show the first N batch img2img results in UI", gr.Slider, {"minimum": -1, "maximum": 1000, "step": 1}).info('0: disable, -1: show all images. Too many images can cause lag'), })) -options_templates.update(options_section(('optimizations', "Optimizations"), { +options_templates.update(options_section(('optimizations', "Optimizations", "sd"), { "cross_attention_optimization": OptionInfo("Automatic", "Cross attention optimization", gr.Dropdown, lambda: {"choices": shared_items.cross_attention_optimizations()}), "s_min_uncond": OptionInfo(0.0, "Negative Guidance minimum sigma", gr.Slider, {"minimum": 0.0, "maximum": 15.0, "step": 0.01}).link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/9177").info("skip negative prompt for some steps when the image is almost ready; 0=disable, higher=faster"), "token_merging_ratio": OptionInfo(0.0, "Token merging ratio", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}, infotext='Token merging ratio').link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/9256").info("0=disable, higher=faster"), @@ -203,7 +210,7 @@ "batch_cond_uncond": OptionInfo(True, "Batch cond/uncond").info("do both conditional and unconditional denoising in one batch; uses a bit more VRAM during sampling, but improves speed; previously this was controlled by --always-batch-cond-uncond comandline argument"), })) -options_templates.update(options_section(('compatibility', "Compatibility"), { +options_templates.update(options_section(('compatibility', "Compatibility", "sd"), { "use_old_emphasis_implementation": OptionInfo(False, "Use old emphasis implementation. Can be useful to reproduce old seeds."), "use_old_karras_scheduler_sigmas": OptionInfo(False, "Use old karras scheduler sigmas (0.1 to 10)."), "no_dpmpp_sde_batch_determinism": OptionInfo(False, "Do not make DPM++ SDE deterministic across different batch sizes."), @@ -228,7 +235,7 @@ "deepbooru_filter_tags": OptionInfo("", "deepbooru: filter out those tags").info("separate by comma"), })) -options_templates.update(options_section(('extra_networks', "Extra Networks"), { +options_templates.update(options_section(('extra_networks', "Extra Networks", "sd"), { "extra_networks_show_hidden_directories": OptionInfo(True, "Show hidden directories").info("directory is hidden if its name starts with \".\"."), "extra_networks_hidden_models": OptionInfo("When searched", "Show cards for models in hidden directories", gr.Radio, {"choices": ["Always", "When searched", "Never"]}).info('"When searched" option will only show the item when the search string has 4 characters or more'), "extra_networks_default_multiplier": OptionInfo(1.0, "Default multiplier for extra networks", gr.Slider, {"minimum": 0.0, "maximum": 2.0, "step": 0.01}), @@ -245,7 +252,7 @@ "sd_hypernetwork": OptionInfo("None", "Add hypernetwork to prompt", gr.Dropdown, lambda: {"choices": ["None", *shared.hypernetworks]}, refresh=shared_items.reload_hypernetworks), })) -options_templates.update(options_section(('ui', "User interface"), { +options_templates.update(options_section(('ui', "User interface", "ui"), { "localization": OptionInfo("None", "Localization", gr.Dropdown, lambda: {"choices": ["None"] + list(localization.localizations.keys())}, refresh=lambda: localization.list_localizations(cmd_opts.localizations_dir)).needs_reload_ui(), "gradio_theme": OptionInfo("Default", "Gradio theme", ui_components.DropdownEditable, lambda: {"choices": ["Default"] + shared_gradio_themes.gradio_hf_hub_themes}).info("you can also manually enter any of themes from the gallery.").needs_reload_ui(), "gradio_themes_cache": OptionInfo(True, "Cache gradio themes locally").info("disable to update the selected Gradio theme"), @@ -280,7 +287,7 @@ })) -options_templates.update(options_section(('infotext', "Infotext"), { +options_templates.update(options_section(('infotext', "Infotext", "ui"), { "add_model_hash_to_info": OptionInfo(True, "Add model hash to generation information"), "add_model_name_to_info": OptionInfo(True, "Add model name to generation information"), "add_user_name_to_info": OptionInfo(False, "Add user name to generation information when authenticated"), @@ -295,7 +302,7 @@ })) -options_templates.update(options_section(('ui', "Live previews"), { +options_templates.update(options_section(('ui', "Live previews", "ui"), { "show_progressbar": OptionInfo(True, "Show progressbar"), "live_previews_enable": OptionInfo(True, "Show live previews of the created image"), "live_previews_image_format": OptionInfo("png", "Live preview file format", gr.Radio, {"choices": ["jpeg", "png", "webp"]}), @@ -308,7 +315,7 @@ "live_preview_fast_interrupt": OptionInfo(False, "Return image with chosen live preview method on interrupt").info("makes interrupts faster"), })) -options_templates.update(options_section(('sampler-params', "Sampler parameters"), { +options_templates.update(options_section(('sampler-params', "Sampler parameters", "sd"), { "hide_samplers": OptionInfo([], "Hide samplers in user interface", gr.CheckboxGroup, lambda: {"choices": [x.name for x in shared_items.list_samplers()]}).needs_reload_ui(), "eta_ddim": OptionInfo(0.0, "Eta for DDIM", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}, infotext='Eta DDIM').info("noise multiplier; higher = more unpredictable results"), "eta_ancestral": OptionInfo(1.0, "Eta for k-diffusion samplers", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}, infotext='Eta').info("noise multiplier; currently only applies to ancestral samplers (i.e. Euler a) and SDE samplers"), @@ -330,7 +337,7 @@ 'uni_pc_lower_order_final': OptionInfo(True, "UniPC lower order final", infotext='UniPC lower order final'), })) -options_templates.update(options_section(('postprocessing', "Postprocessing"), { +options_templates.update(options_section(('postprocessing', "Postprocessing", "postprocessing"), { 'postprocessing_enable_in_main_ui': OptionInfo([], "Enable postprocessing operations in txt2img and img2img tabs", ui_components.DropdownMulti, lambda: {"choices": [x.name for x in shared_items.postprocessing_scripts()]}), 'postprocessing_operation_order': OptionInfo([], "Postprocessing operation order", ui_components.DropdownMulti, lambda: {"choices": [x.name for x in shared_items.postprocessing_scripts()]}), 'upscaling_max_images_in_cache': OptionInfo(5, "Maximum number of images in upscaling cache", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}), diff --git a/style.css b/style.css index f8b42636d7f..6e3ca84113e 100644 --- a/style.css +++ b/style.css @@ -462,6 +462,15 @@ div.toprow-compact-tools{ padding: 4px; } +#settings > div.tab-nav .settings-category{ + display: block; + margin: 1em 0 0.25em 0; + font-weight: bold; + text-decoration: underline; + cursor: default; + user-select: none; +} + #settings_result{ height: 1.4em; margin: 0 1.2em; From 1f6844eb7e3a91639b2977d1e0cfbb9bf98baea7 Mon Sep 17 00:00:00 2001 From: Jabasukuriputo Wang Date: Sun, 26 Nov 2023 10:04:39 -0600 Subject: [PATCH 0309/1174] also consider extension url --- modules/ui_extensions.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/modules/ui_extensions.py b/modules/ui_extensions.py index b670888111a..252e6ff2c23 100644 --- a/modules/ui_extensions.py +++ b/modules/ui_extensions.py @@ -452,6 +452,7 @@ def get_date(info: dict, key): def refresh_available_extensions_from_data(hide_tags, sort_column, filter_text=""): extlist = available_extensions["extensions"] installed_extensions = {extension.name for extension in extensions.extensions} + installed_extension_urls = {normalize_git_url(extension.remote) for extension in extensions.extensions if extension.remote is not None} tags = available_extensions.get("tags", {}) tags_to_hide = set(hide_tags) @@ -484,7 +485,7 @@ def refresh_available_extensions_from_data(hide_tags, sort_column, filter_text=" if url is None: continue - existing = get_extension_dirname_from_url(url) in installed_extensions + existing = get_extension_dirname_from_url(url) in installed_extensions or normalize_git_url(url) in installed_extension_urls extension_tags = extension_tags + ["installed"] if existing else extension_tags if any(x for x in extension_tags if x in tags_to_hide): From b30cc87b786d32f2385cfecf40a2469ee3a96ab5 Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Mon, 27 Nov 2023 13:15:17 +0900 Subject: [PATCH 0310/1174] add Block component creation callback --- modules/gradio_extensons.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/modules/gradio_extensons.py b/modules/gradio_extensons.py index e6b6835adcc..7d88dc984bb 100644 --- a/modules/gradio_extensons.py +++ b/modules/gradio_extensons.py @@ -47,10 +47,20 @@ def Block_get_config(self): def BlockContext_init(self, *args, **kwargs): + if scripts.scripts_current is not None: + scripts.scripts_current.before_component(self, **kwargs) + + scripts.script_callbacks.before_component_callback(self, **kwargs) + res = original_BlockContext_init(self, *args, **kwargs) add_classes_to_gradio_component(self) + scripts.script_callbacks.after_component_callback(self, **kwargs) + + if scripts.scripts_current is not None: + scripts.scripts_current.after_component(self, **kwargs) + return res From 8a6e4bda21dddef3ab2e70a05d71b587b6c8b04b Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Mon, 27 Nov 2023 14:00:17 +0900 Subject: [PATCH 0311/1174] catch uncaught exception with ui creation scripts prevent total webui crash --- modules/scripts.py | 54 +++++++++++++++++++++++++--------------------- 1 file changed, 29 insertions(+), 25 deletions(-) diff --git a/modules/scripts.py b/modules/scripts.py index b0689a23de1..961d032cea0 100644 --- a/modules/scripts.py +++ b/modules/scripts.py @@ -570,40 +570,44 @@ def create_script_ui(self, script): if controls is None: return - script.name = wrap_call(script.title, script.filename, "title", default=script.filename).lower() - api_args = [] + try: + script.name = wrap_call(script.title, script.filename, "title", default=script.filename).lower() + api_args = [] - for control in controls: - control.custom_script_source = os.path.basename(script.filename) + for control in controls: + control.custom_script_source = os.path.basename(script.filename) - arg_info = api_models.ScriptArg(label=control.label or "") + arg_info = api_models.ScriptArg(label=control.label or "") - for field in ("value", "minimum", "maximum", "step"): - v = getattr(control, field, None) - if v is not None: - setattr(arg_info, field, v) + for field in ("value", "minimum", "maximum", "step"): + v = getattr(control, field, None) + if v is not None: + setattr(arg_info, field, v) - choices = getattr(control, 'choices', None) # as of gradio 3.41, some items in choices are strings, and some are tuples where the first elem is the string - if choices is not None: - arg_info.choices = [x[0] if isinstance(x, tuple) else x for x in choices] + choices = getattr(control, 'choices', None) # as of gradio 3.41, some items in choices are strings, and some are tuples where the first elem is the string + if choices is not None: + arg_info.choices = [x[0] if isinstance(x, tuple) else x for x in choices] - api_args.append(arg_info) + api_args.append(arg_info) - script.api_info = api_models.ScriptInfo( - name=script.name, - is_img2img=script.is_img2img, - is_alwayson=script.alwayson, - args=api_args, - ) + script.api_info = api_models.ScriptInfo( + name=script.name, + is_img2img=script.is_img2img, + is_alwayson=script.alwayson, + args=api_args, + ) - if script.infotext_fields is not None: - self.infotext_fields += script.infotext_fields + if script.infotext_fields is not None: + self.infotext_fields += script.infotext_fields - if script.paste_field_names is not None: - self.paste_field_names += script.paste_field_names + if script.paste_field_names is not None: + self.paste_field_names += script.paste_field_names - self.inputs += controls - script.args_to = len(self.inputs) + self.inputs += controls + script.args_to = len(self.inputs) + + except Exception: + errors.report(f"Error creating UI for {script.name}: ", exc_info=True) def setup_ui_for_section(self, section, scriptlist=None): if scriptlist is None: From 9621ca4d64bbe59880d869b923e1572f1475a52b Mon Sep 17 00:00:00 2001 From: Charlie Joynt Date: Mon, 27 Nov 2023 11:39:50 +0000 Subject: [PATCH 0312/1174] Allow use of mutiple styles csv files --- modules/styles.py | 203 ++++++++++++++++++++++++++++++++++++++-------- 1 file changed, 171 insertions(+), 32 deletions(-) diff --git a/modules/styles.py b/modules/styles.py index 0740fe1b1c0..974d3289b26 100644 --- a/modules/styles.py +++ b/modules/styles.py @@ -1,4 +1,5 @@ import csv +import fnmatch import os import os.path import re @@ -10,6 +11,23 @@ class PromptStyle(typing.NamedTuple): name: str prompt: str negative_prompt: str + path: str = None + + +def clean_text(text: str) -> str: + """ + Iterating through a list of regular expressions and replacement strings, we + clean up the prompt and style text to make it easier to match against each + other. + """ + re_list = [ + ("multiple commas", re.compile("(,+\s+)+,?"), ", "), + ("multiple spaces", re.compile("\s{2,}"), " "), + ] + for _, regex, replace in re_list: + text = regex.sub(replace, text) + + return text.strip(", ") def merge_prompts(style_prompt: str, prompt: str) -> str: @@ -26,41 +44,64 @@ def apply_styles_to_prompt(prompt, styles): for style in styles: prompt = merge_prompts(style, prompt) - return prompt + return clean_text(prompt) -re_spaces = re.compile(" +") +def unwrap_style_text_from_prompt(style_text, prompt): + """ + Checks the prompt to see if the style text is wrapped around it. If so, + returns True plus the prompt text without the style text. Otherwise, returns + False with the original prompt. - -def extract_style_text_from_prompt(style_text, prompt): - stripped_prompt = re.sub(re_spaces, " ", prompt.strip()) - stripped_style_text = re.sub(re_spaces, " ", style_text.strip()) + Note that the "cleaned" version of the style text is only used for matching + purposes here. It isn't returned; the original style text is not modified. + """ + stripped_prompt = clean_text(prompt) + stripped_style_text = clean_text(style_text) if "{prompt}" in stripped_style_text: - left, right = stripped_style_text.split("{prompt}", 2) + # Work out whether the prompt is wrapped in the style text. If so, we + # return True and the "inner" prompt text that isn't part of the style. + try: + left, right = stripped_style_text.split("{prompt}", 2) + except ValueError as e: + # If the style text has multple "{prompt}"s, we can't split it into + # two parts. This is an error, but we can't do anything about it. + print("Unable to compare style text to prompt:`n{style_text}") + print(f"Error: {e}") + return False, prompt if stripped_prompt.startswith(left) and stripped_prompt.endswith(right): - prompt = stripped_prompt[len(left):len(stripped_prompt)-len(right)] + prompt = stripped_prompt[len(left) : len(stripped_prompt) - len(right)] return True, prompt else: + # Work out whether the given prompt ends with the style text. If so, we + # return True and the prompt text up to where the style text starts. if stripped_prompt.endswith(stripped_style_text): - prompt = stripped_prompt[:len(stripped_prompt)-len(stripped_style_text)] - - if prompt.endswith(', '): + prompt = stripped_prompt[: len(stripped_prompt) - len(stripped_style_text)] + if prompt.endswith(", "): prompt = prompt[:-2] - return True, prompt return False, prompt -def extract_style_from_prompts(style: PromptStyle, prompt, negative_prompt): +def extract_original_prompts(style: PromptStyle, prompt, negative_prompt): + """ + Takes a style and compares it to the prompt and negative prompt. If the style + matches, returns True plus the prompt and negative prompt with the style text + removed. Otherwise, returns False with the original prompt and negative prompt. + """ if not style.prompt and not style.negative_prompt: return False, prompt, negative_prompt - match_positive, extracted_positive = extract_style_text_from_prompt(style.prompt, prompt) + match_positive, extracted_positive = unwrap_style_text_from_prompt( + style.prompt, prompt + ) if not match_positive: return False, prompt, negative_prompt - match_negative, extracted_negative = extract_style_text_from_prompt(style.negative_prompt, negative_prompt) + match_negative, extracted_negative = unwrap_style_text_from_prompt( + style.negative_prompt, negative_prompt + ) if not match_negative: return False, prompt, negative_prompt @@ -69,25 +110,88 @@ def extract_style_from_prompts(style: PromptStyle, prompt, negative_prompt): class StyleDatabase: def __init__(self, path: str): - self.no_style = PromptStyle("None", "", "") + self.no_style = PromptStyle("None", "", "", None) self.styles = {} self.path = path + folder, file = os.path.split(self.path) + self.default_file = file.split("*")[0] + ".csv" + if self.default_file == ".csv": + self.default_file = "styles.csv" + self.default_path = os.path.join(folder, self.default_file) + + self.prompt_fields = [field for field in PromptStyle._fields if field != "path"] + self.reload() def reload(self): + """ + Clears the style database and reloads the styles from the CSV file(s) + matching the path used to initialize the database. + """ self.styles.clear() - if not os.path.exists(self.path): + path, filename = os.path.split(self.path) + + if "*" in filename: + fileglob = filename.split("*")[0] + "*.csv" + filelist = [] + for file in os.listdir(path): + if fnmatch.fnmatch(file, fileglob): + filelist.append(file) + # Add a visible divider to the style list + half_len = round(len(file) / 2) + divider = f"{'-' * (20 - half_len)} {file.upper()}" + divider = f"{divider} {'-' * (40 - len(divider))}" + self.styles[divider] = PromptStyle( + f"{divider}", None, None, "do_not_save" + ) + # Add styles from this CSV file + self.load_from_csv(os.path.join(path, file)) + if len(filelist) == 0: + print(f"No styles found in {path} matching {fileglob}") + return + elif not os.path.exists(self.path): + print(f"Style database not found: {self.path}") return + else: + self.load_from_csv(self.path) - with open(self.path, "r", encoding="utf-8-sig", newline='') as file: + def load_from_csv(self, path: str): + with open(path, "r", encoding="utf-8-sig", newline="") as file: reader = csv.DictReader(file, skipinitialspace=True) for row in reader: + # Ignore empty rows or rows starting with a comment + if not row or row["name"].startswith("#"): + continue # Support loading old CSV format with "name, text"-columns prompt = row["prompt"] if "prompt" in row else row["text"] negative_prompt = row.get("negative_prompt", "") - self.styles[row["name"]] = PromptStyle(row["name"], prompt, negative_prompt) + # Add style to database + self.styles[row["name"]] = PromptStyle( + row["name"], prompt, negative_prompt, path + ) + + def get_style_paths(self) -> list(): + """ + Returns a list of all distinct paths, including the default path, of + files that styles are loaded from.""" + # Update any styles without a path to the default path + for style in list(self.styles.values()): + if not style.path: + self.styles[style.name] = style._replace(path=self.default_path) + + # Create a list of all distinct paths, including the default path + style_paths = set() + style_paths.add(self.default_path) + for _, style in self.styles.items(): + if style.path: + style_paths.add(style.path) + + # Remove any paths for styles that are just list dividers + style_paths.remove("do_not_save") + + return list(style_paths) def get_style_prompts(self, styles): return [self.styles.get(x, self.no_style).prompt for x in styles] @@ -96,20 +200,53 @@ def get_negative_style_prompts(self, styles): return [self.styles.get(x, self.no_style).negative_prompt for x in styles] def apply_styles_to_prompt(self, prompt, styles): - return apply_styles_to_prompt(prompt, [self.styles.get(x, self.no_style).prompt for x in styles]) + return apply_styles_to_prompt( + prompt, [self.styles.get(x, self.no_style).prompt for x in styles] + ) def apply_negative_styles_to_prompt(self, prompt, styles): - return apply_styles_to_prompt(prompt, [self.styles.get(x, self.no_style).negative_prompt for x in styles]) - - def save_styles(self, path: str) -> None: - # Always keep a backup file around - if os.path.exists(path): - shutil.copy(path, f"{path}.bak") - - with open(path, "w", encoding="utf-8-sig", newline='') as file: - writer = csv.DictWriter(file, fieldnames=PromptStyle._fields) - writer.writeheader() - writer.writerows(style._asdict() for k, style in self.styles.items()) + return apply_styles_to_prompt( + prompt, [self.styles.get(x, self.no_style).negative_prompt for x in styles] + ) + + def save_styles(self, path: str = None) -> None: + # The path argument is deprecated, but kept for backwards compatibility + _ = path + + # Update any styles without a path to the default path + for style in list(self.styles.values()): + if not style.path: + self.styles[style.name] = style._replace(path=self.default_path) + + # Create a list of all distinct paths, including the default path + style_paths = set() + style_paths.add(self.default_path) + for _, style in self.styles.items(): + if style.path: + style_paths.add(style.path) + + # Remove any paths for styles that are just list dividers + style_paths.remove("do_not_save") + + csv_names = [os.path.split(path)[1].lower() for path in style_paths] + + for style_path in style_paths: + # Always keep a backup file around + if os.path.exists(style_path): + shutil.copy(style_path, f"{style_path}.bak") + + # Write the styles to the CSV file + with open(style_path, "w", encoding="utf-8-sig", newline="") as file: + writer = csv.DictWriter(file, fieldnames=self.prompt_fields) + writer.writeheader() + for style in (s for s in self.styles.values() if s.path == style_path): + # Skip style list dividers, e.g. "STYLES.CSV" + if style.name.lower().strip("# ") in csv_names: + continue + # Write style fields, ignoring the path field + writer.writerow( + {k: v for k, v in style._asdict().items() if k != "path"} + ) def extract_styles_from_prompt(self, prompt, negative_prompt): extracted = [] @@ -120,7 +257,9 @@ def extract_styles_from_prompt(self, prompt, negative_prompt): found_style = None for style in applicable_styles: - is_match, new_prompt, new_neg_prompt = extract_style_from_prompts(style, prompt, negative_prompt) + is_match, new_prompt, new_neg_prompt = extract_original_prompts( + style, prompt, negative_prompt + ) if is_match: found_style = style prompt = new_prompt From 1c64bb71402c2cd62ac98f936203437f0c4fcd02 Mon Sep 17 00:00:00 2001 From: MisterSeajay Date: Mon, 27 Nov 2023 11:57:27 +0000 Subject: [PATCH 0313/1174] bugfix for warning message (#6) --- modules/styles.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/styles.py b/modules/styles.py index 974d3289b26..e73920c70dd 100644 --- a/modules/styles.py +++ b/modules/styles.py @@ -66,7 +66,7 @@ def unwrap_style_text_from_prompt(style_text, prompt): except ValueError as e: # If the style text has multple "{prompt}"s, we can't split it into # two parts. This is an error, but we can't do anything about it. - print("Unable to compare style text to prompt:`n{style_text}") + print(f"Unable to compare style text to prompt:`n{style_text}") print(f"Error: {e}") return False, prompt if stripped_prompt.startswith(left) and stripped_prompt.endswith(right): From a75314b41f938d1e598916ecdd0f14126ae1876b Mon Sep 17 00:00:00 2001 From: MisterSeajay Date: Mon, 27 Nov 2023 12:03:42 +0000 Subject: [PATCH 0314/1174] bugfix for warning message (#6) * bugfix for warning message * bugfix error message --- modules/styles.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/styles.py b/modules/styles.py index e73920c70dd..4d218cd7ed0 100644 --- a/modules/styles.py +++ b/modules/styles.py @@ -66,7 +66,7 @@ def unwrap_style_text_from_prompt(style_text, prompt): except ValueError as e: # If the style text has multple "{prompt}"s, we can't split it into # two parts. This is an error, but we can't do anything about it. - print(f"Unable to compare style text to prompt:`n{style_text}") + print(f"Unable to compare style text to prompt:\n{style_text}") print(f"Error: {e}") return False, prompt if stripped_prompt.startswith(left) and stripped_prompt.endswith(right): From 26a0c29587da428d27fd3e6a95491776ef66bbdd Mon Sep 17 00:00:00 2001 From: Charlie Joynt Date: Mon, 27 Nov 2023 11:39:50 +0000 Subject: [PATCH 0315/1174] Allow use of mutiple styles csv files * https://github.com/AUTOMATIC1111/stable-diffusion-webui/issues/14122 Fix edge case where style text has multiple {prompt} placeholders * https://github.com/AUTOMATIC1111/stable-diffusion-webui/issues/14005 --- modules/styles.py | 203 ++++++++++++++++++++++++++++++++++++++-------- 1 file changed, 171 insertions(+), 32 deletions(-) diff --git a/modules/styles.py b/modules/styles.py index 0740fe1b1c0..4d218cd7ed0 100644 --- a/modules/styles.py +++ b/modules/styles.py @@ -1,4 +1,5 @@ import csv +import fnmatch import os import os.path import re @@ -10,6 +11,23 @@ class PromptStyle(typing.NamedTuple): name: str prompt: str negative_prompt: str + path: str = None + + +def clean_text(text: str) -> str: + """ + Iterating through a list of regular expressions and replacement strings, we + clean up the prompt and style text to make it easier to match against each + other. + """ + re_list = [ + ("multiple commas", re.compile("(,+\s+)+,?"), ", "), + ("multiple spaces", re.compile("\s{2,}"), " "), + ] + for _, regex, replace in re_list: + text = regex.sub(replace, text) + + return text.strip(", ") def merge_prompts(style_prompt: str, prompt: str) -> str: @@ -26,41 +44,64 @@ def apply_styles_to_prompt(prompt, styles): for style in styles: prompt = merge_prompts(style, prompt) - return prompt + return clean_text(prompt) -re_spaces = re.compile(" +") +def unwrap_style_text_from_prompt(style_text, prompt): + """ + Checks the prompt to see if the style text is wrapped around it. If so, + returns True plus the prompt text without the style text. Otherwise, returns + False with the original prompt. - -def extract_style_text_from_prompt(style_text, prompt): - stripped_prompt = re.sub(re_spaces, " ", prompt.strip()) - stripped_style_text = re.sub(re_spaces, " ", style_text.strip()) + Note that the "cleaned" version of the style text is only used for matching + purposes here. It isn't returned; the original style text is not modified. + """ + stripped_prompt = clean_text(prompt) + stripped_style_text = clean_text(style_text) if "{prompt}" in stripped_style_text: - left, right = stripped_style_text.split("{prompt}", 2) + # Work out whether the prompt is wrapped in the style text. If so, we + # return True and the "inner" prompt text that isn't part of the style. + try: + left, right = stripped_style_text.split("{prompt}", 2) + except ValueError as e: + # If the style text has multple "{prompt}"s, we can't split it into + # two parts. This is an error, but we can't do anything about it. + print(f"Unable to compare style text to prompt:\n{style_text}") + print(f"Error: {e}") + return False, prompt if stripped_prompt.startswith(left) and stripped_prompt.endswith(right): - prompt = stripped_prompt[len(left):len(stripped_prompt)-len(right)] + prompt = stripped_prompt[len(left) : len(stripped_prompt) - len(right)] return True, prompt else: + # Work out whether the given prompt ends with the style text. If so, we + # return True and the prompt text up to where the style text starts. if stripped_prompt.endswith(stripped_style_text): - prompt = stripped_prompt[:len(stripped_prompt)-len(stripped_style_text)] - - if prompt.endswith(', '): + prompt = stripped_prompt[: len(stripped_prompt) - len(stripped_style_text)] + if prompt.endswith(", "): prompt = prompt[:-2] - return True, prompt return False, prompt -def extract_style_from_prompts(style: PromptStyle, prompt, negative_prompt): +def extract_original_prompts(style: PromptStyle, prompt, negative_prompt): + """ + Takes a style and compares it to the prompt and negative prompt. If the style + matches, returns True plus the prompt and negative prompt with the style text + removed. Otherwise, returns False with the original prompt and negative prompt. + """ if not style.prompt and not style.negative_prompt: return False, prompt, negative_prompt - match_positive, extracted_positive = extract_style_text_from_prompt(style.prompt, prompt) + match_positive, extracted_positive = unwrap_style_text_from_prompt( + style.prompt, prompt + ) if not match_positive: return False, prompt, negative_prompt - match_negative, extracted_negative = extract_style_text_from_prompt(style.negative_prompt, negative_prompt) + match_negative, extracted_negative = unwrap_style_text_from_prompt( + style.negative_prompt, negative_prompt + ) if not match_negative: return False, prompt, negative_prompt @@ -69,25 +110,88 @@ def extract_style_from_prompts(style: PromptStyle, prompt, negative_prompt): class StyleDatabase: def __init__(self, path: str): - self.no_style = PromptStyle("None", "", "") + self.no_style = PromptStyle("None", "", "", None) self.styles = {} self.path = path + folder, file = os.path.split(self.path) + self.default_file = file.split("*")[0] + ".csv" + if self.default_file == ".csv": + self.default_file = "styles.csv" + self.default_path = os.path.join(folder, self.default_file) + + self.prompt_fields = [field for field in PromptStyle._fields if field != "path"] + self.reload() def reload(self): + """ + Clears the style database and reloads the styles from the CSV file(s) + matching the path used to initialize the database. + """ self.styles.clear() - if not os.path.exists(self.path): + path, filename = os.path.split(self.path) + + if "*" in filename: + fileglob = filename.split("*")[0] + "*.csv" + filelist = [] + for file in os.listdir(path): + if fnmatch.fnmatch(file, fileglob): + filelist.append(file) + # Add a visible divider to the style list + half_len = round(len(file) / 2) + divider = f"{'-' * (20 - half_len)} {file.upper()}" + divider = f"{divider} {'-' * (40 - len(divider))}" + self.styles[divider] = PromptStyle( + f"{divider}", None, None, "do_not_save" + ) + # Add styles from this CSV file + self.load_from_csv(os.path.join(path, file)) + if len(filelist) == 0: + print(f"No styles found in {path} matching {fileglob}") + return + elif not os.path.exists(self.path): + print(f"Style database not found: {self.path}") return + else: + self.load_from_csv(self.path) - with open(self.path, "r", encoding="utf-8-sig", newline='') as file: + def load_from_csv(self, path: str): + with open(path, "r", encoding="utf-8-sig", newline="") as file: reader = csv.DictReader(file, skipinitialspace=True) for row in reader: + # Ignore empty rows or rows starting with a comment + if not row or row["name"].startswith("#"): + continue # Support loading old CSV format with "name, text"-columns prompt = row["prompt"] if "prompt" in row else row["text"] negative_prompt = row.get("negative_prompt", "") - self.styles[row["name"]] = PromptStyle(row["name"], prompt, negative_prompt) + # Add style to database + self.styles[row["name"]] = PromptStyle( + row["name"], prompt, negative_prompt, path + ) + + def get_style_paths(self) -> list(): + """ + Returns a list of all distinct paths, including the default path, of + files that styles are loaded from.""" + # Update any styles without a path to the default path + for style in list(self.styles.values()): + if not style.path: + self.styles[style.name] = style._replace(path=self.default_path) + + # Create a list of all distinct paths, including the default path + style_paths = set() + style_paths.add(self.default_path) + for _, style in self.styles.items(): + if style.path: + style_paths.add(style.path) + + # Remove any paths for styles that are just list dividers + style_paths.remove("do_not_save") + + return list(style_paths) def get_style_prompts(self, styles): return [self.styles.get(x, self.no_style).prompt for x in styles] @@ -96,20 +200,53 @@ def get_negative_style_prompts(self, styles): return [self.styles.get(x, self.no_style).negative_prompt for x in styles] def apply_styles_to_prompt(self, prompt, styles): - return apply_styles_to_prompt(prompt, [self.styles.get(x, self.no_style).prompt for x in styles]) + return apply_styles_to_prompt( + prompt, [self.styles.get(x, self.no_style).prompt for x in styles] + ) def apply_negative_styles_to_prompt(self, prompt, styles): - return apply_styles_to_prompt(prompt, [self.styles.get(x, self.no_style).negative_prompt for x in styles]) - - def save_styles(self, path: str) -> None: - # Always keep a backup file around - if os.path.exists(path): - shutil.copy(path, f"{path}.bak") - - with open(path, "w", encoding="utf-8-sig", newline='') as file: - writer = csv.DictWriter(file, fieldnames=PromptStyle._fields) - writer.writeheader() - writer.writerows(style._asdict() for k, style in self.styles.items()) + return apply_styles_to_prompt( + prompt, [self.styles.get(x, self.no_style).negative_prompt for x in styles] + ) + + def save_styles(self, path: str = None) -> None: + # The path argument is deprecated, but kept for backwards compatibility + _ = path + + # Update any styles without a path to the default path + for style in list(self.styles.values()): + if not style.path: + self.styles[style.name] = style._replace(path=self.default_path) + + # Create a list of all distinct paths, including the default path + style_paths = set() + style_paths.add(self.default_path) + for _, style in self.styles.items(): + if style.path: + style_paths.add(style.path) + + # Remove any paths for styles that are just list dividers + style_paths.remove("do_not_save") + + csv_names = [os.path.split(path)[1].lower() for path in style_paths] + + for style_path in style_paths: + # Always keep a backup file around + if os.path.exists(style_path): + shutil.copy(style_path, f"{style_path}.bak") + + # Write the styles to the CSV file + with open(style_path, "w", encoding="utf-8-sig", newline="") as file: + writer = csv.DictWriter(file, fieldnames=self.prompt_fields) + writer.writeheader() + for style in (s for s in self.styles.values() if s.path == style_path): + # Skip style list dividers, e.g. "STYLES.CSV" + if style.name.lower().strip("# ") in csv_names: + continue + # Write style fields, ignoring the path field + writer.writerow( + {k: v for k, v in style._asdict().items() if k != "path"} + ) def extract_styles_from_prompt(self, prompt, negative_prompt): extracted = [] @@ -120,7 +257,9 @@ def extract_styles_from_prompt(self, prompt, negative_prompt): found_style = None for style in applicable_styles: - is_match, new_prompt, new_neg_prompt = extract_style_from_prompts(style, prompt, negative_prompt) + is_match, new_prompt, new_neg_prompt = extract_original_prompts( + style, prompt, negative_prompt + ) if is_match: found_style = style prompt = new_prompt From 23c36f59b4a423362d74f1ca2cc69871ae101e0e Mon Sep 17 00:00:00 2001 From: aria1th <35677394+aria1th@users.noreply.github.com> Date: Mon, 27 Nov 2023 21:10:26 +0900 Subject: [PATCH 0316/1174] Support XYZ scripts / split hires path from unet --- .../hypertile/scripts/hypertile_script.py | 11 ++-- .../hypertile/scripts/hypertile_xyz.py | 52 +++++++++++++++++++ 2 files changed, 60 insertions(+), 3 deletions(-) create mode 100644 extensions-builtin/hypertile/scripts/hypertile_xyz.py diff --git a/extensions-builtin/hypertile/scripts/hypertile_script.py b/extensions-builtin/hypertile/scripts/hypertile_script.py index 3cc29cd1fa4..b2413cc5fc3 100644 --- a/extensions-builtin/hypertile/scripts/hypertile_script.py +++ b/extensions-builtin/hypertile/scripts/hypertile_script.py @@ -1,5 +1,6 @@ import hypertile from modules import scripts, script_callbacks, shared +from scripts.hypertile_xyz import add_axis_options class ScriptHypertile(scripts.Script): @@ -17,7 +18,10 @@ def process(self, p, *args): configure_hypertile(p.width, p.height, enable_unet=shared.opts.hypertile_enable_unet) def before_hr(self, p, *args): - configure_hypertile(p.hr_upscale_to_x, p.hr_upscale_to_y, enable_unet=shared.opts.hypertile_enable_unet_secondpass or shared.opts.hypertile_enable_unet) + # exclusive hypertile seed for the second pass + if not shared.opts.hypertile_enable_unet: + hypertile.set_hypertile_seed(p.all_seeds[0]) + configure_hypertile(p.hr_upscale_to_x, p.hr_upscale_to_y, enable_unet=shared.opts.hypertile_enable_unet_secondpass) def configure_hypertile(width, height, enable_unet=True): @@ -57,12 +61,12 @@ def on_ui_settings(): "hypertile_enable_unet_secondpass": shared.OptionInfo(False, "Enable Hypertile U-Net for hires fix second pass"), "hypertile_max_depth_unet": shared.OptionInfo(3, "Hypertile U-Net max depth", gr.Slider, {"minimum": 0, "maximum": 3, "step": 1}), "hypertile_max_tile_unet": shared.OptionInfo(256, "Hypertile U-net max tile size", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}), - "hypertile_swap_size_unet": shared.OptionInfo(3, "Hypertile U-net swap size", gr.Slider, {"minimum": 0, "maximum": 6, "step": 1}), + "hypertile_swap_size_unet": shared.OptionInfo(3, "Hypertile U-net swap size", gr.Slider, {"minimum": 0, "maximum": 64, "step": 1}), "hypertile_enable_vae": shared.OptionInfo(False, "Enable Hypertile VAE").info("minimal change in the generated picture"), "hypertile_max_depth_vae": shared.OptionInfo(3, "Hypertile VAE max depth", gr.Slider, {"minimum": 0, "maximum": 3, "step": 1}), "hypertile_max_tile_vae": shared.OptionInfo(128, "Hypertile VAE max tile size", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}), - "hypertile_swap_size_vae": shared.OptionInfo(3, "Hypertile VAE swap size ", gr.Slider, {"minimum": 0, "maximum": 6, "step": 1}), + "hypertile_swap_size_vae": shared.OptionInfo(3, "Hypertile VAE swap size ", gr.Slider, {"minimum": 0, "maximum": 64, "step": 1}), } for name, opt in options.items(): @@ -71,3 +75,4 @@ def on_ui_settings(): script_callbacks.on_ui_settings(on_ui_settings) +script_callbacks.on_before_ui(add_axis_options) \ No newline at end of file diff --git a/extensions-builtin/hypertile/scripts/hypertile_xyz.py b/extensions-builtin/hypertile/scripts/hypertile_xyz.py new file mode 100644 index 00000000000..eaf7c8d7a6d --- /dev/null +++ b/extensions-builtin/hypertile/scripts/hypertile_xyz.py @@ -0,0 +1,52 @@ +from modules import scripts +xyz_grid = [x for x in scripts.scripts_data if x.script_class.__module__ == "xyz_grid.py"][0].module +from modules.shared import opts + +def int_applier(value_name:str, min_range:int = -1, max_range:int = -1): + """ + Returns a function that applies the given value to the given value_name in opts.data. + """ + # convert to int + def validate(value_name:str, value:str): + try: + value = int(value) + except: + raise ValueError(f"Value {value} for {value_name} is not an integer") + # validate value + if not min_range == -1: + assert value >= min_range, f"Value {value} for {value_name} must be greater than or equal to {min_range}" + if not max_range == -1: + assert value <= max_range, f"Value {value} for {value_name} must be less than or equal to {max_range}" + def apply_int(p, x, xs): + validate(value_name, x) + opts.data[value_name] = int(x) + return apply_int + +def bool_applier(value_name:str): + """ + Returns a function that applies the given value to the given value_name in opts.data. + """ + def validate(value_name:str, value:str): + assert value.lower() in ["true", "false"], f"Value {value} for {value_name} must be either true or false" + def apply_bool(p, x, xs): + validate(value_name, x) + value_boolean = x.lower() == "true" + opts.data[value_name] = value_boolean + return apply_bool + +def add_axis_options(): + extra_axis_options = [ + xyz_grid.AxisOption("[Hypertile] Unet First pass Enabled", str, bool_applier("hypertile_enable_unet"), choices=xyz_grid.boolean_choice(reverse=True)), + xyz_grid.AxisOption("[Hypertile] Unet Second pass Enabled", str, bool_applier("hypertile_enable_unet_secondpass"), choices=xyz_grid.boolean_choice(reverse=True)), + xyz_grid.AxisOption("[Hypertile] Unet Max Depth", int, int_applier("hypertile_max_depth_unet", 0, 3), choices=lambda: [str(x) for x in range(4)]), + xyz_grid.AxisOption("[Hypertile] Unet Max Tile Size", int, int_applier("hypertile_max_tile_unet", 0, 512)), + xyz_grid.AxisOption("[Hypertile] Unet Swap Size", int, int_applier("hypertile_swap_size_unet", 0, 64)), + xyz_grid.AxisOption("[Hypertile] VAE Enabled", str, bool_applier("hypertile_enable_vae"), choices=xyz_grid.boolean_choice(reverse=True)), + xyz_grid.AxisOption("[Hypertile] VAE Max Depth", int, int_applier("hypertile_max_depth_vae", 0, 3), choices=lambda: [str(x) for x in range(4)]), + xyz_grid.AxisOption("[Hypertile] VAE Max Tile Size", int, int_applier("hypertile_max_tile_vae", 0, 512)), + xyz_grid.AxisOption("[Hypertile] VAE Swap Size", int, int_applier("hypertile_swap_size_vae", 0, 64)), + ] + # check if the axis options have already been added + if any(set(opt.label for opt in extra_axis_options).intersection(set(opt.label for opt in xyz_grid.axis_options))): + return + xyz_grid.axis_options.extend(extra_axis_options) \ No newline at end of file From 601a7b4ce5b28efd29b1668c7b8b74fb6b62f6f3 Mon Sep 17 00:00:00 2001 From: aria1th <35677394+aria1th@users.noreply.github.com> Date: Mon, 27 Nov 2023 22:10:31 +0900 Subject: [PATCH 0317/1174] cache divisors / fix ruff --- extensions-builtin/hypertile/hypertile.py | 24 ++++++++++++------- .../hypertile/scripts/hypertile_script.py | 2 +- .../hypertile/scripts/hypertile_xyz.py | 18 +++++++------- 3 files changed, 26 insertions(+), 18 deletions(-) diff --git a/extensions-builtin/hypertile/hypertile.py b/extensions-builtin/hypertile/hypertile.py index feb02fd2737..0f40e2d3925 100644 --- a/extensions-builtin/hypertile/hypertile.py +++ b/extensions-builtin/hypertile/hypertile.py @@ -6,7 +6,6 @@ from __future__ import annotations -import functools from dataclasses import dataclass from typing import Callable @@ -189,20 +188,27 @@ class HypertileParams: RNG_INSTANCE = random.Random() - -def random_divisor(value: int, min_value: int, /, max_options: int = 1) -> int: +@cache +def get_divisors(value: int, min_value: int, /, max_options: int = 1) -> list[int]: """ - Returns a random divisor of value that + Returns divisors of value that x * min_value <= value - if max_options is 1, the behavior is deterministic + in big -> small order, amount of divisors is limited by max_options """ + max_options = max(1, max_options) # at least 1 option should be returned min_value = min(min_value, value) - - # All big divisors of value (inclusive) divisors = [i for i in range(min_value, value + 1) if value % i == 0] # divisors in small -> big order - ns = [value // i for i in divisors[:max_options]] # has at least 1 element # big -> small order + return ns + +def random_divisor(value: int, min_value: int, /, max_options: int = 1) -> int: + """ + Returns a random divisor of value that + x * min_value <= value + if max_options is 1, the behavior is deterministic + """ + ns = get_divisors(value, min_value, max_options=max_options) # get cached divisors idx = RNG_INSTANCE.randint(0, len(ns) - 1) return ns[idx] @@ -212,7 +218,7 @@ def set_hypertile_seed(seed: int) -> None: RNG_INSTANCE.seed(seed) -@functools.cache +@cache def largest_tile_size_available(width: int, height: int) -> int: """ Calculates the largest tile size available for a given width and height diff --git a/extensions-builtin/hypertile/scripts/hypertile_script.py b/extensions-builtin/hypertile/scripts/hypertile_script.py index b2413cc5fc3..d3ab6091500 100644 --- a/extensions-builtin/hypertile/scripts/hypertile_script.py +++ b/extensions-builtin/hypertile/scripts/hypertile_script.py @@ -75,4 +75,4 @@ def on_ui_settings(): script_callbacks.on_ui_settings(on_ui_settings) -script_callbacks.on_before_ui(add_axis_options) \ No newline at end of file +script_callbacks.on_before_ui(add_axis_options) diff --git a/extensions-builtin/hypertile/scripts/hypertile_xyz.py b/extensions-builtin/hypertile/scripts/hypertile_xyz.py index eaf7c8d7a6d..3007a08326c 100644 --- a/extensions-builtin/hypertile/scripts/hypertile_xyz.py +++ b/extensions-builtin/hypertile/scripts/hypertile_xyz.py @@ -1,17 +1,17 @@ from modules import scripts -xyz_grid = [x for x in scripts.scripts_data if x.script_class.__module__ == "xyz_grid.py"][0].module from modules.shared import opts +xyz_grid = [x for x in scripts.scripts_data if x.script_class.__module__ == "xyz_grid.py"][0].module + def int_applier(value_name:str, min_range:int = -1, max_range:int = -1): """ Returns a function that applies the given value to the given value_name in opts.data. """ # convert to int def validate(value_name:str, value:str): - try: - value = int(value) - except: - raise ValueError(f"Value {value} for {value_name} is not an integer") + if not value.isnumeric(): + raise ValueError(f"Value {value} for {value_name} must be an integer") + value = int(value) # validate value if not min_range == -1: assert value >= min_range, f"Value {value} for {value_name} must be greater than or equal to {min_range}" @@ -46,7 +46,9 @@ def add_axis_options(): xyz_grid.AxisOption("[Hypertile] VAE Max Tile Size", int, int_applier("hypertile_max_tile_vae", 0, 512)), xyz_grid.AxisOption("[Hypertile] VAE Swap Size", int, int_applier("hypertile_swap_size_vae", 0, 64)), ] - # check if the axis options have already been added - if any(set(opt.label for opt in extra_axis_options).intersection(set(opt.label for opt in xyz_grid.axis_options))): + set_a = set([opt.label for opt in xyz_grid.axis_options]) + set_b = set([opt.label for opt in extra_axis_options]) + if set_a.intersection(set_b): return - xyz_grid.axis_options.extend(extra_axis_options) \ No newline at end of file + + xyz_grid.axis_options.extend(extra_axis_options) From f207eb7a0d8b4443dbe665df99c31f8ff91660fd Mon Sep 17 00:00:00 2001 From: aria1th <35677394+aria1th@users.noreply.github.com> Date: Mon, 27 Nov 2023 22:11:28 +0900 Subject: [PATCH 0318/1174] fix ruff in hypertile_xyz.py --- extensions-builtin/hypertile/scripts/hypertile_xyz.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/extensions-builtin/hypertile/scripts/hypertile_xyz.py b/extensions-builtin/hypertile/scripts/hypertile_xyz.py index 3007a08326c..4055a9eada7 100644 --- a/extensions-builtin/hypertile/scripts/hypertile_xyz.py +++ b/extensions-builtin/hypertile/scripts/hypertile_xyz.py @@ -46,8 +46,8 @@ def add_axis_options(): xyz_grid.AxisOption("[Hypertile] VAE Max Tile Size", int, int_applier("hypertile_max_tile_vae", 0, 512)), xyz_grid.AxisOption("[Hypertile] VAE Swap Size", int, int_applier("hypertile_swap_size_vae", 0, 64)), ] - set_a = set([opt.label for opt in xyz_grid.axis_options]) - set_b = set([opt.label for opt in extra_axis_options]) + set_a = set(opt.label for opt in xyz_grid.axis_options) + set_b = set(opt.label for opt in extra_axis_options) if set_a.intersection(set_b): return From 524d6a4dbae68bf557d9c5fe686707d96841e0b5 Mon Sep 17 00:00:00 2001 From: aria1th <35677394+aria1th@users.noreply.github.com> Date: Mon, 27 Nov 2023 22:13:18 +0900 Subject: [PATCH 0319/1174] fix ruff - set comprehension --- extensions-builtin/hypertile/scripts/hypertile_xyz.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/extensions-builtin/hypertile/scripts/hypertile_xyz.py b/extensions-builtin/hypertile/scripts/hypertile_xyz.py index 4055a9eada7..928e99652cb 100644 --- a/extensions-builtin/hypertile/scripts/hypertile_xyz.py +++ b/extensions-builtin/hypertile/scripts/hypertile_xyz.py @@ -46,8 +46,8 @@ def add_axis_options(): xyz_grid.AxisOption("[Hypertile] VAE Max Tile Size", int, int_applier("hypertile_max_tile_vae", 0, 512)), xyz_grid.AxisOption("[Hypertile] VAE Swap Size", int, int_applier("hypertile_swap_size_vae", 0, 64)), ] - set_a = set(opt.label for opt in xyz_grid.axis_options) - set_b = set(opt.label for opt in extra_axis_options) + set_a = {opt.label for opt in xyz_grid.axis_options} + set_b = {opt.label for opt in extra_axis_options} if set_a.intersection(set_b): return From ec78354efa179b64e92d6b98d781f6572b4eb084 Mon Sep 17 00:00:00 2001 From: aria1th <35677394+aria1th@users.noreply.github.com> Date: Mon, 27 Nov 2023 22:25:28 +0900 Subject: [PATCH 0320/1174] hypertile_xyz: we don't need isnumeric check for AxisOption --- extensions-builtin/hypertile/scripts/hypertile_xyz.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/extensions-builtin/hypertile/scripts/hypertile_xyz.py b/extensions-builtin/hypertile/scripts/hypertile_xyz.py index 928e99652cb..9e96ae3c527 100644 --- a/extensions-builtin/hypertile/scripts/hypertile_xyz.py +++ b/extensions-builtin/hypertile/scripts/hypertile_xyz.py @@ -7,10 +7,7 @@ def int_applier(value_name:str, min_range:int = -1, max_range:int = -1): """ Returns a function that applies the given value to the given value_name in opts.data. """ - # convert to int def validate(value_name:str, value:str): - if not value.isnumeric(): - raise ValueError(f"Value {value} for {value_name} must be an integer") value = int(value) # validate value if not min_range == -1: From 3cd6e1d0a0877e6f1ac931c8253e6eee09da3805 Mon Sep 17 00:00:00 2001 From: obsol <33932119+read-0nly@users.noreply.github.com> Date: Mon, 27 Nov 2023 19:21:43 -0500 Subject: [PATCH 0321/1174] Update devices.py fixes issue where "--use-cpu" all properly makes SD run on CPU but leaves ControlNet (and other extensions, I presume) pointed at GPU, causing a crash in ControlNet caused by a mismatch between devices between SD and CN https://github.com/AUTOMATIC1111/stable-diffusion-webui/issues/14097 --- modules/devices.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/devices.py b/modules/devices.py index c01f06024b4..65efcf1ebdd 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -38,7 +38,7 @@ def get_optimal_device(): def get_device_for(task): - if task in shared.cmd_opts.use_cpu: + if task in shared.cmd_opts.use_cpu or "all" in shared.cmd_opts.use_cpu: return cpu return get_optimal_device() From 03ee297aa22296ea12b965fc1cb11aa46375d372 Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Mon, 27 Nov 2023 17:26:16 +0900 Subject: [PATCH 0322/1174] fix Auto focal point crop for opencv >= 4.8.x autocrop.download_and_cache_models in opencv >= 4.8 the face detection model was updated download the base on opencv version returns the model path or raise exception --- modules/textual_inversion/autocrop.py | 29 ++++++++++++++----------- modules/textual_inversion/preprocess.py | 4 ++-- 2 files changed, 18 insertions(+), 15 deletions(-) diff --git a/modules/textual_inversion/autocrop.py b/modules/textual_inversion/autocrop.py index 1675e39a54b..051be1188ea 100644 --- a/modules/textual_inversion/autocrop.py +++ b/modules/textual_inversion/autocrop.py @@ -3,6 +3,8 @@ import os import numpy as np from PIL import ImageDraw +from modules import paths_internal +from pkg_resources import parse_version GREEN = "#0F0" BLUE = "#00F" @@ -294,22 +296,23 @@ def is_square(w, h): return w == h -def download_and_cache_models(dirname): - download_url = 'https://github.com/opencv/opencv_zoo/blob/91fb0290f50896f38a0ab1e558b74b16bc009428/models/face_detection_yunet/face_detection_yunet_2022mar.onnx?raw=true' - model_file_name = 'face_detection_yunet.onnx' +model_dir_opencv = os.path.join(paths_internal.models_path, 'opencv') +if parse_version(cv2.__version__) >= parse_version('4.8'): + model_file_path = os.path.join(model_dir_opencv, 'face_detection_yunet_2023mar.onnx') + model_url = 'https://github.com/opencv/opencv_zoo/blob/b6e370b10f641879a87890d44e42173077154a05/models/face_detection_yunet/face_detection_yunet_2023mar.onnx?raw=true' +else: + model_file_path = os.path.join(model_dir_opencv, 'face_detection_yunet.onnx') + model_url = 'https://github.com/opencv/opencv_zoo/blob/91fb0290f50896f38a0ab1e558b74b16bc009428/models/face_detection_yunet/face_detection_yunet_2022mar.onnx?raw=true' - os.makedirs(dirname, exist_ok=True) - cache_file = os.path.join(dirname, model_file_name) - if not os.path.exists(cache_file): - print(f"downloading face detection model from '{download_url}' to '{cache_file}'") - response = requests.get(download_url) - with open(cache_file, "wb") as f: +def download_and_cache_models(): + if not os.path.exists(model_file_path): + os.makedirs(model_dir_opencv, exist_ok=True) + print(f"downloading face detection model from '{model_url}' to '{model_file_path}'") + response = requests.get(model_url) + with open(model_file_path, "wb") as f: f.write(response.content) - - if os.path.exists(cache_file): - return cache_file - return None + return model_file_path class PointOfInterest: diff --git a/modules/textual_inversion/preprocess.py b/modules/textual_inversion/preprocess.py index dbd856bd859..789fa08384b 100644 --- a/modules/textual_inversion/preprocess.py +++ b/modules/textual_inversion/preprocess.py @@ -3,7 +3,7 @@ import math import tqdm -from modules import paths, shared, images, deepbooru +from modules import shared, images, deepbooru from modules.textual_inversion import autocrop @@ -196,7 +196,7 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pre dnn_model_path = None try: - dnn_model_path = autocrop.download_and_cache_models(os.path.join(paths.models_path, "opencv")) + dnn_model_path = autocrop.download_and_cache_models() except Exception as e: print("Unable to load face detection model for auto crop selection. Falling back to lower quality haar method.", e) From d608926f817b279d16b39a7875beec80d010a988 Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Tue, 28 Nov 2023 12:12:27 +0900 Subject: [PATCH 0323/1174] reformat file with uniform indentation --- modules/textual_inversion/autocrop.py | 210 +++++++++++++------------- 1 file changed, 106 insertions(+), 104 deletions(-) diff --git a/modules/textual_inversion/autocrop.py b/modules/textual_inversion/autocrop.py index 051be1188ea..e223a2e0cc9 100644 --- a/modules/textual_inversion/autocrop.py +++ b/modules/textual_inversion/autocrop.py @@ -27,7 +27,6 @@ def crop_image(im, settings): elif is_portrait(settings.crop_width, settings.crop_height): scale_by = settings.crop_height / im.height - im = im.resize((int(im.width * scale_by), int(im.height * scale_by))) im_debug = im.copy() @@ -71,6 +70,7 @@ def crop_image(im, settings): return results + def focal_point(im, settings): corner_points = image_corner_points(im, settings) if settings.corner_points_weight > 0 else [] entropy_points = image_entropy_points(im, settings) if settings.entropy_points_weight > 0 else [] @@ -80,118 +80,120 @@ def focal_point(im, settings): weight_pref_total = 0 if corner_points: - weight_pref_total += settings.corner_points_weight + weight_pref_total += settings.corner_points_weight if entropy_points: - weight_pref_total += settings.entropy_points_weight + weight_pref_total += settings.entropy_points_weight if face_points: - weight_pref_total += settings.face_points_weight + weight_pref_total += settings.face_points_weight corner_centroid = None if corner_points: - corner_centroid = centroid(corner_points) - corner_centroid.weight = settings.corner_points_weight / weight_pref_total - pois.append(corner_centroid) + corner_centroid = centroid(corner_points) + corner_centroid.weight = settings.corner_points_weight / weight_pref_total + pois.append(corner_centroid) entropy_centroid = None if entropy_points: - entropy_centroid = centroid(entropy_points) - entropy_centroid.weight = settings.entropy_points_weight / weight_pref_total - pois.append(entropy_centroid) + entropy_centroid = centroid(entropy_points) + entropy_centroid.weight = settings.entropy_points_weight / weight_pref_total + pois.append(entropy_centroid) face_centroid = None if face_points: - face_centroid = centroid(face_points) - face_centroid.weight = settings.face_points_weight / weight_pref_total - pois.append(face_centroid) + face_centroid = centroid(face_points) + face_centroid.weight = settings.face_points_weight / weight_pref_total + pois.append(face_centroid) average_point = poi_average(pois, settings) if settings.annotate_image: - d = ImageDraw.Draw(im) - max_size = min(im.width, im.height) * 0.07 - if corner_centroid is not None: - color = BLUE - box = corner_centroid.bounding(max_size * corner_centroid.weight) - d.text((box[0], box[1]-15), f"Edge: {corner_centroid.weight:.02f}", fill=color) - d.ellipse(box, outline=color) - if len(corner_points) > 1: - for f in corner_points: - d.rectangle(f.bounding(4), outline=color) - if entropy_centroid is not None: - color = "#ff0" - box = entropy_centroid.bounding(max_size * entropy_centroid.weight) - d.text((box[0], box[1]-15), f"Entropy: {entropy_centroid.weight:.02f}", fill=color) - d.ellipse(box, outline=color) - if len(entropy_points) > 1: - for f in entropy_points: - d.rectangle(f.bounding(4), outline=color) - if face_centroid is not None: - color = RED - box = face_centroid.bounding(max_size * face_centroid.weight) - d.text((box[0], box[1]-15), f"Face: {face_centroid.weight:.02f}", fill=color) - d.ellipse(box, outline=color) - if len(face_points) > 1: - for f in face_points: - d.rectangle(f.bounding(4), outline=color) - - d.ellipse(average_point.bounding(max_size), outline=GREEN) + d = ImageDraw.Draw(im) + max_size = min(im.width, im.height) * 0.07 + if corner_centroid is not None: + color = BLUE + box = corner_centroid.bounding(max_size * corner_centroid.weight) + d.text((box[0], box[1] - 15), f"Edge: {corner_centroid.weight:.02f}", fill=color) + d.ellipse(box, outline=color) + if len(corner_points) > 1: + for f in corner_points: + d.rectangle(f.bounding(4), outline=color) + if entropy_centroid is not None: + color = "#ff0" + box = entropy_centroid.bounding(max_size * entropy_centroid.weight) + d.text((box[0], box[1] - 15), f"Entropy: {entropy_centroid.weight:.02f}", fill=color) + d.ellipse(box, outline=color) + if len(entropy_points) > 1: + for f in entropy_points: + d.rectangle(f.bounding(4), outline=color) + if face_centroid is not None: + color = RED + box = face_centroid.bounding(max_size * face_centroid.weight) + d.text((box[0], box[1] - 15), f"Face: {face_centroid.weight:.02f}", fill=color) + d.ellipse(box, outline=color) + if len(face_points) > 1: + for f in face_points: + d.rectangle(f.bounding(4), outline=color) + + d.ellipse(average_point.bounding(max_size), outline=GREEN) return average_point def image_face_points(im, settings): if settings.dnn_model_path is not None: - detector = cv2.FaceDetectorYN.create( - settings.dnn_model_path, - "", - (im.width, im.height), - 0.9, # score threshold - 0.3, # nms threshold - 5000 # keep top k before nms - ) - faces = detector.detect(np.array(im)) - results = [] - if faces[1] is not None: - for face in faces[1]: - x = face[0] - y = face[1] - w = face[2] - h = face[3] - results.append( - PointOfInterest( - int(x + (w * 0.5)), # face focus left/right is center - int(y + (h * 0.33)), # face focus up/down is close to the top of the head - size = w, - weight = 1/len(faces[1]) - ) - ) - return results + detector = cv2.FaceDetectorYN.create( + settings.dnn_model_path, + "", + (im.width, im.height), + 0.9, # score threshold + 0.3, # nms threshold + 5000 # keep top k before nms + ) + faces = detector.detect(np.array(im)) + results = [] + if faces[1] is not None: + for face in faces[1]: + x = face[0] + y = face[1] + w = face[2] + h = face[3] + results.append( + PointOfInterest( + int(x + (w * 0.5)), # face focus left/right is center + int(y + (h * 0.33)), # face focus up/down is close to the top of the head + size=w, + weight=1 / len(faces[1]) + ) + ) + return results else: - np_im = np.array(im) - gray = cv2.cvtColor(np_im, cv2.COLOR_BGR2GRAY) - - tries = [ - [ f'{cv2.data.haarcascades}haarcascade_eye.xml', 0.01 ], - [ f'{cv2.data.haarcascades}haarcascade_frontalface_default.xml', 0.05 ], - [ f'{cv2.data.haarcascades}haarcascade_profileface.xml', 0.05 ], - [ f'{cv2.data.haarcascades}haarcascade_frontalface_alt.xml', 0.05 ], - [ f'{cv2.data.haarcascades}haarcascade_frontalface_alt2.xml', 0.05 ], - [ f'{cv2.data.haarcascades}haarcascade_frontalface_alt_tree.xml', 0.05 ], - [ f'{cv2.data.haarcascades}haarcascade_eye_tree_eyeglasses.xml', 0.05 ], - [ f'{cv2.data.haarcascades}haarcascade_upperbody.xml', 0.05 ] - ] - for t in tries: - classifier = cv2.CascadeClassifier(t[0]) - minsize = int(min(im.width, im.height) * t[1]) # at least N percent of the smallest side - try: - faces = classifier.detectMultiScale(gray, scaleFactor=1.1, - minNeighbors=7, minSize=(minsize, minsize), flags=cv2.CASCADE_SCALE_IMAGE) - except Exception: - continue - - if faces: - rects = [[f[0], f[1], f[0] + f[2], f[1] + f[3]] for f in faces] - return [PointOfInterest((r[0] +r[2]) // 2, (r[1] + r[3]) // 2, size=abs(r[0]-r[2]), weight=1/len(rects)) for r in rects] + np_im = np.array(im) + gray = cv2.cvtColor(np_im, cv2.COLOR_BGR2GRAY) + + tries = [ + [f'{cv2.data.haarcascades}haarcascade_eye.xml', 0.01], + [f'{cv2.data.haarcascades}haarcascade_frontalface_default.xml', 0.05], + [f'{cv2.data.haarcascades}haarcascade_profileface.xml', 0.05], + [f'{cv2.data.haarcascades}haarcascade_frontalface_alt.xml', 0.05], + [f'{cv2.data.haarcascades}haarcascade_frontalface_alt2.xml', 0.05], + [f'{cv2.data.haarcascades}haarcascade_frontalface_alt_tree.xml', 0.05], + [f'{cv2.data.haarcascades}haarcascade_eye_tree_eyeglasses.xml', 0.05], + [f'{cv2.data.haarcascades}haarcascade_upperbody.xml', 0.05] + ] + for t in tries: + classifier = cv2.CascadeClassifier(t[0]) + minsize = int(min(im.width, im.height) * t[1]) # at least N percent of the smallest side + try: + faces = classifier.detectMultiScale(gray, scaleFactor=1.1, + minNeighbors=7, minSize=(minsize, minsize), + flags=cv2.CASCADE_SCALE_IMAGE) + except Exception: + continue + + if faces: + rects = [[f[0], f[1], f[0] + f[2], f[1] + f[3]] for f in faces] + return [PointOfInterest((r[0] + r[2]) // 2, (r[1] + r[3]) // 2, size=abs(r[0] - r[2]), + weight=1 / len(rects)) for r in rects] return [] @@ -200,7 +202,7 @@ def image_corner_points(im, settings): # naive attempt at preventing focal points from collecting at watermarks near the bottom gd = ImageDraw.Draw(grayscale) - gd.rectangle([0, im.height*.9, im.width, im.height], fill="#999") + gd.rectangle([0, im.height * .9, im.width, im.height], fill="#999") np_im = np.array(grayscale) @@ -208,7 +210,7 @@ def image_corner_points(im, settings): np_im, maxCorners=100, qualityLevel=0.04, - minDistance=min(grayscale.width, grayscale.height)*0.06, + minDistance=min(grayscale.width, grayscale.height) * 0.06, useHarrisDetector=False, ) @@ -217,8 +219,8 @@ def image_corner_points(im, settings): focal_points = [] for point in points: - x, y = point.ravel() - focal_points.append(PointOfInterest(x, y, size=4, weight=1/len(points))) + x, y = point.ravel() + focal_points.append(PointOfInterest(x, y, size=4, weight=1 / len(points))) return focal_points @@ -227,13 +229,13 @@ def image_entropy_points(im, settings): landscape = im.height < im.width portrait = im.height > im.width if landscape: - move_idx = [0, 2] - move_max = im.size[0] + move_idx = [0, 2] + move_max = im.size[0] elif portrait: - move_idx = [1, 3] - move_max = im.size[1] + move_idx = [1, 3] + move_max = im.size[1] else: - return [] + return [] e_max = 0 crop_current = [0, 0, settings.crop_width, settings.crop_height] @@ -243,14 +245,14 @@ def image_entropy_points(im, settings): e = image_entropy(crop) if (e > e_max): - e_max = e - crop_best = list(crop_current) + e_max = e + crop_best = list(crop_current) crop_current[move_idx[0]] += 4 crop_current[move_idx[1]] += 4 - x_mid = int(crop_best[0] + settings.crop_width/2) - y_mid = int(crop_best[1] + settings.crop_height/2) + x_mid = int(crop_best[0] + settings.crop_width / 2) + y_mid = int(crop_best[1] + settings.crop_height / 2) return [PointOfInterest(x_mid, y_mid, size=25, weight=1.0)] From 39eae9f009c8302eed77b0942e1e634f6125d53e Mon Sep 17 00:00:00 2001 From: hidenorly Date: Wed, 29 Nov 2023 04:07:48 +0900 Subject: [PATCH 0324/1174] Revert "Add FP32 fallback support on sd_vae_approx" This reverts commit 58c19545c83fa6925c9ce2216ee64964eb5129ce. Since the modification is expected to move to mac_specific.py (https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14046#issuecomment-1826731532) --- modules/sd_vae_approx.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/modules/sd_vae_approx.py b/modules/sd_vae_approx.py index 8370493f932..3965e223e6f 100644 --- a/modules/sd_vae_approx.py +++ b/modules/sd_vae_approx.py @@ -21,13 +21,7 @@ def __init__(self): def forward(self, x): extra = 11 - try: - x = nn.functional.interpolate(x, (x.shape[2] * 2, x.shape[3] * 2)) - except RuntimeError as e: - if "not implemented for" in str(e) and "Half" in str(e): - x = nn.functional.interpolate(x.to(torch.float32), (x.shape[2] * 2, x.shape[3] * 2)).to(x.dtype) - else: - print(f"An unexpected RuntimeError occurred: {str(e)}") + x = nn.functional.interpolate(x, (x.shape[2] * 2, x.shape[3] * 2)) x = nn.functional.pad(x, (extra, extra, extra, extra)) for layer in [self.conv1, self.conv2, self.conv3, self.conv4, self.conv5, self.conv6, self.conv7, self.conv8, ]: From a0096c58977c01ddc6a2b83a8a7b64da6fd4a51e Mon Sep 17 00:00:00 2001 From: hidenorly Date: Wed, 29 Nov 2023 04:45:04 +0900 Subject: [PATCH 0325/1174] Add FP32 fallback support on torch.nn.functional.interpolate This tries to execute interpolate with FP32 if it failed. Background is that on some environment such as Mx chip MacOS devices, we get error as follows: ``` "torch/nn/functional.py", line 3931, in interpolate return torch._C._nn.upsample_nearest2d(input, output_size, scale_factors) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ RuntimeError: "upsample_nearest2d_channels_last" not implemented for 'Half' ``` In this case, ```--no-half``` doesn't help to solve. Therefore this commits add the FP32 fallback execution to solve it. Note that the ```upsample_nearest2d``` is called from ```torch.nn.functional.interpolate```. And the fallback for torch.nn.functional.interpolate is necessary at ```modules/sd_vae_approx.py``` 's ```VAEApprox.forward``` ```repositories/stable-diffusion-stability-ai/ldm/modules/diffusionmodules/openaimodel.py``` 's ```Upsample.forward``` --- modules/mac_specific.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/modules/mac_specific.py b/modules/mac_specific.py index 89256c5b060..3538e659ded 100644 --- a/modules/mac_specific.py +++ b/modules/mac_specific.py @@ -1,6 +1,8 @@ import logging import torch +from typing import Optional, List +from torch import Tensor import platform from modules.sd_hijack_utils import CondFunc from packaging import version @@ -51,6 +53,17 @@ def cumsum_fix(input, cumsum_func, *args, **kwargs): return cumsum_func(input, *args, **kwargs) +# MPS workaround for https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14046 +def interpolate_with_fp32_fallback(orig_func, *args, **kwargs) -> Tensor: + try: + return orig_func(*args, **kwargs) + except RuntimeError as e: + if "not implemented for" in str(e) and "Half" in str(e): + input_tensor = args[0] + return orig_func(input_tensor.to(torch.float32), *args[1:], **kwargs).to(input_tensor.dtype) + else: + print(f"An unexpected RuntimeError occurred: {str(e)}") + if has_mps: if platform.mac_ver()[0].startswith("13.2."): # MPS workaround for https://github.com/pytorch/pytorch/issues/95188, thanks to danieldk (https://github.com/explosion/curated-transformers/pull/124) @@ -77,6 +90,9 @@ def cumsum_fix(input, cumsum_func, *args, **kwargs): # MPS workaround for https://github.com/pytorch/pytorch/issues/96113 CondFunc('torch.nn.functional.layer_norm', lambda orig_func, x, normalized_shape, weight, bias, eps, **kwargs: orig_func(x.float(), normalized_shape, weight.float() if weight is not None else None, bias.float() if bias is not None else bias, eps).to(x.dtype), lambda _, input, *args, **kwargs: len(args) == 4 and input.device.type == 'mps') + # MPS workaround for https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14046 + CondFunc('torch.nn.functional.interpolate', interpolate_with_fp32_fallback, None) + # MPS workaround for https://github.com/pytorch/pytorch/issues/92311 if platform.processor() == 'i386': for funcName in ['torch.argmax', 'torch.Tensor.argmax']: From 81c00728b8ec0b6c0e70ea10c7687aad065a95cb Mon Sep 17 00:00:00 2001 From: hidenorly Date: Wed, 29 Nov 2023 04:59:35 +0900 Subject: [PATCH 0326/1174] Fix the Ruff error about unused import --- modules/mac_specific.py | 1 - 1 file changed, 1 deletion(-) diff --git a/modules/mac_specific.py b/modules/mac_specific.py index 3538e659ded..d96d86d792c 100644 --- a/modules/mac_specific.py +++ b/modules/mac_specific.py @@ -1,7 +1,6 @@ import logging import torch -from typing import Optional, List from torch import Tensor import platform from modules.sd_hijack_utils import CondFunc From dec791d35ddcd02ca33563d3d0355e05e45de8ad Mon Sep 17 00:00:00 2001 From: CodeHatchling Date: Tue, 28 Nov 2023 15:05:01 -0700 Subject: [PATCH 0327/1174] Removed code which forces the inpainting mask to be 0 or 1. Now fractional values (e.g. 0.5) are accepted. --- modules/processing.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/modules/processing.py b/modules/processing.py index e124e7f0dd2..317458f582f 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -83,7 +83,7 @@ def apply_overlay(image, paste_loc, index, overlays): def create_binary_mask(image): if image.mode == 'RGBA' and image.getextrema()[-1] != (255, 255): - image = image.split()[-1].convert("L").point(lambda x: 255 if x > 128 else 0) + image = image.split()[-1].convert("L") else: image = image.convert('L') return image @@ -319,9 +319,6 @@ def inpainting_image_conditioning(self, source_image, latent_image, image_mask=N conditioning_mask = np.array(image_mask.convert("L")) conditioning_mask = conditioning_mask.astype(np.float32) / 255.0 conditioning_mask = torch.from_numpy(conditioning_mask[None, None]) - - # Inpainting model uses a discretized mask as input, so we round to either 1.0 or 0.0 - conditioning_mask = torch.round(conditioning_mask) else: conditioning_mask = source_image.new_ones(1, 1, *source_image.shape[-2:]) @@ -1504,7 +1501,6 @@ def init(self, all_prompts, all_seeds, all_subseeds): latmask = init_mask.convert('RGB').resize((self.init_latent.shape[3], self.init_latent.shape[2])) latmask = np.moveaxis(np.array(latmask, dtype=np.float32), 2, 0) / 255 latmask = latmask[0] - latmask = np.around(latmask) latmask = np.tile(latmask[None], (4, 1, 1)) self.mask = torch.asarray(1.0 - latmask).to(shared.device).type(self.sd_model.dtype) From bbba133f054706c3668b7d03b0e6d0afc15705db Mon Sep 17 00:00:00 2001 From: CodeHatchling Date: Tue, 28 Nov 2023 15:09:43 -0700 Subject: [PATCH 0328/1174] Removed conflicting step that replaces the softly inpainted latents with a naive blend with the original latents. --- modules/processing.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/modules/processing.py b/modules/processing.py index 317458f582f..ae894f1a756 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -1523,9 +1523,6 @@ def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subs samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning, image_conditioning=self.image_conditioning) - if self.mask is not None: - samples = samples * self.nmask + self.init_latent * self.mask - del x devices.torch_gc() From e715e46b6aa7f2e5e147cfa1fa2f49b1d926a074 Mon Sep 17 00:00:00 2001 From: CodeHatchling Date: Tue, 28 Nov 2023 16:10:22 -0700 Subject: [PATCH 0329/1174] Implements "scheduling" for blending of the original latents and a latent blending formula that preserves details in blend transition areas. --- modules/sd_samplers_cfg_denoiser.py | 61 ++++++++++++++++++++++++++++- 1 file changed, 59 insertions(+), 2 deletions(-) diff --git a/modules/sd_samplers_cfg_denoiser.py b/modules/sd_samplers_cfg_denoiser.py index b8101d38dc3..c4d6fda650f 100644 --- a/modules/sd_samplers_cfg_denoiser.py +++ b/modules/sd_samplers_cfg_denoiser.py @@ -43,6 +43,9 @@ def __init__(self, sampler): self.model_wrap = None self.mask = None self.nmask = None + self.mask_blend_power = 1 + self.mask_blend_scale = 1 + self.mask_blend_offset = 0 self.init_latent = None self.steps = None """number of steps as specified by user in UI""" @@ -56,6 +59,9 @@ def __init__(self, sampler): self.sampler = sampler self.model_wrap = None self.p = None + + # NOTE: masking before denoising can cause the original latents to be oversmoothed + # as the original latents do not have noise self.mask_before_denoising = False @property @@ -89,6 +95,55 @@ def update_inner_model(self): self.sampler.sampler_extra_args['uncond'] = uc def forward(self, x, sigma, uncond, cond, cond_scale, s_min_uncond, image_cond): + def latent_blend(a, b, t): + """ + Interpolates two latent image representations according to the parameter t, + where the interpolated vectors' magnitudes are also interpolated separately. + The "detail_preservation" factor biases the magnitude interpolation towards + the larger of the two magnitudes. + """ + # Record the original latent vector magnitudes. + # We bring them to a power so that larger magnitudes are favored over smaller ones. + # 64-bit operations are used here to allow large exponents. + detail_preservation = 32 + a_magnitude = torch.norm(a, p=2, dim=1).to(torch.float64) ** detail_preservation + b_magnitude = torch.norm(b, p=2, dim=1).to(torch.float64) ** detail_preservation + + one_minus_t = 1 - t + + # Interpolate the powered magnitudes, then un-power them (bring them back to a power of 1). + interp_magnitude = (a_magnitude * one_minus_t + b_magnitude * t) ** (1 / detail_preservation) + + # Linearly interpolate the image vectors. + image_interp = a * one_minus_t + b * t + + # Calculate the magnitude of the interpolated vectors. (We will remove this magnitude.) + # 64-bit operations are used here to allow large exponents. + image_interp_magnitude = torch.norm(image_interp, p=2, dim=1).to(torch.float64) + 0.0001 + + # Change the linearly interpolated image vectors' magnitudes to the value we want. + # This is the last 64-bit operation. + image_interp *= (interp_magnitude / image_interp_magnitude).to(image_interp.dtype) + + return image_interp + + def get_modified_nmask(nmask, _sigma): + """ + Converts a negative mask representing the transparency of the original latent vectors being overlayed + to a mask that is scaled according to the denoising strength for this step. + + Where: + 0 = fully opaque, infinite density, fully masked + 1 = fully transparent, zero density, fully unmasked + + We bring this transparency to a power, as this allows one to simulate N number of blending operations + where N can be any positive real value. Using this one can control the balance of influence between + the denoiser and the original latents according to the sigma value. + + NOTE: "mask" is not used + """ + return torch.pow(nmask, (_sigma ** self.mask_blend_power) * self.mask_blend_scale + self.mask_blend_offset) + if state.interrupted or state.skipped: raise sd_samplers_common.InterruptedException @@ -105,8 +160,9 @@ def forward(self, x, sigma, uncond, cond, cond_scale, s_min_uncond, image_cond): assert not is_edit_model or all(len(conds) == 1 for conds in conds_list), "AND is not supported for InstructPix2Pix checkpoint (unless using Image CFG scale = 1.0)" + # Blend in the original latents (before) if self.mask_before_denoising and self.mask is not None: - x = self.init_latent * self.mask + self.nmask * x + x = latent_blend(self.init_latent, x, get_modified_nmask(self.nmask, sigma)) batch_size = len(conds_list) repeats = [len(conds_list[i]) for i in range(batch_size)] @@ -207,8 +263,9 @@ def forward(self, x, sigma, uncond, cond, cond_scale, s_min_uncond, image_cond): else: denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale) + # Blend in the original latents (after) if not self.mask_before_denoising and self.mask is not None: - denoised = self.init_latent * self.mask + self.nmask * denoised + denoised = latent_blend(self.init_latent, denoised, get_modified_nmask(self.nmask, sigma)) self.sampler.last_latent = self.get_pred_x0(torch.cat([x_in[i:i + 1] for i in denoised_image_indexes]), torch.cat([x_out[i:i + 1] for i in denoised_image_indexes]), sigma) From a6e584645305c0a91a3d46f73546e191b249210f Mon Sep 17 00:00:00 2001 From: CodeHatchling Date: Tue, 28 Nov 2023 16:13:42 -0700 Subject: [PATCH 0330/1174] Nerfs the aggressive post-processing step of overlaying the original image. --- modules/processing.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/modules/processing.py b/modules/processing.py index ae894f1a756..12e08e87630 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -1412,7 +1412,12 @@ def init(self, all_prompts, all_seeds, all_subseeds): image_mask = Image.fromarray(np_mask) if self.inpaint_full_res: - self.mask_for_overlay = image_mask + np_mask = np.array(image_mask).astype(np.float32) + np_mask /= 255 + np_mask = 1-pow(1-np_mask, 100) + np_mask *= 255 + np_mask = np.clip(np_mask, 0, 255).astype(np.uint8) + self.mask_for_overlay = Image.fromarray(np_mask) mask = image_mask.convert('L') crop_region = masking.get_crop_region(np.array(mask), self.inpaint_full_res_padding) crop_region = masking.expand_crop_region(crop_region, self.width, self.height, mask.width, mask.height) @@ -1423,8 +1428,11 @@ def init(self, all_prompts, all_seeds, all_subseeds): self.paste_to = (x1, y1, x2-x1, y2-y1) else: image_mask = images.resize_image(self.resize_mode, image_mask, self.width, self.height) - np_mask = np.array(image_mask) - np_mask = np.clip((np_mask.astype(np.float32)) * 2, 0, 255).astype(np.uint8) + np_mask = np.array(image_mask).astype(np.float32) + np_mask /= 255 + np_mask = 1-pow(1-np_mask, 100) + np_mask *= 255 + np_mask = np.clip(np_mask, 0, 255).astype(np.uint8) self.mask_for_overlay = Image.fromarray(np_mask) self.overlay_images = [] From debf836fcc8d9becc3da8b1a29e33f40b0d9ef3e Mon Sep 17 00:00:00 2001 From: CodeHatchling Date: Tue, 28 Nov 2023 16:15:36 -0700 Subject: [PATCH 0331/1174] Added UI elements to control blending parameters. --- modules/img2img.py | 48 +++++++++++++++++++++++++++++++- modules/processing.py | 3 ++ modules/sd_samplers_common.py | 3 ++ modules/ui.py | 9 ++++++ scripts/outpainting_mk_2.py | 10 +++++-- scripts/poor_mans_outpainting.py | 11 ++++++-- test/test_img2img.py | 3 ++ 7 files changed, 82 insertions(+), 5 deletions(-) diff --git a/modules/img2img.py b/modules/img2img.py index 1519e132b2b..240d0588422 100644 --- a/modules/img2img.py +++ b/modules/img2img.py @@ -116,7 +116,47 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=Fal process_images(p) -def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_name: str, mask_blur: int, mask_alpha: float, inpainting_fill: int, n_iter: int, batch_size: int, cfg_scale: float, image_cfg_scale: float, denoising_strength: float, selected_scale_tab: int, height: int, width: int, scale_by: float, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, override_settings_texts, img2img_batch_use_png_info: bool, img2img_batch_png_info_props: list, img2img_batch_png_info_dir: str, request: gr.Request, *args): +def img2img(id_task: str, + mode: int, + prompt: str, + negative_prompt: str, + prompt_styles, + init_img, + sketch, + init_img_with_mask, + inpaint_color_sketch, + inpaint_color_sketch_orig, + init_img_inpaint, + init_mask_inpaint, + steps: int, + sampler_name: str, + mask_blur: int, + mask_alpha: float, + mask_blend_power: float, + mask_blend_scale: float, + mask_blend_offset: float, + inpainting_fill: int, + n_iter: int, + batch_size: int, + cfg_scale: float, + image_cfg_scale: float, + denoising_strength: float, + selected_scale_tab: int, + height: int, + width: int, + scale_by: float, + resize_mode: int, + inpaint_full_res: bool, + inpaint_full_res_padding: int, + inpainting_mask_invert: int, + img2img_batch_input_dir: str, + img2img_batch_output_dir: str, + img2img_batch_inpaint_mask_dir: str, + override_settings_texts, + img2img_batch_use_png_info: bool, + img2img_batch_png_info_props: list, + img2img_batch_png_info_dir: str, + request: gr.Request, *args): override_settings = create_override_settings_dict(override_settings_texts) is_batch = mode == 5 @@ -174,6 +214,9 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s init_images=[image], mask=mask, mask_blur=mask_blur, + mask_blend_power=mask_blend_power, + mask_blend_scale=mask_blend_scale, + mask_blend_offset=mask_blend_offset, inpainting_fill=inpainting_fill, resize_mode=resize_mode, denoising_strength=denoising_strength, @@ -194,6 +237,9 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s if mask: p.extra_generation_params["Mask blur"] = mask_blur + p.extra_generation_params["Mask blend power"] = mask_blend_power + p.extra_generation_params["Mask blend scale"] = mask_blend_scale + p.extra_generation_params["Mask blend offset"] = mask_blend_offset with closing(p): if is_batch: diff --git a/modules/processing.py b/modules/processing.py index 12e08e87630..da4d6fda9ff 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -1349,6 +1349,9 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): mask_blur_x: int = 4 mask_blur_y: int = 4 mask_blur: int = None + mask_blend_power: float = 1 + mask_blend_scale: float = 1 + mask_blend_offset: float = 0 inpainting_fill: int = 0 inpaint_full_res: bool = True inpaint_full_res_padding: int = 0 diff --git a/modules/sd_samplers_common.py b/modules/sd_samplers_common.py index 58efcad2374..8904da2fb6b 100644 --- a/modules/sd_samplers_common.py +++ b/modules/sd_samplers_common.py @@ -277,6 +277,9 @@ def initialize(self, p) -> dict: self.model_wrap_cfg.p = p self.model_wrap_cfg.mask = p.mask if hasattr(p, 'mask') else None self.model_wrap_cfg.nmask = p.nmask if hasattr(p, 'nmask') else None + self.model_wrap_cfg.mask_blend_power = p.mask_blend_power if hasattr(p, 'mask_blend_power') else None + self.model_wrap_cfg.mask_blend_scale = p.mask_blend_scale if hasattr(p, 'mask_blend_scale') else None + self.model_wrap_cfg.mask_blend_offset = p.mask_blend_offset if hasattr(p, 'mask_blend_offset') else None self.model_wrap_cfg.step = 0 self.model_wrap_cfg.image_cfg_scale = getattr(p, 'image_cfg_scale', None) self.eta = p.eta if p.eta is not None else getattr(opts, self.eta_option_field, 0.0) diff --git a/modules/ui.py b/modules/ui.py index 579bab9800c..86c130869e5 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -732,6 +732,9 @@ def copy_image(img): with FormRow(): mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=4, elem_id="img2img_mask_blur") mask_alpha = gr.Slider(label="Mask transparency", visible=False, elem_id="img2img_mask_alpha") + mask_blend_power = gr.Slider(label='Mask blend power', minimum=0, maximum=8, step=0.1, value=1, elem_id="img2img_mask_blend_power") + mask_blend_scale = gr.Slider(label='Mask blend scale', minimum=0, maximum=8, step=0.1, value=1, elem_id="img2img_mask_blend_scale") + mask_blend_offset = gr.Slider(label='Mask blend offset', minimum=-4, maximum=4, step=0.1, value=0, elem_id="img2img_mask_blend_offset") with FormRow(): inpainting_mask_invert = gr.Radio(label='Mask mode', choices=['Inpaint masked', 'Inpaint not masked'], value='Inpaint masked', type="index", elem_id="img2img_mask_mode") @@ -781,6 +784,9 @@ def select_img2img_tab(tab): sampler_name, mask_blur, mask_alpha, + mask_blend_power, + mask_blend_scale, + mask_blend_offset, inpainting_fill, batch_count, batch_size, @@ -879,6 +885,9 @@ def select_img2img_tab(tab): (toprow.ui_styles.dropdown, lambda d: d["Styles array"] if isinstance(d.get("Styles array"), list) else gr.update()), (denoising_strength, "Denoising strength"), (mask_blur, "Mask blur"), + (mask_blend_power, "Mask blend power"), + (mask_blend_scale, "Mask blend scale"), + (mask_blend_offset, "Mask blend offset"), *scripts.scripts_img2img.infotext_fields ] parameters_copypaste.add_paste_fields("img2img", init_img, img2img_paste_fields, override_settings) diff --git a/scripts/outpainting_mk_2.py b/scripts/outpainting_mk_2.py index c98ab48098e..6aa97edfabe 100644 --- a/scripts/outpainting_mk_2.py +++ b/scripts/outpainting_mk_2.py @@ -133,13 +133,16 @@ def ui(self, is_img2img): pixels = gr.Slider(label="Pixels to expand", minimum=8, maximum=256, step=8, value=128, elem_id=self.elem_id("pixels")) mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=8, elem_id=self.elem_id("mask_blur")) + mask_blend_power = gr.Slider(label='Mask blend power', minimum=0, maximum=8, step=0.1, value=1, elem_id=self.elem_id("mask_blend_power")) + mask_blend_scale = gr.Slider(label='Mask blend scale', minimum=0, maximum=8, step=0.1, value=1, elem_id=self.elem_id("mask_blend_scale")) + mask_blend_offset = gr.Slider(label='Mask blend scale', minimum=-4, maximum=4, step=0.1, value=1, elem_id=self.elem_id("mask_blend_offset")) direction = gr.CheckboxGroup(label="Outpainting direction", choices=['left', 'right', 'up', 'down'], value=['left', 'right', 'up', 'down'], elem_id=self.elem_id("direction")) noise_q = gr.Slider(label="Fall-off exponent (lower=higher detail)", minimum=0.0, maximum=4.0, step=0.01, value=1.0, elem_id=self.elem_id("noise_q")) color_variation = gr.Slider(label="Color variation", minimum=0.0, maximum=1.0, step=0.01, value=0.05, elem_id=self.elem_id("color_variation")) - return [info, pixels, mask_blur, direction, noise_q, color_variation] + return [info, pixels, mask_blur, mask_blend_power, mask_blend_scale, mask_blend_offset, direction, noise_q, color_variation] - def run(self, p, _, pixels, mask_blur, direction, noise_q, color_variation): + def run(self, p, _, pixels, mask_blur, mask_blend_power, mask_blend_scale, mask_blend_offset, direction, noise_q, color_variation): initial_seed_and_info = [None, None] process_width = p.width @@ -167,6 +170,9 @@ def run(self, p, _, pixels, mask_blur, direction, noise_q, color_variation): p.mask_blur_x = mask_blur_x*4 p.mask_blur_y = mask_blur_y*4 + p.mask_blend_power = mask_blend_power + p.mask_blend_scale = mask_blend_scale + p.mask_blend_offset = mask_blend_offset init_img = p.init_images[0] target_w = math.ceil((init_img.width + left + right) / 64) * 64 diff --git a/scripts/poor_mans_outpainting.py b/scripts/poor_mans_outpainting.py index ea0632b68c1..b10140f1420 100644 --- a/scripts/poor_mans_outpainting.py +++ b/scripts/poor_mans_outpainting.py @@ -22,16 +22,23 @@ def ui(self, is_img2img): pixels = gr.Slider(label="Pixels to expand", minimum=8, maximum=256, step=8, value=128, elem_id=self.elem_id("pixels")) mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=4, elem_id=self.elem_id("mask_blur")) + mask_blend_power = gr.Slider(label='Mask blend power', minimum=0, maximum=8, step=0.1, value=1, elem_id=self.elem_id("mask_blend_power")) + mask_blend_scale = gr.Slider(label='Mask blend scale', minimum=0, maximum=8, step=0.1, value=1, elem_id=self.elem_id("mask_blend_scale")) + mask_blend_offset = gr.Slider(label='Mask blend offset', minimum=-4, maximum=4, step=0.1, value=0, elem_id=self.elem_id("mask_blend_offset")) inpainting_fill = gr.Radio(label='Masked content', choices=['fill', 'original', 'latent noise', 'latent nothing'], value='fill', type="index", elem_id=self.elem_id("inpainting_fill")) direction = gr.CheckboxGroup(label="Outpainting direction", choices=['left', 'right', 'up', 'down'], value=['left', 'right', 'up', 'down'], elem_id=self.elem_id("direction")) - return [pixels, mask_blur, inpainting_fill, direction] + return [pixels, mask_blur, mask_blend_power, mask_blend_scale, mask_blend_offset, inpainting_fill, direction] - def run(self, p, pixels, mask_blur, inpainting_fill, direction): + def run(self, p, pixels, mask_blur, mask_blend_power, mask_blend_scale, mask_blend_offset, inpainting_fill, direction): initial_seed = None initial_info = None p.mask_blur = mask_blur * 2 + p.mask_blend_power = mask_blend_power + p.mask_blend_scale = mask_blend_scale + p.mask_blend_offset = mask_blend_offset + p.inpainting_fill = inpainting_fill p.inpaint_full_res = False diff --git a/test/test_img2img.py b/test/test_img2img.py index 117d2d1eb45..6289e59e14b 100644 --- a/test/test_img2img.py +++ b/test/test_img2img.py @@ -24,6 +24,9 @@ def simple_img2img_request(img2img_basic_image_base64): "inpainting_mask_invert": False, "mask": None, "mask_blur": 4, + "mask_blend_power": 1, + "mask_blend_scale": 1, + "mask_blend_offset": 0, "n_iter": 1, "negative_prompt": "", "override_settings": {}, From c5c7fa06aae1ae9f8b6d29ae2da3874921d4729b Mon Sep 17 00:00:00 2001 From: CodeHatchling Date: Tue, 28 Nov 2023 22:35:07 -0700 Subject: [PATCH 0332/1174] Added slider for detail preservation strength, removed largely needless offset parameter, changed labels in UI and for saving to/pasting data from PNG files. --- modules/img2img.py | 10 +++++----- modules/processing.py | 2 +- modules/sd_samplers_cfg_denoiser.py | 11 +++++------ modules/sd_samplers_common.py | 2 +- modules/ui.py | 14 +++++++------- scripts/outpainting_mk_2.py | 12 ++++++------ scripts/poor_mans_outpainting.py | 12 ++++++------ test/test_img2img.py | 2 +- 8 files changed, 32 insertions(+), 33 deletions(-) diff --git a/modules/img2img.py b/modules/img2img.py index 240d0588422..023808d6c00 100644 --- a/modules/img2img.py +++ b/modules/img2img.py @@ -134,7 +134,7 @@ def img2img(id_task: str, mask_alpha: float, mask_blend_power: float, mask_blend_scale: float, - mask_blend_offset: float, + inpaint_detail_preservation: float, inpainting_fill: int, n_iter: int, batch_size: int, @@ -216,7 +216,7 @@ def img2img(id_task: str, mask_blur=mask_blur, mask_blend_power=mask_blend_power, mask_blend_scale=mask_blend_scale, - mask_blend_offset=mask_blend_offset, + inpaint_detail_preservation=inpaint_detail_preservation, inpainting_fill=inpainting_fill, resize_mode=resize_mode, denoising_strength=denoising_strength, @@ -237,9 +237,9 @@ def img2img(id_task: str, if mask: p.extra_generation_params["Mask blur"] = mask_blur - p.extra_generation_params["Mask blend power"] = mask_blend_power - p.extra_generation_params["Mask blend scale"] = mask_blend_scale - p.extra_generation_params["Mask blend offset"] = mask_blend_offset + p.extra_generation_params["Mask blending bias"] = mask_blend_power + p.extra_generation_params["Mask blending preservation"] = mask_blend_scale + p.extra_generation_params["Mask blending detail boost"] = inpaint_detail_preservation with closing(p): if is_batch: diff --git a/modules/processing.py b/modules/processing.py index da4d6fda9ff..361e8b05d72 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -1351,7 +1351,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): mask_blur: int = None mask_blend_power: float = 1 mask_blend_scale: float = 1 - mask_blend_offset: float = 0 + inpaint_detail_preservation: float = 16 inpainting_fill: int = 0 inpaint_full_res: bool = True inpaint_full_res_padding: int = 0 diff --git a/modules/sd_samplers_cfg_denoiser.py b/modules/sd_samplers_cfg_denoiser.py index c4d6fda650f..598cd487640 100644 --- a/modules/sd_samplers_cfg_denoiser.py +++ b/modules/sd_samplers_cfg_denoiser.py @@ -45,7 +45,7 @@ def __init__(self, sampler): self.nmask = None self.mask_blend_power = 1 self.mask_blend_scale = 1 - self.mask_blend_offset = 0 + self.inpaint_detail_preservation = 16 self.init_latent = None self.steps = None """number of steps as specified by user in UI""" @@ -105,14 +105,13 @@ def latent_blend(a, b, t): # Record the original latent vector magnitudes. # We bring them to a power so that larger magnitudes are favored over smaller ones. # 64-bit operations are used here to allow large exponents. - detail_preservation = 32 - a_magnitude = torch.norm(a, p=2, dim=1).to(torch.float64) ** detail_preservation - b_magnitude = torch.norm(b, p=2, dim=1).to(torch.float64) ** detail_preservation + a_magnitude = torch.norm(a, p=2, dim=1).to(torch.float64) ** self.inpaint_detail_preservation + b_magnitude = torch.norm(b, p=2, dim=1).to(torch.float64) ** self.inpaint_detail_preservation one_minus_t = 1 - t # Interpolate the powered magnitudes, then un-power them (bring them back to a power of 1). - interp_magnitude = (a_magnitude * one_minus_t + b_magnitude * t) ** (1 / detail_preservation) + interp_magnitude = (a_magnitude * one_minus_t + b_magnitude * t) ** (1 / self.inpaint_detail_preservation) # Linearly interpolate the image vectors. image_interp = a * one_minus_t + b * t @@ -142,7 +141,7 @@ def get_modified_nmask(nmask, _sigma): NOTE: "mask" is not used """ - return torch.pow(nmask, (_sigma ** self.mask_blend_power) * self.mask_blend_scale + self.mask_blend_offset) + return torch.pow(nmask, (_sigma ** self.mask_blend_power) * self.mask_blend_scale) if state.interrupted or state.skipped: raise sd_samplers_common.InterruptedException diff --git a/modules/sd_samplers_common.py b/modules/sd_samplers_common.py index 8904da2fb6b..ecd8ab0a093 100644 --- a/modules/sd_samplers_common.py +++ b/modules/sd_samplers_common.py @@ -279,7 +279,7 @@ def initialize(self, p) -> dict: self.model_wrap_cfg.nmask = p.nmask if hasattr(p, 'nmask') else None self.model_wrap_cfg.mask_blend_power = p.mask_blend_power if hasattr(p, 'mask_blend_power') else None self.model_wrap_cfg.mask_blend_scale = p.mask_blend_scale if hasattr(p, 'mask_blend_scale') else None - self.model_wrap_cfg.mask_blend_offset = p.mask_blend_offset if hasattr(p, 'mask_blend_offset') else None + self.model_wrap_cfg.inpaint_detail_preservation = p.inpaint_detail_preservation if hasattr(p, 'inpaint_detail_preservation') else None self.model_wrap_cfg.step = 0 self.model_wrap_cfg.image_cfg_scale = getattr(p, 'image_cfg_scale', None) self.eta = p.eta if p.eta is not None else getattr(opts, self.eta_option_field, 0.0) diff --git a/modules/ui.py b/modules/ui.py index 86c130869e5..f5e201477a6 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -732,9 +732,9 @@ def copy_image(img): with FormRow(): mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=4, elem_id="img2img_mask_blur") mask_alpha = gr.Slider(label="Mask transparency", visible=False, elem_id="img2img_mask_alpha") - mask_blend_power = gr.Slider(label='Mask blend power', minimum=0, maximum=8, step=0.1, value=1, elem_id="img2img_mask_blend_power") - mask_blend_scale = gr.Slider(label='Mask blend scale', minimum=0, maximum=8, step=0.1, value=1, elem_id="img2img_mask_blend_scale") - mask_blend_offset = gr.Slider(label='Mask blend offset', minimum=-4, maximum=4, step=0.1, value=0, elem_id="img2img_mask_blend_offset") + mask_blend_power = gr.Slider(label='Blending bias', minimum=0, maximum=8, step=0.1, value=1, elem_id="img2img_mask_blend_power") + mask_blend_scale = gr.Slider(label='Blending preservation', minimum=0, maximum=8, step=0.05, value=1, elem_id="img2img_mask_blend_scale") + inpaint_detail_preservation = gr.Slider(label='Blending detail boost', minimum=1, maximum=32, step=0.5, value=16, elem_id="img2img_mask_blend_offset") with FormRow(): inpainting_mask_invert = gr.Radio(label='Mask mode', choices=['Inpaint masked', 'Inpaint not masked'], value='Inpaint masked', type="index", elem_id="img2img_mask_mode") @@ -786,7 +786,7 @@ def select_img2img_tab(tab): mask_alpha, mask_blend_power, mask_blend_scale, - mask_blend_offset, + inpaint_detail_preservation, inpainting_fill, batch_count, batch_size, @@ -885,9 +885,9 @@ def select_img2img_tab(tab): (toprow.ui_styles.dropdown, lambda d: d["Styles array"] if isinstance(d.get("Styles array"), list) else gr.update()), (denoising_strength, "Denoising strength"), (mask_blur, "Mask blur"), - (mask_blend_power, "Mask blend power"), - (mask_blend_scale, "Mask blend scale"), - (mask_blend_offset, "Mask blend offset"), + (mask_blend_power, "Mask blending bias"), + (mask_blend_scale, "Mask blending preservation"), + (inpaint_detail_preservation, "Mask blending detail boost"), *scripts.scripts_img2img.infotext_fields ] parameters_copypaste.add_paste_fields("img2img", init_img, img2img_paste_fields, override_settings) diff --git a/scripts/outpainting_mk_2.py b/scripts/outpainting_mk_2.py index 6aa97edfabe..54d95825aba 100644 --- a/scripts/outpainting_mk_2.py +++ b/scripts/outpainting_mk_2.py @@ -133,16 +133,16 @@ def ui(self, is_img2img): pixels = gr.Slider(label="Pixels to expand", minimum=8, maximum=256, step=8, value=128, elem_id=self.elem_id("pixels")) mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=8, elem_id=self.elem_id("mask_blur")) - mask_blend_power = gr.Slider(label='Mask blend power', minimum=0, maximum=8, step=0.1, value=1, elem_id=self.elem_id("mask_blend_power")) - mask_blend_scale = gr.Slider(label='Mask blend scale', minimum=0, maximum=8, step=0.1, value=1, elem_id=self.elem_id("mask_blend_scale")) - mask_blend_offset = gr.Slider(label='Mask blend scale', minimum=-4, maximum=4, step=0.1, value=1, elem_id=self.elem_id("mask_blend_offset")) + mask_blend_power = gr.Slider(label='Blending bias', minimum=0, maximum=8, step=0.1, value=1, elem_id=self.elem_id("mask_blend_power")) + mask_blend_scale = gr.Slider(label='Blending preservation', minimum=0, maximum=8, step=0.1, value=1, elem_id=self.elem_id("mask_blend_scale")) + inpaint_detail_preservation = gr.Slider(label='Blending detail boost', minimum=1, maximum=32, step=0.5, value=16, elem_id=self.elem_id("inpaint_detail_preservation")) direction = gr.CheckboxGroup(label="Outpainting direction", choices=['left', 'right', 'up', 'down'], value=['left', 'right', 'up', 'down'], elem_id=self.elem_id("direction")) noise_q = gr.Slider(label="Fall-off exponent (lower=higher detail)", minimum=0.0, maximum=4.0, step=0.01, value=1.0, elem_id=self.elem_id("noise_q")) color_variation = gr.Slider(label="Color variation", minimum=0.0, maximum=1.0, step=0.01, value=0.05, elem_id=self.elem_id("color_variation")) - return [info, pixels, mask_blur, mask_blend_power, mask_blend_scale, mask_blend_offset, direction, noise_q, color_variation] + return [info, pixels, mask_blur, mask_blend_power, mask_blend_scale, inpaint_detail_preservation, direction, noise_q, color_variation] - def run(self, p, _, pixels, mask_blur, mask_blend_power, mask_blend_scale, mask_blend_offset, direction, noise_q, color_variation): + def run(self, p, _, pixels, mask_blur, mask_blend_power, mask_blend_scale, inpaint_detail_preservation, direction, noise_q, color_variation): initial_seed_and_info = [None, None] process_width = p.width @@ -172,7 +172,7 @@ def run(self, p, _, pixels, mask_blur, mask_blend_power, mask_blend_scale, mask_ p.mask_blur_y = mask_blur_y*4 p.mask_blend_power = mask_blend_power p.mask_blend_scale = mask_blend_scale - p.mask_blend_offset = mask_blend_offset + p.inpaint_detail_preservation = inpaint_detail_preservation init_img = p.init_images[0] target_w = math.ceil((init_img.width + left + right) / 64) * 64 diff --git a/scripts/poor_mans_outpainting.py b/scripts/poor_mans_outpainting.py index b10140f1420..e3acb3d4701 100644 --- a/scripts/poor_mans_outpainting.py +++ b/scripts/poor_mans_outpainting.py @@ -22,22 +22,22 @@ def ui(self, is_img2img): pixels = gr.Slider(label="Pixels to expand", minimum=8, maximum=256, step=8, value=128, elem_id=self.elem_id("pixels")) mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=4, elem_id=self.elem_id("mask_blur")) - mask_blend_power = gr.Slider(label='Mask blend power', minimum=0, maximum=8, step=0.1, value=1, elem_id=self.elem_id("mask_blend_power")) - mask_blend_scale = gr.Slider(label='Mask blend scale', minimum=0, maximum=8, step=0.1, value=1, elem_id=self.elem_id("mask_blend_scale")) - mask_blend_offset = gr.Slider(label='Mask blend offset', minimum=-4, maximum=4, step=0.1, value=0, elem_id=self.elem_id("mask_blend_offset")) + mask_blend_power = gr.Slider(label='Blending bias', minimum=0, maximum=8, step=0.1, value=1, elem_id=self.elem_id("mask_blend_power")) + mask_blend_scale = gr.Slider(label='Blending preservation', minimum=0, maximum=8, step=0.1, value=1, elem_id=self.elem_id("mask_blend_scale")) + inpaint_detail_preservation = gr.Slider(label='Blending detail boost', minimum=1, maximum=32, step=0.5, value=16, elem_id=self.elem_id("inpaint_detail_preservation")) inpainting_fill = gr.Radio(label='Masked content', choices=['fill', 'original', 'latent noise', 'latent nothing'], value='fill', type="index", elem_id=self.elem_id("inpainting_fill")) direction = gr.CheckboxGroup(label="Outpainting direction", choices=['left', 'right', 'up', 'down'], value=['left', 'right', 'up', 'down'], elem_id=self.elem_id("direction")) - return [pixels, mask_blur, mask_blend_power, mask_blend_scale, mask_blend_offset, inpainting_fill, direction] + return [pixels, mask_blur, mask_blend_power, mask_blend_scale, inpaint_detail_preservation, inpainting_fill, direction] - def run(self, p, pixels, mask_blur, mask_blend_power, mask_blend_scale, mask_blend_offset, inpainting_fill, direction): + def run(self, p, pixels, mask_blur, mask_blend_power, mask_blend_scale, inpaint_detail_preservation, inpainting_fill, direction): initial_seed = None initial_info = None p.mask_blur = mask_blur * 2 p.mask_blend_power = mask_blend_power p.mask_blend_scale = mask_blend_scale - p.mask_blend_offset = mask_blend_offset + p.inpaint_detail_preservation = inpaint_detail_preservation p.inpainting_fill = inpainting_fill p.inpaint_full_res = False diff --git a/test/test_img2img.py b/test/test_img2img.py index 6289e59e14b..88b06eb8d7f 100644 --- a/test/test_img2img.py +++ b/test/test_img2img.py @@ -26,7 +26,7 @@ def simple_img2img_request(img2img_basic_image_base64): "mask_blur": 4, "mask_blend_power": 1, "mask_blend_scale": 1, - "mask_blend_offset": 0, + "inpaint_detail_preservation": 16, "n_iter": 1, "negative_prompt": "", "override_settings": {}, From 284fd8f415ec70e14ae5de0b7f5ce738007a6b7f Mon Sep 17 00:00:00 2001 From: CodeHatchling Date: Tue, 28 Nov 2023 23:03:50 -0700 Subject: [PATCH 0333/1174] Tweaked UI sliders and labels. --- modules/img2img.py | 2 +- modules/ui.py | 6 +++--- scripts/outpainting_mk_2.py | 4 ++-- scripts/poor_mans_outpainting.py | 4 ++-- 4 files changed, 8 insertions(+), 8 deletions(-) diff --git a/modules/img2img.py b/modules/img2img.py index 023808d6c00..0ae1636541a 100644 --- a/modules/img2img.py +++ b/modules/img2img.py @@ -239,7 +239,7 @@ def img2img(id_task: str, p.extra_generation_params["Mask blur"] = mask_blur p.extra_generation_params["Mask blending bias"] = mask_blend_power p.extra_generation_params["Mask blending preservation"] = mask_blend_scale - p.extra_generation_params["Mask blending detail boost"] = inpaint_detail_preservation + p.extra_generation_params["Mask blending contrast boost"] = inpaint_detail_preservation with closing(p): if is_batch: diff --git a/modules/ui.py b/modules/ui.py index f5e201477a6..3a9038b227d 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -733,8 +733,8 @@ def copy_image(img): mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=4, elem_id="img2img_mask_blur") mask_alpha = gr.Slider(label="Mask transparency", visible=False, elem_id="img2img_mask_alpha") mask_blend_power = gr.Slider(label='Blending bias', minimum=0, maximum=8, step=0.1, value=1, elem_id="img2img_mask_blend_power") - mask_blend_scale = gr.Slider(label='Blending preservation', minimum=0, maximum=8, step=0.05, value=1, elem_id="img2img_mask_blend_scale") - inpaint_detail_preservation = gr.Slider(label='Blending detail boost', minimum=1, maximum=32, step=0.5, value=16, elem_id="img2img_mask_blend_offset") + mask_blend_scale = gr.Slider(label='Blending preservation', minimum=0, maximum=8, step=0.05, value=0.5, elem_id="img2img_mask_blend_scale") + inpaint_detail_preservation = gr.Slider(label='Blending contrast boost', minimum=1, maximum=32, step=0.5, value=4, elem_id="img2img_mask_blend_offset") with FormRow(): inpainting_mask_invert = gr.Radio(label='Mask mode', choices=['Inpaint masked', 'Inpaint not masked'], value='Inpaint masked', type="index", elem_id="img2img_mask_mode") @@ -887,7 +887,7 @@ def select_img2img_tab(tab): (mask_blur, "Mask blur"), (mask_blend_power, "Mask blending bias"), (mask_blend_scale, "Mask blending preservation"), - (inpaint_detail_preservation, "Mask blending detail boost"), + (inpaint_detail_preservation, "Mask blending contrast boost"), *scripts.scripts_img2img.infotext_fields ] parameters_copypaste.add_paste_fields("img2img", init_img, img2img_paste_fields, override_settings) diff --git a/scripts/outpainting_mk_2.py b/scripts/outpainting_mk_2.py index 54d95825aba..bd9cb61bf82 100644 --- a/scripts/outpainting_mk_2.py +++ b/scripts/outpainting_mk_2.py @@ -134,8 +134,8 @@ def ui(self, is_img2img): pixels = gr.Slider(label="Pixels to expand", minimum=8, maximum=256, step=8, value=128, elem_id=self.elem_id("pixels")) mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=8, elem_id=self.elem_id("mask_blur")) mask_blend_power = gr.Slider(label='Blending bias', minimum=0, maximum=8, step=0.1, value=1, elem_id=self.elem_id("mask_blend_power")) - mask_blend_scale = gr.Slider(label='Blending preservation', minimum=0, maximum=8, step=0.1, value=1, elem_id=self.elem_id("mask_blend_scale")) - inpaint_detail_preservation = gr.Slider(label='Blending detail boost', minimum=1, maximum=32, step=0.5, value=16, elem_id=self.elem_id("inpaint_detail_preservation")) + mask_blend_scale = gr.Slider(label='Blending preservation', minimum=0, maximum=8, step=0.05, value=0.5, elem_id=self.elem_id("mask_blend_scale")) + inpaint_detail_preservation = gr.Slider(label='Blending contrast boost', minimum=1, maximum=32, step=0.5, value=4, elem_id=self.elem_id("inpaint_detail_preservation")) direction = gr.CheckboxGroup(label="Outpainting direction", choices=['left', 'right', 'up', 'down'], value=['left', 'right', 'up', 'down'], elem_id=self.elem_id("direction")) noise_q = gr.Slider(label="Fall-off exponent (lower=higher detail)", minimum=0.0, maximum=4.0, step=0.01, value=1.0, elem_id=self.elem_id("noise_q")) color_variation = gr.Slider(label="Color variation", minimum=0.0, maximum=1.0, step=0.01, value=0.05, elem_id=self.elem_id("color_variation")) diff --git a/scripts/poor_mans_outpainting.py b/scripts/poor_mans_outpainting.py index e3acb3d4701..5388f5db462 100644 --- a/scripts/poor_mans_outpainting.py +++ b/scripts/poor_mans_outpainting.py @@ -23,8 +23,8 @@ def ui(self, is_img2img): pixels = gr.Slider(label="Pixels to expand", minimum=8, maximum=256, step=8, value=128, elem_id=self.elem_id("pixels")) mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=4, elem_id=self.elem_id("mask_blur")) mask_blend_power = gr.Slider(label='Blending bias', minimum=0, maximum=8, step=0.1, value=1, elem_id=self.elem_id("mask_blend_power")) - mask_blend_scale = gr.Slider(label='Blending preservation', minimum=0, maximum=8, step=0.1, value=1, elem_id=self.elem_id("mask_blend_scale")) - inpaint_detail_preservation = gr.Slider(label='Blending detail boost', minimum=1, maximum=32, step=0.5, value=16, elem_id=self.elem_id("inpaint_detail_preservation")) + mask_blend_scale = gr.Slider(label='Blending preservation', minimum=0, maximum=8, step=0.05, value=0.5, elem_id=self.elem_id("mask_blend_scale")) + inpaint_detail_preservation = gr.Slider(label='Blending contrast boost', minimum=1, maximum=32, step=0.5, value=4, elem_id=self.elem_id("inpaint_detail_preservation")) inpainting_fill = gr.Radio(label='Masked content', choices=['fill', 'original', 'latent noise', 'latent nothing'], value='fill', type="index", elem_id=self.elem_id("inpainting_fill")) direction = gr.CheckboxGroup(label="Outpainting direction", choices=['left', 'right', 'up', 'down'], value=['left', 'right', 'up', 'down'], elem_id=self.elem_id("direction")) From c7a1ff87207544dd4bcf3aefffa67a4a38678c16 Mon Sep 17 00:00:00 2001 From: CodeHatchling Date: Tue, 28 Nov 2023 23:31:10 -0700 Subject: [PATCH 0334/1174] Tweaked default values. --- modules/processing.py | 4 ++-- modules/sd_samplers_cfg_denoiser.py | 4 ++-- test/test_img2img.py | 4 ++-- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/modules/processing.py b/modules/processing.py index 361e8b05d72..92fdebadd98 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -1350,8 +1350,8 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): mask_blur_y: int = 4 mask_blur: int = None mask_blend_power: float = 1 - mask_blend_scale: float = 1 - inpaint_detail_preservation: float = 16 + mask_blend_scale: float = 0.5 + inpaint_detail_preservation: float = 4 inpainting_fill: int = 0 inpaint_full_res: bool = True inpaint_full_res_padding: int = 0 diff --git a/modules/sd_samplers_cfg_denoiser.py b/modules/sd_samplers_cfg_denoiser.py index 598cd487640..ceb612d7936 100644 --- a/modules/sd_samplers_cfg_denoiser.py +++ b/modules/sd_samplers_cfg_denoiser.py @@ -44,8 +44,8 @@ def __init__(self, sampler): self.mask = None self.nmask = None self.mask_blend_power = 1 - self.mask_blend_scale = 1 - self.inpaint_detail_preservation = 16 + self.mask_blend_scale = 0.5 + self.inpaint_detail_preservation = 4 self.init_latent = None self.steps = None """number of steps as specified by user in UI""" diff --git a/test/test_img2img.py b/test/test_img2img.py index 88b06eb8d7f..5cda2dbaefe 100644 --- a/test/test_img2img.py +++ b/test/test_img2img.py @@ -25,8 +25,8 @@ def simple_img2img_request(img2img_basic_image_base64): "mask": None, "mask_blur": 4, "mask_blend_power": 1, - "mask_blend_scale": 1, - "inpaint_detail_preservation": 16, + "mask_blend_scale": 0.5, + "inpaint_detail_preservation": 4, "n_iter": 1, "negative_prompt": "", "override_settings": {}, From b25c126ccdbc4da22ade46597a9addf808998989 Mon Sep 17 00:00:00 2001 From: drhead <1313496+drhead@users.noreply.github.com> Date: Wed, 29 Nov 2023 17:38:53 -0500 Subject: [PATCH 0335/1174] Protect alphas_cumprod from downcasting --- modules/sd_models.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/modules/sd_models.py b/modules/sd_models.py index 841402e8629..de80a493a6f 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -387,7 +387,11 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer if shared.cmd_opts.upcast_sampling and depth_model: model.depth_model = None + alphas_cumprod = model.alphas_cumprod + model.alphas_cumprod = None model.half() + model.alphas_cumprod = alphas_cumprod + model.alphas_cumprod_original = alphas_cumprod model.first_stage_model = vae if depth_model: model.depth_model = depth_model @@ -642,6 +646,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None): else: weight_dtype_conversion = { 'first_stage_model': None, + 'alphas_cumprod': None, '': torch.float16, } From 588a52891dca4d030ca7028dd9c0b56022a68b57 Mon Sep 17 00:00:00 2001 From: drhead <1313496+drhead@users.noreply.github.com> Date: Wed, 29 Nov 2023 17:40:23 -0500 Subject: [PATCH 0336/1174] Add options for zero terminal SNR --- modules/shared_options.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/modules/shared_options.py b/modules/shared_options.py index 04e68a712ef..51596777044 100644 --- a/modules/shared_options.py +++ b/modules/shared_options.py @@ -218,6 +218,7 @@ "dont_fix_second_order_samplers_schedule": OptionInfo(False, "Do not fix prompt schedule for second order samplers."), "hires_fix_use_firstpass_conds": OptionInfo(False, "For hires fix, calculate conds of second pass using extra networks of first pass."), "use_old_scheduling": OptionInfo(False, "Use old prompt editing timelines.", infotext="Old prompt editing timelines").info("For [red:green:N]; old: If N < 1, it's a fraction of steps (and hires fix uses range from 0 to 1), if N >= 1, it's an absolute number of steps; new: If N has a decimal point in it, it's a fraction of steps (and hires fix uses range from 1 to 2), othewrwise it's an absolute number of steps"), + "use_downcasted_alpha_bar": OptionInfo(False, "Downcast model alphas_cumprod to fp16 before sampling. For reproducing old seeds.") })) options_templates.update(options_section(('interrogate', "Interrogate"), { @@ -335,6 +336,7 @@ 'uni_pc_skip_type': OptionInfo("time_uniform", "UniPC skip type", gr.Radio, {"choices": ["time_uniform", "time_quadratic", "logSNR"]}, infotext='UniPC skip type'), 'uni_pc_order': OptionInfo(3, "UniPC order", gr.Slider, {"minimum": 1, "maximum": 50, "step": 1}, infotext='UniPC order').info("must be < sampling steps"), 'uni_pc_lower_order_final': OptionInfo(True, "UniPC lower order final", infotext='UniPC lower order final'), + 'sd_noise_schedule': OptionInfo("Default", "Noise schedule for sampling", gr.Radio, {"choices": ["Default", "Zero Terminal SNR"]}, infotext="Noise schedule for sampling").info("for use with zero terminal SNR trained models") })) options_templates.update(options_section(('postprocessing', "Postprocessing", "postprocessing"), { From 6d0a8dcd892f7ad9b399fed6edbad6ede13c5f69 Mon Sep 17 00:00:00 2001 From: drhead <1313496+drhead@users.noreply.github.com> Date: Wed, 29 Nov 2023 17:42:07 -0500 Subject: [PATCH 0337/1174] Implement zero terminal SNR schedule option --- modules/processing.py | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/modules/processing.py b/modules/processing.py index ac58ef86940..c88eec70589 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -863,6 +863,34 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: if p.n_iter > 1: shared.state.job = f"Batch {n+1} out of {p.n_iter}" + + def rescale_zero_terminal_snr_abar(alphas_cumprod): + alphas_bar_sqrt = alphas_cumprod.sqrt() + + # Store old values. + alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone() + alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone() + + # Shift so the last timestep is zero. + alphas_bar_sqrt -= (alphas_bar_sqrt_T) + + # Scale so the first timestep is back to the old value. + alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T) + + # Convert alphas_bar_sqrt to betas + alphas_bar = alphas_bar_sqrt**2 # Revert sqrt + alphas_bar[-1] = 4.8973451890853435e-08 + return alphas_bar + + p.sd_model.alphas_cumprod = p.sd_model.alphas_cumprod_original.to(shared.device) + + if opts.use_downcasted_alpha_bar: + p.extra_generation_params['Downcast alphas_cumprod'] = opts.use_downcasted_alpha_bar + p.sd_model.alphas_cumprod = p.sd_model.alphas_cumprod.half().to(shared.device) + if opts.sd_noise_schedule == "Zero Terminal SNR": + p.extra_generation_params['Noise Schedule'] = opts.sd_noise_schedule + print("rescaling noise schedule for zero snr") + p.sd_model.alphas_cumprod = rescale_zero_terminal_snr_abar(p.sd_model.alphas_cumprod).to(shared.device) with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast(): samples_ddim = p.sample(conditioning=p.c, unconditional_conditioning=p.uc, seeds=p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, prompts=p.prompts) From ec6ee5c13bf3453f8703e225a191333a9bbcf10a Mon Sep 17 00:00:00 2001 From: catboxanon <122327233+catboxanon@users.noreply.github.com> Date: Wed, 29 Nov 2023 18:10:27 -0500 Subject: [PATCH 0338/1174] Fix infotext for ztSNR --- modules/shared_options.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/modules/shared_options.py b/modules/shared_options.py index 51596777044..bc3d56dec60 100644 --- a/modules/shared_options.py +++ b/modules/shared_options.py @@ -218,7 +218,7 @@ "dont_fix_second_order_samplers_schedule": OptionInfo(False, "Do not fix prompt schedule for second order samplers."), "hires_fix_use_firstpass_conds": OptionInfo(False, "For hires fix, calculate conds of second pass using extra networks of first pass."), "use_old_scheduling": OptionInfo(False, "Use old prompt editing timelines.", infotext="Old prompt editing timelines").info("For [red:green:N]; old: If N < 1, it's a fraction of steps (and hires fix uses range from 0 to 1), if N >= 1, it's an absolute number of steps; new: If N has a decimal point in it, it's a fraction of steps (and hires fix uses range from 1 to 2), othewrwise it's an absolute number of steps"), - "use_downcasted_alpha_bar": OptionInfo(False, "Downcast model alphas_cumprod to fp16 before sampling. For reproducing old seeds.") + "use_downcasted_alpha_bar": OptionInfo(False, "Downcast model alphas_cumprod to fp16 before sampling. For reproducing old seeds.", infotext="Downcast alphas_cumprod") })) options_templates.update(options_section(('interrogate', "Interrogate"), { @@ -336,7 +336,7 @@ 'uni_pc_skip_type': OptionInfo("time_uniform", "UniPC skip type", gr.Radio, {"choices": ["time_uniform", "time_quadratic", "logSNR"]}, infotext='UniPC skip type'), 'uni_pc_order': OptionInfo(3, "UniPC order", gr.Slider, {"minimum": 1, "maximum": 50, "step": 1}, infotext='UniPC order').info("must be < sampling steps"), 'uni_pc_lower_order_final': OptionInfo(True, "UniPC lower order final", infotext='UniPC lower order final'), - 'sd_noise_schedule': OptionInfo("Default", "Noise schedule for sampling", gr.Radio, {"choices": ["Default", "Zero Terminal SNR"]}, infotext="Noise schedule for sampling").info("for use with zero terminal SNR trained models") + 'sd_noise_schedule': OptionInfo("Default", "Noise schedule for sampling", gr.Radio, {"choices": ["Default", "Zero Terminal SNR"]}, infotext="Noise Schedule").info("for use with zero terminal SNR trained models") })) options_templates.update(options_section(('postprocessing', "Postprocessing", "postprocessing"), { From ffa7f8201d849636bb327b3b40298e7c169ff204 Mon Sep 17 00:00:00 2001 From: catboxanon <122327233+catboxanon@users.noreply.github.com> Date: Wed, 29 Nov 2023 18:10:43 -0500 Subject: [PATCH 0339/1174] Lint --- modules/processing.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/modules/processing.py b/modules/processing.py index c88eec70589..f3883d5b3c2 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -863,7 +863,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: if p.n_iter > 1: shared.state.job = f"Batch {n+1} out of {p.n_iter}" - + def rescale_zero_terminal_snr_abar(alphas_cumprod): alphas_bar_sqrt = alphas_cumprod.sqrt() @@ -881,7 +881,7 @@ def rescale_zero_terminal_snr_abar(alphas_cumprod): alphas_bar = alphas_bar_sqrt**2 # Revert sqrt alphas_bar[-1] = 4.8973451890853435e-08 return alphas_bar - + p.sd_model.alphas_cumprod = p.sd_model.alphas_cumprod_original.to(shared.device) if opts.use_downcasted_alpha_bar: From de79597ab9894965e3702939b8536ec3dcc3c859 Mon Sep 17 00:00:00 2001 From: catboxanon <122327233+catboxanon@users.noreply.github.com> Date: Wed, 29 Nov 2023 18:33:32 -0500 Subject: [PATCH 0340/1174] Only apply ztSNR related code if alphas_cumprod exists --- modules/processing.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/modules/processing.py b/modules/processing.py index f3883d5b3c2..7e73d7e2988 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -882,15 +882,16 @@ def rescale_zero_terminal_snr_abar(alphas_cumprod): alphas_bar[-1] = 4.8973451890853435e-08 return alphas_bar - p.sd_model.alphas_cumprod = p.sd_model.alphas_cumprod_original.to(shared.device) - - if opts.use_downcasted_alpha_bar: - p.extra_generation_params['Downcast alphas_cumprod'] = opts.use_downcasted_alpha_bar - p.sd_model.alphas_cumprod = p.sd_model.alphas_cumprod.half().to(shared.device) - if opts.sd_noise_schedule == "Zero Terminal SNR": - p.extra_generation_params['Noise Schedule'] = opts.sd_noise_schedule - print("rescaling noise schedule for zero snr") - p.sd_model.alphas_cumprod = rescale_zero_terminal_snr_abar(p.sd_model.alphas_cumprod).to(shared.device) + if hasattr(p.sd_model, 'alphas_cumprod') and hasattr(p.sd_model, 'alphas_cumprod_original'): + p.sd_model.alphas_cumprod = p.sd_model.alphas_cumprod_original.to(shared.device) + + if opts.use_downcasted_alpha_bar: + p.extra_generation_params['Downcast alphas_cumprod'] = opts.use_downcasted_alpha_bar + p.sd_model.alphas_cumprod = p.sd_model.alphas_cumprod.half().to(shared.device) + if opts.sd_noise_schedule == "Zero Terminal SNR": + p.extra_generation_params['Noise Schedule'] = opts.sd_noise_schedule + print("rescaling noise schedule for zero snr") + p.sd_model.alphas_cumprod = rescale_zero_terminal_snr_abar(p.sd_model.alphas_cumprod).to(shared.device) with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast(): samples_ddim = p.sample(conditioning=p.c, unconditional_conditioning=p.uc, seeds=p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, prompts=p.prompts) From 668ae34e21df848ef4909b8b49c4142a3674701b Mon Sep 17 00:00:00 2001 From: drhead <1313496+drhead@users.noreply.github.com> Date: Wed, 29 Nov 2023 22:48:31 -0500 Subject: [PATCH 0341/1174] remove debug print --- modules/processing.py | 1 - 1 file changed, 1 deletion(-) diff --git a/modules/processing.py b/modules/processing.py index 7e73d7e2988..d73c8bfc1c9 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -890,7 +890,6 @@ def rescale_zero_terminal_snr_abar(alphas_cumprod): p.sd_model.alphas_cumprod = p.sd_model.alphas_cumprod.half().to(shared.device) if opts.sd_noise_schedule == "Zero Terminal SNR": p.extra_generation_params['Noise Schedule'] = opts.sd_noise_schedule - print("rescaling noise schedule for zero snr") p.sd_model.alphas_cumprod = rescale_zero_terminal_snr_abar(p.sd_model.alphas_cumprod).to(shared.device) with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast(): From 8b40f475a31109cc6ecbdc0d14a0cee9e0303291 Mon Sep 17 00:00:00 2001 From: Nuullll Date: Fri, 10 Nov 2023 11:06:26 +0800 Subject: [PATCH 0342/1174] Initial IPEX support --- modules/devices.py | 11 +++++++++-- modules/xpu_specific.py | 42 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 51 insertions(+), 2 deletions(-) create mode 100644 modules/xpu_specific.py diff --git a/modules/devices.py b/modules/devices.py index 1d4eb563517..be599736cba 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -3,7 +3,7 @@ from functools import lru_cache import torch -from modules import errors, shared +from modules import errors, shared, xpu_specific if sys.platform == "darwin": from modules import mac_specific @@ -30,6 +30,9 @@ def get_optimal_device_name(): if has_mps(): return "mps" + if xpu_specific.has_ipex: + return xpu_specific.get_xpu_device_string() + return "cpu" @@ -100,11 +103,15 @@ def autocast(disable=False): if dtype == torch.float32 or shared.cmd_opts.precision == "full": return contextlib.nullcontext() + if xpu_specific.has_xpu: + return torch.autocast("xpu") + return torch.autocast("cuda") def without_autocast(disable=False): - return torch.autocast("cuda", enabled=False) if torch.is_autocast_enabled() and not disable else contextlib.nullcontext() + device_type = "xpu" if xpu_specific.has_xpu else "cuda" + return torch.autocast(device_type, enabled=False) if torch.is_autocast_enabled() and not disable else contextlib.nullcontext() class NansException(Exception): diff --git a/modules/xpu_specific.py b/modules/xpu_specific.py new file mode 100644 index 00000000000..6417dd2d659 --- /dev/null +++ b/modules/xpu_specific.py @@ -0,0 +1,42 @@ +import contextlib +from modules import shared +from modules.sd_hijack_utils import CondFunc + +has_ipex = False +try: + import torch + import intel_extension_for_pytorch as ipex + has_ipex = True +except Exception: + pass + +def check_for_xpu(): + if not has_ipex: + return False + + return hasattr(torch, 'xpu') and torch.xpu.is_available() + +has_xpu = check_for_xpu() + +def get_xpu_device_string(): + if shared.cmd_opts.device_id is not None: + return f"xpu:{shared.cmd_opts.device_id}" + return "xpu" + +def return_null_context(*args, **kwargs): # pylint: disable=unused-argument + return contextlib.nullcontext() + +if has_xpu: + CondFunc('torch.Generator', + lambda orig_func, device=None: torch.xpu.Generator(device), + lambda orig_func, device=None: device is not None and device != torch.device("cpu") and device != "cpu") + + CondFunc('torch.nn.functional.layer_norm', + lambda orig_func, input, normalized_shape=None, weight=None, *args, **kwargs: + orig_func(input.to(weight.data.dtype), normalized_shape, weight, *args, **kwargs), + lambda orig_func, input, normalized_shape=None, weight=None, *args, **kwargs: + weight is not None and input.dtype != weight.data.dtype) + + CondFunc('torch.nn.modules.GroupNorm.forward', + lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)), + lambda orig_func, self, input: input.dtype != self.weight.data.dtype) From c2ed4132037a32cda856e8ba6e2cda32b44b9784 Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Fri, 1 Dec 2023 02:59:41 +0900 Subject: [PATCH 0343/1174] add max-heigh/width to global-popup-inner prevent the pop-up from being too big as to making exiting the pop-up impossible --- style.css | 2 ++ 1 file changed, 2 insertions(+) diff --git a/style.css b/style.css index 6e3ca84113e..ee39a57b73e 100644 --- a/style.css +++ b/style.css @@ -646,6 +646,8 @@ table.popup-table .link{ margin: auto; padding: 2em; z-index: 1001; + max-height: 90%; + max-width: 90%; } /* fullpage image viewer */ From 01c8f1803a77c63b2ebfd3cbbd41659fb914f274 Mon Sep 17 00:00:00 2001 From: missionfloyd Date: Thu, 30 Nov 2023 22:36:12 -0700 Subject: [PATCH 0344/1174] Close popups with escape key --- javascript/extraNetworks.js | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/javascript/extraNetworks.js b/javascript/extraNetworks.js index a787372cff3..98a7abb745c 100644 --- a/javascript/extraNetworks.js +++ b/javascript/extraNetworks.js @@ -392,3 +392,9 @@ function extraNetworksRefreshSingleCard(page, tabname, name) { } }); } + +window.addEventListener("keydown", function(event) { + if (event.key == "Escape") { + closePopup(); + } +}); From 293f44e6c1de7bbf744a4236db81ac4559bdb82a Mon Sep 17 00:00:00 2001 From: MrCheeze Date: Fri, 1 Dec 2023 22:56:08 -0500 Subject: [PATCH 0345/1174] Fix bug where is_using_v_parameterization_for_sd2 fails because the sd_hijack is only partially undone --- modules/sd_hijack.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 0157e19f003..3d340fc9b21 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -38,9 +38,6 @@ optimizers = [] current_optimizer: sd_hijack_optimizations.SdOptimization = None -ldm_original_forward = patches.patch(__file__, ldm.modules.diffusionmodules.openaimodel.UNetModel, "forward", sd_unet.UNetModel_forward) -sgm_original_forward = patches.patch(__file__, sgm.modules.diffusionmodules.openaimodel.UNetModel, "forward", sd_unet.UNetModel_forward) - def list_optimizers(): new_optimizers = script_callbacks.list_optimizers_callback() @@ -258,6 +255,9 @@ def flatten(el): import modules.models.diffusion.ddpm_edit + ldm_original_forward = patches.patch(__file__, ldm.modules.diffusionmodules.openaimodel.UNetModel, "forward", sd_unet.UNetModel_forward) + sgm_original_forward = patches.patch(__file__, sgm.modules.diffusionmodules.openaimodel.UNetModel, "forward", sd_unet.UNetModel_forward) + if isinstance(m, ldm.models.diffusion.ddpm.LatentDiffusion): sd_unet.original_forward = ldm_original_forward elif isinstance(m, modules.models.diffusion.ddpm_edit.LatentDiffusion): @@ -303,6 +303,9 @@ def undo_hijack(self, m): self.layers = None self.clip = None + patches.undo(__file__, ldm.modules.diffusionmodules.openaimodel.UNetModel, "forward") + patches.undo(__file__, sgm.modules.diffusionmodules.openaimodel.UNetModel, "forward") + sd_unet.original_forward = None From 6080045b2a0964e63bdcd33dd26015f8a51411f6 Mon Sep 17 00:00:00 2001 From: MrCheeze Date: Fri, 1 Dec 2023 22:58:05 -0500 Subject: [PATCH 0346/1174] Add support for SD 2.1 Turbo, by converting the state dict from SGM to LDM on load --- modules/sd_models.py | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/modules/sd_models.py b/modules/sd_models.py index 841402e8629..9355f1e16b7 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -230,15 +230,19 @@ def select_checkpoint(): return checkpoint_info -checkpoint_dict_replacements = { +checkpoint_dict_replacements_sd1 = { 'cond_stage_model.transformer.embeddings.': 'cond_stage_model.transformer.text_model.embeddings.', 'cond_stage_model.transformer.encoder.': 'cond_stage_model.transformer.text_model.encoder.', 'cond_stage_model.transformer.final_layer_norm.': 'cond_stage_model.transformer.text_model.final_layer_norm.', } +checkpoint_dict_replacements_sd2_turbo = { # Converts SD 2.1 Turbo from SGM to LDM format. + 'conditioner.embedders.0.': 'cond_stage_model.', +} + -def transform_checkpoint_dict_key(k): - for text, replacement in checkpoint_dict_replacements.items(): +def transform_checkpoint_dict_key(k, replacements): + for text, replacement in replacements.items(): if k.startswith(text): k = replacement + k[len(text):] @@ -249,9 +253,14 @@ def get_state_dict_from_checkpoint(pl_sd): pl_sd = pl_sd.pop("state_dict", pl_sd) pl_sd.pop("state_dict", None) + is_sd2_turbo = 'conditioner.embedders.0.model.ln_final.weight' in pl_sd and pl_sd['conditioner.embedders.0.model.ln_final.weight'].size()[0] == 1024 + sd = {} for k, v in pl_sd.items(): - new_key = transform_checkpoint_dict_key(k) + if is_sd2_turbo: + new_key = transform_checkpoint_dict_key(k, checkpoint_dict_replacements_sd2_turbo) + else: + new_key = transform_checkpoint_dict_key(k, checkpoint_dict_replacements_sd1) if new_key is not None: sd[new_key] = v From b58d061e41cba6fb91910d310d53e175d0511650 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sat, 2 Dec 2023 08:33:28 +0300 Subject: [PATCH 0347/1174] infotext updates: add option to disregard certain infotext fields, add option to not include VAE in infotext, add explanation to infotext settings page, move some options to infotext settings page --- modules/generation_parameters_copypaste.py | 13 +++++++++---- modules/processing.py | 4 ++-- modules/shared_items.py | 16 ++++++++++++++++ modules/shared_options.py | 20 ++++++++++++++------ 4 files changed, 41 insertions(+), 12 deletions(-) diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py index 0a606515b79..4efe53e0c48 100644 --- a/modules/generation_parameters_copypaste.py +++ b/modules/generation_parameters_copypaste.py @@ -1,3 +1,4 @@ +from __future__ import annotations import base64 import io import json @@ -15,9 +16,6 @@ re_hypernet_hash = re.compile("\(([0-9a-f]+)\)$") type_of_gr_update = type(gr.update()) -paste_fields = {} -registered_param_bindings = [] - class ParamBinding: def __init__(self, paste_button, tabname, source_text_component=None, source_image_component=None, source_tabname=None, override_settings_component=None, paste_field_names=None): @@ -30,6 +28,10 @@ def __init__(self, paste_button, tabname, source_text_component=None, source_ima self.paste_field_names = paste_field_names or [] +paste_fields: dict[str, dict] = {} +registered_param_bindings: list[ParamBinding] = [] + + def reset(): paste_fields.clear() registered_param_bindings.clear() @@ -113,7 +115,6 @@ def register_paste_params_button(binding: ParamBinding): def connect_paste_params_buttons(): - binding: ParamBinding for binding in registered_param_bindings: destination_image_component = paste_fields[binding.tabname]["init_img"] fields = paste_fields[binding.tabname]["fields"] @@ -313,6 +314,9 @@ def parse_generation_parameters(x: str): if "VAE Decoder" not in res: res["VAE Decoder"] = "Full" + skip = set(shared.opts.infotext_skip_pasting) + res = {k: v for k, v in res.items() if k not in skip} + return res @@ -443,3 +447,4 @@ def paste_settings(params): outputs=[], show_progress=False, ) + diff --git a/modules/processing.py b/modules/processing.py index ac58ef86940..5ab6dddef39 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -679,8 +679,8 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter "Size": f"{p.width}x{p.height}", "Model hash": p.sd_model_hash if opts.add_model_hash_to_info else None, "Model": p.sd_model_name if opts.add_model_name_to_info else None, - "VAE hash": p.sd_vae_hash if opts.add_model_hash_to_info else None, - "VAE": p.sd_vae_name if opts.add_model_name_to_info else None, + "VAE hash": p.sd_vae_hash if opts.add_vae_hash_to_info else None, + "VAE": p.sd_vae_name if opts.add_vae_name_to_info else None, "Variation seed": (None if p.subseed_strength == 0 else (p.all_subseeds[0] if use_main_prompt else all_subseeds[index])), "Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength), "Seed resize from": (None if p.seed_resize_from_w <= 0 or p.seed_resize_from_h <= 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"), diff --git a/modules/shared_items.py b/modules/shared_items.py index 5024b4268d0..991971ad0fb 100644 --- a/modules/shared_items.py +++ b/modules/shared_items.py @@ -66,6 +66,22 @@ def reload_hypernetworks(): shared.hypernetworks = hypernetwork.list_hypernetworks(cmd_opts.hypernetwork_dir) +def get_infotext_names(): + from modules import generation_parameters_copypaste, shared + res = {} + + for info in shared.opts.data_labels.values(): + if info.infotext: + res[info.infotext] = 1 + + for tab_data in generation_parameters_copypaste.paste_fields.values(): + for _, name in tab_data.get("fields") or []: + if isinstance(name, str): + res[name] = 1 + + return list(res) + + ui_reorder_categories_builtin_items = [ "prompt", "image", diff --git a/modules/shared_options.py b/modules/shared_options.py index 04e68a712ef..df45fc0a0c9 100644 --- a/modules/shared_options.py +++ b/modules/shared_options.py @@ -46,8 +46,6 @@ "grid_text_inactive_color": OptionInfo("#999999", "Inactive text color for image grids", ui_components.FormColorPicker, {}), "grid_background_color": OptionInfo("#ffffff", "Background color for image grids", ui_components.FormColorPicker, {}), - "enable_pnginfo": OptionInfo(True, "Save text information about generation parameters as chunks to png files"), - "save_txt": OptionInfo(False, "Create a text file next to every image with generation parameters."), "save_images_before_face_restoration": OptionInfo(False, "Save a copy of image before doing face restoration."), "save_images_before_highres_fix": OptionInfo(False, "Save a copy of image before applying highres fix."), "save_images_before_color_correction": OptionInfo(False, "Save a copy of image before applying color correction to img2img results"), @@ -288,11 +286,21 @@ options_templates.update(options_section(('infotext', "Infotext", "ui"), { - "add_model_hash_to_info": OptionInfo(True, "Add model hash to generation information"), - "add_model_name_to_info": OptionInfo(True, "Add model name to generation information"), - "add_user_name_to_info": OptionInfo(False, "Add user name to generation information when authenticated"), - "add_version_to_infotext": OptionInfo(True, "Add program version to generation information"), + "infotext_explanation": OptionHTML(""" +Infotext is what this software calls the text that contains generation parameters and can be used to generate the same picture again. +It is displayed in UI below the image. To use infotext, paste it into the prompt and click the ↙️ paste button. +"""), + "enable_pnginfo": OptionInfo(True, "Write infotext to metadata of the generated image"), + "save_txt": OptionInfo(False, "Create a text file with infotext next to every generated image"), + + "add_model_name_to_info": OptionInfo(True, "Add model name to infotext"), + "add_model_hash_to_info": OptionInfo(True, "Add model hash to infotext"), + "add_vae_name_to_info": OptionInfo(True, "Add VAE name to infotext"), + "add_vae_hash_to_info": OptionInfo(True, "Add VAE hash to infotext"), + "add_user_name_to_info": OptionInfo(False, "Add user name to infotext when authenticated"), + "add_version_to_infotext": OptionInfo(True, "Add program version to infotext"), "disable_weights_auto_swap": OptionInfo(True, "Disregard checkpoint information from pasted infotext").info("when reading generation parameters from text into UI"), + "infotext_skip_pasting": OptionInfo([], "Disregard fields from pasted infotext", ui_components.DropdownMulti, lambda: {"choices": shared_items.get_infotext_names()}), "infotext_styles": OptionInfo("Apply if any", "Infer styles from prompts of pasted infotext", gr.Radio, {"choices": ["Ignore", "Apply", "Discard", "Apply if any"]}).info("when reading generation parameters from text into UI)").html("""
  • Ignore: keep prompt and styles dropdown as it is.
  • Apply: remove style text from prompt, always replace styles dropdown value with found styles (even if none are found).
  • From 7499148ad4dbd3444215c843d02453f68c459707 Mon Sep 17 00:00:00 2001 From: Nuullll Date: Sat, 2 Dec 2023 14:00:46 +0800 Subject: [PATCH 0348/1174] Disable ipex autocast due to its bad perf --- modules/cmd_args.py | 1 + modules/devices.py | 20 +++++++++++++------- modules/xpu_specific.py | 28 ++++++++++++++++++---------- webui-ipex-user.bat | 19 +++++++++++++++++++ 4 files changed, 51 insertions(+), 17 deletions(-) create mode 100644 webui-ipex-user.bat diff --git a/modules/cmd_args.py b/modules/cmd_args.py index a9fb9bfa3ef..da93eb2669f 100644 --- a/modules/cmd_args.py +++ b/modules/cmd_args.py @@ -70,6 +70,7 @@ parser.add_argument("--disable-opt-split-attention", action='store_true', help="prefer no cross-attention layer optimization for automatic choice of optimization") parser.add_argument("--disable-nan-check", action='store_true', help="do not check if produced images/latent spaces have nans; useful for running without a checkpoint in CI") parser.add_argument("--use-cpu", nargs='+', help="use CPU as torch device for specified modules", default=[], type=str.lower) +parser.add_argument("--use-ipex", action="store_true", help="use Intel XPU as torch device") parser.add_argument("--disable-model-loading-ram-optimization", action='store_true', help="disable an optimization that reduces RAM use when loading a model") parser.add_argument("--listen", action='store_true', help="launch gradio with 0.0.0.0 as server name, allowing to respond to network requests") parser.add_argument("--port", type=int, help="launch gradio with given server port, you need root/admin rights for ports < 1024, defaults to 7860 if available", default=None) diff --git a/modules/devices.py b/modules/devices.py index be599736cba..37ecca78430 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -3,11 +3,18 @@ from functools import lru_cache import torch -from modules import errors, shared, xpu_specific +from modules import errors, shared if sys.platform == "darwin": from modules import mac_specific +if shared.cmd_opts.use_ipex: + from modules import xpu_specific + + +def has_xpu() -> bool: + return shared.cmd_opts.use_ipex and xpu_specific.has_xpu + def has_mps() -> bool: if sys.platform != "darwin": @@ -30,7 +37,7 @@ def get_optimal_device_name(): if has_mps(): return "mps" - if xpu_specific.has_ipex: + if has_xpu(): return xpu_specific.get_xpu_device_string() return "cpu" @@ -57,6 +64,9 @@ def torch_gc(): if has_mps(): mac_specific.torch_mps_gc() + if has_xpu(): + xpu_specific.torch_xpu_gc() + def enable_tf32(): if torch.cuda.is_available(): @@ -103,15 +113,11 @@ def autocast(disable=False): if dtype == torch.float32 or shared.cmd_opts.precision == "full": return contextlib.nullcontext() - if xpu_specific.has_xpu: - return torch.autocast("xpu") - return torch.autocast("cuda") def without_autocast(disable=False): - device_type = "xpu" if xpu_specific.has_xpu else "cuda" - return torch.autocast(device_type, enabled=False) if torch.is_autocast_enabled() and not disable else contextlib.nullcontext() + return torch.autocast("cuda", enabled=False) if torch.is_autocast_enabled() and not disable else contextlib.nullcontext() class NansException(Exception): diff --git a/modules/xpu_specific.py b/modules/xpu_specific.py index 6417dd2d659..2df68665ac9 100644 --- a/modules/xpu_specific.py +++ b/modules/xpu_specific.py @@ -1,4 +1,3 @@ -import contextlib from modules import shared from modules.sd_hijack_utils import CondFunc @@ -10,33 +9,42 @@ except Exception: pass -def check_for_xpu(): - if not has_ipex: - return False - return hasattr(torch, 'xpu') and torch.xpu.is_available() +def check_for_xpu(): + return has_ipex and hasattr(torch, 'xpu') and torch.xpu.is_available() -has_xpu = check_for_xpu() def get_xpu_device_string(): if shared.cmd_opts.device_id is not None: return f"xpu:{shared.cmd_opts.device_id}" return "xpu" -def return_null_context(*args, **kwargs): # pylint: disable=unused-argument - return contextlib.nullcontext() + +def torch_xpu_gc(): + with torch.xpu.device(get_xpu_device_string()): + torch.xpu.empty_cache() + + +has_xpu = check_for_xpu() if has_xpu: + # W/A for https://github.com/intel/intel-extension-for-pytorch/issues/452: torch.Generator API doesn't support XPU device CondFunc('torch.Generator', lambda orig_func, device=None: torch.xpu.Generator(device), - lambda orig_func, device=None: device is not None and device != torch.device("cpu") and device != "cpu") + lambda orig_func, device=None: device is not None and device.type == "xpu") + # W/A for some OPs that could not handle different input dtypes CondFunc('torch.nn.functional.layer_norm', lambda orig_func, input, normalized_shape=None, weight=None, *args, **kwargs: orig_func(input.to(weight.data.dtype), normalized_shape, weight, *args, **kwargs), lambda orig_func, input, normalized_shape=None, weight=None, *args, **kwargs: weight is not None and input.dtype != weight.data.dtype) - CondFunc('torch.nn.modules.GroupNorm.forward', lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)), lambda orig_func, self, input: input.dtype != self.weight.data.dtype) + CondFunc('torch.nn.modules.linear.Linear.forward', + lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)), + lambda orig_func, self, input: input.dtype != self.weight.data.dtype) + CondFunc('torch.nn.modules.conv.Conv2d.forward', + lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)), + lambda orig_func, self, input: input.dtype != self.weight.data.dtype) diff --git a/webui-ipex-user.bat b/webui-ipex-user.bat new file mode 100644 index 00000000000..ab25a0400d7 --- /dev/null +++ b/webui-ipex-user.bat @@ -0,0 +1,19 @@ +@echo off + +set PYTHON= +@REM The "Nuullll/intel-extension-for-pytorch" wheels were built from IPEX source for Intel Arc GPU: https://github.com/intel/intel-extension-for-pytorch/tree/xpu-main +@REM This is NOT an Intel official release so please use it at your own risk!! +@REM See https://github.com/Nuullll/intel-extension-for-pytorch/releases/tag/v2.0.110%2Bxpu-master%2Bdll-bundle for details. +@REM +@REM Strengths (over official IPEX 2.0.110 windows release): +@REM - AOT build (for Arc GPU only) to eliminate JIT compilation overhead: https://github.com/intel/intel-extension-for-pytorch/issues/399 +@REM - Bundles minimal oneAPI 2023.2 dependencies into the python wheels, so users don't need to install oneAPI for the whole system. +@REM - Provides a compatible torchvision wheel: https://github.com/intel/intel-extension-for-pytorch/issues/465 +@REM Limitation: +@REM - Only works for python 3.10 +set "TORCH_COMMAND=pip install https://github.com/Nuullll/intel-extension-for-pytorch/releases/download/v2.0.110%%2Bxpu-master%%2Bdll-bundle/torch-2.0.0a0+gite9ebda2-cp310-cp310-win_amd64.whl https://github.com/Nuullll/intel-extension-for-pytorch/releases/download/v2.0.110%%2Bxpu-master%%2Bdll-bundle/torchvision-0.15.2a0+fa99a53-cp310-cp310-win_amd64.whl https://github.com/Nuullll/intel-extension-for-pytorch/releases/download/v2.0.110%%2Bxpu-master%%2Bdll-bundle/intel_extension_for_pytorch-2.0.110+gitc6ea20b-cp310-cp310-win_amd64.whl" +set GIT= +set VENV_DIR= +set "COMMANDLINE_ARGS=--use-ipex --skip-torch-cuda-test --skip-version-check --opt-sdp-attention" + +call webui.bat From e294e46d46a814457fc77af13c17128bd6075d45 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sat, 2 Dec 2023 09:26:38 +0300 Subject: [PATCH 0349/1174] split UI settings page into many --- .../scripts/extra_options_section.py | 13 +++-- modules/shared_options.py | 57 +++++++++++-------- 2 files changed, 40 insertions(+), 30 deletions(-) diff --git a/extensions-builtin/extra-options-section/scripts/extra_options_section.py b/extensions-builtin/extra-options-section/scripts/extra_options_section.py index 983f87ff033..a903df62548 100644 --- a/extensions-builtin/extra-options-section/scripts/extra_options_section.py +++ b/extensions-builtin/extra-options-section/scripts/extra_options_section.py @@ -64,11 +64,14 @@ def before_process(self, p, *args): p.override_settings[name] = value -shared.options_templates.update(shared.options_section(('ui', "User interface"), { - "extra_options_txt2img": shared.OptionInfo([], "Options in main UI - txt2img", ui_components.DropdownMulti, lambda: {"choices": list(shared.opts.data_labels.keys())}).js("info", "settingsHintsShowQuicksettings").info("setting entries that also appear in txt2img interfaces").needs_reload_ui(), - "extra_options_img2img": shared.OptionInfo([], "Options in main UI - img2img", ui_components.DropdownMulti, lambda: {"choices": list(shared.opts.data_labels.keys())}).js("info", "settingsHintsShowQuicksettings").info("setting entries that also appear in img2img interfaces").needs_reload_ui(), - "extra_options_cols": shared.OptionInfo(1, "Options in main UI - number of columns", gr.Number, {"precision": 0}).needs_reload_ui(), - "extra_options_accordion": shared.OptionInfo(False, "Options in main UI - place into an accordion").needs_reload_ui() +shared.options_templates.update(shared.options_section(('settings_in_ui', "Settings in UI", "ui"), { + "settings_in_ui": shared.OptionHTML(""" +This page allows you to add some settings to the main interface of txt2img and img2img tabs. +"""), + "extra_options_txt2img": shared.OptionInfo([], "Settings for txt2img", ui_components.DropdownMulti, lambda: {"choices": list(shared.opts.data_labels.keys())}).js("info", "settingsHintsShowQuicksettings").info("setting entries that also appear in txt2img interfaces").needs_reload_ui(), + "extra_options_img2img": shared.OptionInfo([], "Settings for img2img", ui_components.DropdownMulti, lambda: {"choices": list(shared.opts.data_labels.keys())}).js("info", "settingsHintsShowQuicksettings").info("setting entries that also appear in img2img interfaces").needs_reload_ui(), + "extra_options_cols": shared.OptionInfo(1, "Number of columns for added settings", gr.Number, {"precision": 0}).needs_reload_ui(), + "extra_options_accordion": shared.OptionInfo(False, "Place added settings into an accordion").needs_reload_ui() })) diff --git a/modules/shared_options.py b/modules/shared_options.py index df45fc0a0c9..1390152db78 100644 --- a/modules/shared_options.py +++ b/modules/shared_options.py @@ -250,38 +250,45 @@ "sd_hypernetwork": OptionInfo("None", "Add hypernetwork to prompt", gr.Dropdown, lambda: {"choices": ["None", *shared.hypernetworks]}, refresh=shared_items.reload_hypernetworks), })) -options_templates.update(options_section(('ui', "User interface", "ui"), { - "localization": OptionInfo("None", "Localization", gr.Dropdown, lambda: {"choices": ["None"] + list(localization.localizations.keys())}, refresh=lambda: localization.list_localizations(cmd_opts.localizations_dir)).needs_reload_ui(), - "gradio_theme": OptionInfo("Default", "Gradio theme", ui_components.DropdownEditable, lambda: {"choices": ["Default"] + shared_gradio_themes.gradio_hf_hub_themes}).info("you can also manually enter any of themes from the gallery.").needs_reload_ui(), - "gradio_themes_cache": OptionInfo(True, "Cache gradio themes locally").info("disable to update the selected Gradio theme"), - "gallery_height": OptionInfo("", "Gallery height", gr.Textbox).info("an be any valid CSS value").needs_reload_ui(), - "return_grid": OptionInfo(True, "Show grid in results for web"), - "do_not_show_images": OptionInfo(False, "Do not show any images in results for web"), - "send_seed": OptionInfo(True, "Send seed when sending prompt or image to other interface"), - "send_size": OptionInfo(True, "Send size when sending prompt or image to another interface"), - "js_modal_lightbox": OptionInfo(True, "Enable full page image viewer"), - "js_modal_lightbox_initially_zoomed": OptionInfo(True, "Show images zoomed in by default in full page image viewer"), - "js_modal_lightbox_gamepad": OptionInfo(False, "Navigate image viewer with gamepad"), - "js_modal_lightbox_gamepad_repeat": OptionInfo(250, "Gamepad repeat period, in milliseconds"), - "show_progress_in_title": OptionInfo(True, "Show generation progress in window title."), +options_templates.update(options_section(('ui_prompt_editing', "Prompt editing", "ui"), { + "keyedit_precision_attention": OptionInfo(0.1, "Precision for (attention:1.1) when editing the prompt with Ctrl+up/down", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}), + "keyedit_precision_extra": OptionInfo(0.05, "Precision for when editing the prompt with Ctrl+up/down", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}), + "keyedit_delimiters": OptionInfo(r".,\/!?%^*;:{}=`~() ", "Word delimiters when editing the prompt with Ctrl+up/down"), + "disable_token_counters": OptionInfo(False, "Disable prompt token counters").needs_reload_ui(), +})) + +options_templates.update(options_section(('ui_gallery', "Gallery", "ui"), { + "return_grid": OptionInfo(True, "Show grid in gallery"), + "do_not_show_images": OptionInfo(False, "Do not show any images in gallery"), + "js_modal_lightbox": OptionInfo(True, "Full page image viewer: enable"), + "js_modal_lightbox_initially_zoomed": OptionInfo(True, "Full page image viewer: show images zoomed in by default"), + "js_modal_lightbox_gamepad": OptionInfo(False, "Full page image viewer: navigate with gamepad"), + "js_modal_lightbox_gamepad_repeat": OptionInfo(250, "Full page image viewer: gamepad repeat period").info("in milliseconds"), + "gallery_height": OptionInfo("", "Gallery height", gr.Textbox).info("can be any valid CSS value, for example 768px or 20em").needs_reload_ui(), +})) + +options_templates.update(options_section(('ui_alternatives', "UI alternatives", "ui"), { + "compact_prompt_box": OptionInfo(False, "Compact prompt layout").info("puts prompt and negative prompt inside the Generate tab, leaving more vertical space for the image on the right").needs_reload_ui(), "samplers_in_dropdown": OptionInfo(True, "Use dropdown for sampler selection instead of radio group").needs_reload_ui(), "dimensions_and_batch_together": OptionInfo(True, "Show Width/Height and Batch sliders in same row").needs_reload_ui(), - "keyedit_precision_attention": OptionInfo(0.1, "Ctrl+up/down precision when editing (attention:1.1)", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}), - "keyedit_precision_extra": OptionInfo(0.05, "Ctrl+up/down precision when editing ", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}), - "keyedit_delimiters": OptionInfo(r".,\/!?%^*;:{}=`~() ", "Ctrl+up/down word delimiters"), - "keyedit_delimiters_whitespace": OptionInfo(["Tab", "Carriage Return", "Line Feed"], "Ctrl+up/down whitespace delimiters", gr.CheckboxGroup, lambda: {"choices": ["Tab", "Carriage Return", "Line Feed"]}), - "keyedit_move": OptionInfo(True, "Alt+left/right moves prompt elements"), - "quicksettings_list": OptionInfo(["sd_model_checkpoint"], "Quicksettings list", ui_components.DropdownMulti, lambda: {"choices": list(shared.opts.data_labels.keys())}).js("info", "settingsHintsShowQuicksettings").info("setting entries that appear at the top of page rather than in settings tab").needs_reload_ui(), - "ui_tab_order": OptionInfo([], "UI tab order", ui_components.DropdownMulti, lambda: {"choices": list(shared.tab_names)}).needs_reload_ui(), - "hidden_tabs": OptionInfo([], "Hidden UI tabs", ui_components.DropdownMulti, lambda: {"choices": list(shared.tab_names)}).needs_reload_ui(), - "ui_reorder_list": OptionInfo([], "txt2img/img2img UI item order", ui_components.DropdownMulti, lambda: {"choices": list(shared_items.ui_reorder_categories())}).info("selected items appear first").needs_reload_ui(), "sd_checkpoint_dropdown_use_short": OptionInfo(False, "Checkpoint dropdown: use filenames without paths").info("models in subdirectories like photo/sd15.ckpt will be listed as just sd15.ckpt"), "hires_fix_show_sampler": OptionInfo(False, "Hires fix: show hires checkpoint and sampler selection").needs_reload_ui(), "hires_fix_show_prompts": OptionInfo(False, "Hires fix: show hires prompt and negative prompt").needs_reload_ui(), - "disable_token_counters": OptionInfo(False, "Disable prompt token counters").needs_reload_ui(), "txt2img_settings_accordion": OptionInfo(False, "Settings in txt2img hidden under Accordion").needs_reload_ui(), "img2img_settings_accordion": OptionInfo(False, "Settings in img2img hidden under Accordion").needs_reload_ui(), - "compact_prompt_box": OptionInfo(False, "Compact prompt layout").info("puts prompt and negative prompt inside the Generate tab, leaving more vertical space for the image on the right").needs_reload_ui(), +})) + +options_templates.update(options_section(('ui', "User interface", "ui"), { + "localization": OptionInfo("None", "Localization", gr.Dropdown, lambda: {"choices": ["None"] + list(localization.localizations.keys())}, refresh=lambda: localization.list_localizations(cmd_opts.localizations_dir)).needs_reload_ui(), + "quicksettings_list": OptionInfo(["sd_model_checkpoint"], "Quicksettings list", ui_components.DropdownMulti, lambda: {"choices": list(shared.opts.data_labels.keys())}).js("info", "settingsHintsShowQuicksettings").info("setting entries that appear at the top of page rather than in settings tab").needs_reload_ui(), + "ui_tab_order": OptionInfo([], "UI tab order", ui_components.DropdownMulti, lambda: {"choices": list(shared.tab_names)}).needs_reload_ui(), + "hidden_tabs": OptionInfo([], "Hidden UI tabs", ui_components.DropdownMulti, lambda: {"choices": list(shared.tab_names)}).needs_reload_ui(), + "ui_reorder_list": OptionInfo([], "UI item order for txt2img/img2img tabs", ui_components.DropdownMulti, lambda: {"choices": list(shared_items.ui_reorder_categories())}).info("selected items appear first").needs_reload_ui(), + "gradio_theme": OptionInfo("Default", "Gradio theme", ui_components.DropdownEditable, lambda: {"choices": ["Default"] + shared_gradio_themes.gradio_hf_hub_themes}).info("you can also manually enter any of themes from the gallery.").needs_reload_ui(), + "gradio_themes_cache": OptionInfo(True, "Cache gradio themes locally").info("disable to update the selected Gradio theme"), + "show_progress_in_title": OptionInfo(True, "Show generation progress in window title."), + "send_seed": OptionInfo(True, "Send seed when sending prompt or image to other interface"), + "send_size": OptionInfo(True, "Send size when sending prompt or image to another interface"), })) From ef6b8123dc57e4e4bd5e08d9f3e3dbdfdf6b4c4a Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sat, 2 Dec 2023 09:57:39 +0300 Subject: [PATCH 0350/1174] put code that can cause an exception into its own function for #14120 --- modules/scripts.py | 62 ++++++++++++++++++++++++---------------------- 1 file changed, 33 insertions(+), 29 deletions(-) diff --git a/modules/scripts.py b/modules/scripts.py index 961d032cea0..7f9454eb578 100644 --- a/modules/scripts.py +++ b/modules/scripts.py @@ -560,54 +560,58 @@ def apply_on_before_component_callbacks(self): on_after.clear() def create_script_ui(self, script): - import modules.api.models as api_models script.args_from = len(self.inputs) script.args_to = len(self.inputs) + try: + self.create_script_ui_inner(script) + except Exception: + errors.report(f"Error creating UI for {script.name}: ", exc_info=True) + + def create_script_ui_inner(self, script): + import modules.api.models as api_models + controls = wrap_call(script.ui, script.filename, "ui", script.is_img2img) if controls is None: return - try: - script.name = wrap_call(script.title, script.filename, "title", default=script.filename).lower() - api_args = [] + script.name = wrap_call(script.title, script.filename, "title", default=script.filename).lower() - for control in controls: - control.custom_script_source = os.path.basename(script.filename) + api_args = [] - arg_info = api_models.ScriptArg(label=control.label or "") + for control in controls: + control.custom_script_source = os.path.basename(script.filename) - for field in ("value", "minimum", "maximum", "step"): - v = getattr(control, field, None) - if v is not None: - setattr(arg_info, field, v) + arg_info = api_models.ScriptArg(label=control.label or "") - choices = getattr(control, 'choices', None) # as of gradio 3.41, some items in choices are strings, and some are tuples where the first elem is the string - if choices is not None: - arg_info.choices = [x[0] if isinstance(x, tuple) else x for x in choices] + for field in ("value", "minimum", "maximum", "step"): + v = getattr(control, field, None) + if v is not None: + setattr(arg_info, field, v) - api_args.append(arg_info) + choices = getattr(control, 'choices', None) # as of gradio 3.41, some items in choices are strings, and some are tuples where the first elem is the string + if choices is not None: + arg_info.choices = [x[0] if isinstance(x, tuple) else x for x in choices] - script.api_info = api_models.ScriptInfo( - name=script.name, - is_img2img=script.is_img2img, - is_alwayson=script.alwayson, - args=api_args, - ) + api_args.append(arg_info) - if script.infotext_fields is not None: - self.infotext_fields += script.infotext_fields + script.api_info = api_models.ScriptInfo( + name=script.name, + is_img2img=script.is_img2img, + is_alwayson=script.alwayson, + args=api_args, + ) - if script.paste_field_names is not None: - self.paste_field_names += script.paste_field_names + if script.infotext_fields is not None: + self.infotext_fields += script.infotext_fields - self.inputs += controls - script.args_to = len(self.inputs) + if script.paste_field_names is not None: + self.paste_field_names += script.paste_field_names - except Exception: - errors.report(f"Error creating UI for {script.name}: ", exc_info=True) + self.inputs += controls + script.args_to = len(self.inputs) def setup_ui_for_section(self, section, scriptlist=None): if scriptlist is None: From 87cd07b3af74c447b02570bf3963ba83ade2e203 Mon Sep 17 00:00:00 2001 From: Nuullll Date: Sat, 2 Dec 2023 15:54:25 +0800 Subject: [PATCH 0351/1174] Fix fp64 --- modules/sd_samplers_timesteps_impl.py | 4 ++-- modules/xpu_specific.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/modules/sd_samplers_timesteps_impl.py b/modules/sd_samplers_timesteps_impl.py index a72daafd47d..930a64af590 100644 --- a/modules/sd_samplers_timesteps_impl.py +++ b/modules/sd_samplers_timesteps_impl.py @@ -11,7 +11,7 @@ def ddim(model, x, timesteps, extra_args=None, callback=None, disable=None, eta=0.0): alphas_cumprod = model.inner_model.inner_model.alphas_cumprod alphas = alphas_cumprod[timesteps] - alphas_prev = alphas_cumprod[torch.nn.functional.pad(timesteps[:-1], pad=(1, 0))].to(torch.float64 if x.device.type != 'mps' else torch.float32) + alphas_prev = alphas_cumprod[torch.nn.functional.pad(timesteps[:-1], pad=(1, 0))].to(torch.float64 if x.device.type != 'mps' and x.device.type != 'xpu' else torch.float32) sqrt_one_minus_alphas = torch.sqrt(1 - alphas) sigmas = eta * np.sqrt((1 - alphas_prev.cpu().numpy()) / (1 - alphas.cpu()) * (1 - alphas.cpu() / alphas_prev.cpu().numpy())) @@ -43,7 +43,7 @@ def ddim(model, x, timesteps, extra_args=None, callback=None, disable=None, eta= def plms(model, x, timesteps, extra_args=None, callback=None, disable=None): alphas_cumprod = model.inner_model.inner_model.alphas_cumprod alphas = alphas_cumprod[timesteps] - alphas_prev = alphas_cumprod[torch.nn.functional.pad(timesteps[:-1], pad=(1, 0))].to(torch.float64 if x.device.type != 'mps' else torch.float32) + alphas_prev = alphas_cumprod[torch.nn.functional.pad(timesteps[:-1], pad=(1, 0))].to(torch.float64 if x.device.type != 'mps' and x.device.type != 'xpu' else torch.float32) sqrt_one_minus_alphas = torch.sqrt(1 - alphas) extra_args = {} if extra_args is None else extra_args diff --git a/modules/xpu_specific.py b/modules/xpu_specific.py index 2df68665ac9..d933c790378 100644 --- a/modules/xpu_specific.py +++ b/modules/xpu_specific.py @@ -4,7 +4,7 @@ has_ipex = False try: import torch - import intel_extension_for_pytorch as ipex + import intel_extension_for_pytorch as ipex # noqa: F401 has_ipex = True except Exception: pass From 4a666381bf98333ba4512db0f0033df5f6a08771 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sat, 2 Dec 2023 12:11:21 +0300 Subject: [PATCH 0352/1174] extras tab batch: actually use original filename preprocessing upscale: do not do an extra upscale step if it's not needed --- modules/postprocessing.py | 4 +++- modules/upscaler.py | 6 +++--- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/modules/postprocessing.py b/modules/postprocessing.py index fd0c0cc99c5..0a134ee4352 100644 --- a/modules/postprocessing.py +++ b/modules/postprocessing.py @@ -60,8 +60,10 @@ def get_images(extras_mode, image, image_folder, input_dir): if opts.use_original_name_batch and name is not None: basename = os.path.splitext(os.path.basename(name))[0] + forced_filename = basename else: basename = '' + forced_filename = None infotext = ", ".join([k if k == v else f'{k}: {generation_parameters_copypaste.quote(v)}' for k, v in pp.info.items() if v is not None]) @@ -70,7 +72,7 @@ def get_images(extras_mode, image, image_folder, input_dir): pp.image.info["postprocessing"] = infotext if save_output: - images.save_image(pp.image, path=outpath, basename=basename, seed=None, prompt=None, extension=opts.samples_format, info=infotext, short_filename=True, no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=existing_pnginfo, forced_filename=None) + images.save_image(pp.image, path=outpath, basename=basename, extension=opts.samples_format, info=infotext, short_filename=True, no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=existing_pnginfo, forced_filename=forced_filename) if extras_mode != 2 or show_extras_results: outputs.append(pp.image) diff --git a/modules/upscaler.py b/modules/upscaler.py index e682bbaa26c..b256e085b6d 100644 --- a/modules/upscaler.py +++ b/modules/upscaler.py @@ -57,6 +57,9 @@ def upscale(self, img: PIL.Image, scale, selected_model: str = None): dest_h = int((img.height * scale) // 8 * 8) for _ in range(3): + if img.width >= dest_w and img.height >= dest_h: + break + shape = (img.width, img.height) img = self.do_upscale(img, selected_model) @@ -64,9 +67,6 @@ def upscale(self, img: PIL.Image, scale, selected_model: str = None): if shape == (img.width, img.height): break - if img.width >= dest_w and img.height >= dest_h: - break - if img.width != dest_w or img.height != dest_h: img = img.resize((int(dest_w), int(dest_h)), resample=LANCZOS) From 96871e4f744471177d97e01c49f8587d7f67c125 Mon Sep 17 00:00:00 2001 From: Nuullll Date: Sat, 2 Dec 2023 17:11:11 +0800 Subject: [PATCH 0353/1174] Remove webui-ipex-user.bat --- modules/launch_utils.py | 22 ++++++++++++++++++++++ webui-ipex-user.bat | 19 ------------------- 2 files changed, 22 insertions(+), 19 deletions(-) delete mode 100644 webui-ipex-user.bat diff --git a/modules/launch_utils.py b/modules/launch_utils.py index 264ec9ca6a4..586cdc7eb8c 100644 --- a/modules/launch_utils.py +++ b/modules/launch_utils.py @@ -310,6 +310,26 @@ def requirements_met(requirements_file): def prepare_environment(): torch_index_url = os.environ.get('TORCH_INDEX_URL', "https://download.pytorch.org/whl/cu118") torch_command = os.environ.get('TORCH_COMMAND', f"pip install torch==2.0.1 torchvision==0.15.2 --extra-index-url {torch_index_url}") + if args.use_ipex: + if platform.system() == "Windows": + # The "Nuullll/intel-extension-for-pytorch" wheels were built from IPEX source for Intel Arc GPU: https://github.com/intel/intel-extension-for-pytorch/tree/xpu-main + # This is NOT an Intel official release so please use it at your own risk!! + # See https://github.com/Nuullll/intel-extension-for-pytorch/releases/tag/v2.0.110%2Bxpu-master%2Bdll-bundle for details. + # + # Strengths (over official IPEX 2.0.110 windows release): + # - AOT build (for Arc GPU only) to eliminate JIT compilation overhead: https://github.com/intel/intel-extension-for-pytorch/issues/399 + # - Bundles minimal oneAPI 2023.2 dependencies into the python wheels, so users don't need to install oneAPI for the whole system. + # - Provides a compatible torchvision wheel: https://github.com/intel/intel-extension-for-pytorch/issues/465 + # Limitation: + # - Only works for python 3.10 + url_prefix = "https://github.com/Nuullll/intel-extension-for-pytorch/releases/download/v2.0.110%2Bxpu-master%2Bdll-bundle" + torch_command = os.environ.get('TORCH_COMMAND', f"pip install {url_prefix}/torch-2.0.0a0+gite9ebda2-cp310-cp310-win_amd64.whl {url_prefix}/torchvision-0.15.2a0+fa99a53-cp310-cp310-win_amd64.whl {url_prefix}/intel_extension_for_pytorch-2.0.110+gitc6ea20b-cp310-cp310-win_amd64.whl") + else: + # Using official IPEX release for linux since it's already an AOT build. + # However, users still have to install oneAPI toolkit and activate oneAPI environment manually. + # See https://intel.github.io/intel-extension-for-pytorch/index.html#installation for details. + torch_index_url = os.environ.get('TORCH_INDEX_URL', "https://pytorch-extension.intel.com/release-whl/stable/xpu/us/") + torch_command = os.environ.get('TORCH_COMMAND', f"pip install torch==2.0.0a0 intel-extension-for-pytorch==2.0.110+gitba7f6c1 --extra-index-url {torch_index_url}") requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt") xformers_package = os.environ.get('XFORMERS_PACKAGE', 'xformers==0.0.20') @@ -352,6 +372,8 @@ def prepare_environment(): run(f'"{python}" -m {torch_command}', "Installing torch and torchvision", "Couldn't install torch", live=True) startup_timer.record("install torch") + if args.use_ipex: + args.skip_torch_cuda_test = True if not args.skip_torch_cuda_test and not check_run_python("import torch; assert torch.cuda.is_available()"): raise RuntimeError( 'Torch is not able to use GPU; ' diff --git a/webui-ipex-user.bat b/webui-ipex-user.bat deleted file mode 100644 index ab25a0400d7..00000000000 --- a/webui-ipex-user.bat +++ /dev/null @@ -1,19 +0,0 @@ -@echo off - -set PYTHON= -@REM The "Nuullll/intel-extension-for-pytorch" wheels were built from IPEX source for Intel Arc GPU: https://github.com/intel/intel-extension-for-pytorch/tree/xpu-main -@REM This is NOT an Intel official release so please use it at your own risk!! -@REM See https://github.com/Nuullll/intel-extension-for-pytorch/releases/tag/v2.0.110%2Bxpu-master%2Bdll-bundle for details. -@REM -@REM Strengths (over official IPEX 2.0.110 windows release): -@REM - AOT build (for Arc GPU only) to eliminate JIT compilation overhead: https://github.com/intel/intel-extension-for-pytorch/issues/399 -@REM - Bundles minimal oneAPI 2023.2 dependencies into the python wheels, so users don't need to install oneAPI for the whole system. -@REM - Provides a compatible torchvision wheel: https://github.com/intel/intel-extension-for-pytorch/issues/465 -@REM Limitation: -@REM - Only works for python 3.10 -set "TORCH_COMMAND=pip install https://github.com/Nuullll/intel-extension-for-pytorch/releases/download/v2.0.110%%2Bxpu-master%%2Bdll-bundle/torch-2.0.0a0+gite9ebda2-cp310-cp310-win_amd64.whl https://github.com/Nuullll/intel-extension-for-pytorch/releases/download/v2.0.110%%2Bxpu-master%%2Bdll-bundle/torchvision-0.15.2a0+fa99a53-cp310-cp310-win_amd64.whl https://github.com/Nuullll/intel-extension-for-pytorch/releases/download/v2.0.110%%2Bxpu-master%%2Bdll-bundle/intel_extension_for_pytorch-2.0.110+gitc6ea20b-cp310-cp310-win_amd64.whl" -set GIT= -set VENV_DIR= -set "COMMANDLINE_ARGS=--use-ipex --skip-torch-cuda-test --skip-version-check --opt-sdp-attention" - -call webui.bat From 50a21cb09fe3e9ea2d4fe058e0484e192c8a86e3 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Sat, 2 Dec 2023 22:06:47 +0800 Subject: [PATCH 0354/1174] Ensure the cached weight will not be affected --- modules/sd_models.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/modules/sd_models.py b/modules/sd_models.py index 4b8a9ae653a..dcf816b3a85 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -435,9 +435,9 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer for module in model.modules(): if isinstance(module, (torch.nn.Conv2d, torch.nn.Linear)): if shared.opts.cache_fp16_weight: - module.fp16_weight = module.weight.clone().half() + module.fp16_weight = module.weight.data.clone().cpu().half() if module.bias is not None: - module.fp16_bias = module.bias.clone().half() + module.fp16_bias = module.bias.data.clone().cpu().half() module.to(torch.float8_e4m3fn) model.first_stage_model = first_stage timer.record("apply fp8") From 11d23e8ca55c097ecfa255a05b63f194e25f08be Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sat, 2 Dec 2023 18:01:11 +0300 Subject: [PATCH 0355/1174] remove Train/Preprocessing tab and put all its functionality into extras batch images mode --- javascript/ui.js | 17 ++ modules/api/api.py | 15 -- modules/api/models.py | 3 - modules/postprocessing.py | 92 +++++-- modules/scripts_postprocessing.py | 86 ++++++- modules/shared_options.py | 1 + modules/textual_inversion/preprocess.py | 232 ------------------ modules/textual_inversion/ui.py | 7 - modules/ui.py | 107 -------- modules/ui_postprocessing.py | 16 +- modules/ui_toprow.py | 6 +- scripts/postprocessing_caption.py | 30 +++ scripts/postprocessing_codeformer.py | 16 +- .../postprocessing_create_flipped_copies.py | 32 +++ scripts/postprocessing_focal_crop.py | 54 ++++ scripts/postprocessing_gfpgan.py | 13 +- scripts/postprocessing_split_oversized.py | 71 ++++++ scripts/postprocessing_upscale.py | 12 + scripts/processing_autosized_crop.py | 64 +++++ 19 files changed, 460 insertions(+), 414 deletions(-) delete mode 100644 modules/textual_inversion/preprocess.py create mode 100644 scripts/postprocessing_caption.py create mode 100644 scripts/postprocessing_create_flipped_copies.py create mode 100644 scripts/postprocessing_focal_crop.py create mode 100644 scripts/postprocessing_split_oversized.py create mode 100644 scripts/processing_autosized_crop.py diff --git a/javascript/ui.js b/javascript/ui.js index 2e262602044..410fc44e3d9 100644 --- a/javascript/ui.js +++ b/javascript/ui.js @@ -170,6 +170,23 @@ function submit_img2img() { return res; } +function submit_extras() { + showSubmitButtons('extras', false); + + var id = randomId(); + + requestProgress(id, gradioApp().getElementById('extras_gallery_container'), gradioApp().getElementById('extras_gallery'), function() { + showSubmitButtons('extras', true); + }); + + var res = create_submit_args(arguments); + + res[0] = id; + + console.log(res); + return res; +} + function restoreProgressTxt2img() { showRestoreProgressButton("txt2img", false); var id = localGet("txt2img_task_id"); diff --git a/modules/api/api.py b/modules/api/api.py index 09083874792..b3d74e513a3 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -22,7 +22,6 @@ from modules.shared import opts from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images from modules.textual_inversion.textual_inversion import create_embedding, train_embedding -from modules.textual_inversion.preprocess import preprocess from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork from PIL import PngImagePlugin, Image from modules.sd_models_config import find_checkpoint_config_near_filename @@ -235,7 +234,6 @@ def __init__(self, app: FastAPI, queue_lock: Lock): self.add_api_route("/sdapi/v1/refresh-vae", self.refresh_vae, methods=["POST"]) self.add_api_route("/sdapi/v1/create/embedding", self.create_embedding, methods=["POST"], response_model=models.CreateResponse) self.add_api_route("/sdapi/v1/create/hypernetwork", self.create_hypernetwork, methods=["POST"], response_model=models.CreateResponse) - self.add_api_route("/sdapi/v1/preprocess", self.preprocess, methods=["POST"], response_model=models.PreprocessResponse) self.add_api_route("/sdapi/v1/train/embedding", self.train_embedding, methods=["POST"], response_model=models.TrainResponse) self.add_api_route("/sdapi/v1/train/hypernetwork", self.train_hypernetwork, methods=["POST"], response_model=models.TrainResponse) self.add_api_route("/sdapi/v1/memory", self.get_memory, methods=["GET"], response_model=models.MemoryResponse) @@ -675,19 +673,6 @@ def create_hypernetwork(self, args: dict): finally: shared.state.end() - def preprocess(self, args: dict): - try: - shared.state.begin(job="preprocess") - preprocess(**args) # quick operation unless blip/booru interrogation is enabled - shared.state.end() - return models.PreprocessResponse(info='preprocess complete') - except KeyError as e: - return models.PreprocessResponse(info=f"preprocess error: invalid token: {e}") - except Exception as e: - return models.PreprocessResponse(info=f"preprocess error: {e}") - finally: - shared.state.end() - def train_embedding(self, args: dict): try: shared.state.begin(job="train_embedding") diff --git a/modules/api/models.py b/modules/api/models.py index a0d80af8c0d..33894b3e694 100644 --- a/modules/api/models.py +++ b/modules/api/models.py @@ -202,9 +202,6 @@ class TrainResponse(BaseModel): class CreateResponse(BaseModel): info: str = Field(title="Create info", description="Response string from create embedding or hypernetwork task.") -class PreprocessResponse(BaseModel): - info: str = Field(title="Preprocess info", description="Response string from preprocessing task.") - fields = {} for key, metadata in opts.data_labels.items(): value = opts.data.get(key) diff --git a/modules/postprocessing.py b/modules/postprocessing.py index 0a134ee4352..3c85a74c1b3 100644 --- a/modules/postprocessing.py +++ b/modules/postprocessing.py @@ -6,7 +6,7 @@ from modules.shared import opts -def run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir, show_extras_results, *args, save_output: bool = True): +def run_postprocessing(id_task, extras_mode, image, image_folder, input_dir, output_dir, show_extras_results, *args, save_output: bool = True): devices.torch_gc() shared.state.begin(job="extras") @@ -29,11 +29,7 @@ def get_images(extras_mode, image, image_folder, input_dir): image_list = shared.listfiles(input_dir) for filename in image_list: - try: - image = Image.open(filename) - except Exception: - continue - yield image, filename + yield filename, filename else: assert image, 'image not selected' yield image, None @@ -45,37 +41,85 @@ def get_images(extras_mode, image, image_folder, input_dir): infotext = '' - for image_data, name in get_images(extras_mode, image, image_folder, input_dir): + data_to_process = list(get_images(extras_mode, image, image_folder, input_dir)) + shared.state.job_count = len(data_to_process) + + for image_placeholder, name in data_to_process: image_data: Image.Image + shared.state.nextjob() shared.state.textinfo = name + shared.state.skipped = False + + if shared.state.interrupted: + break + + if isinstance(image_placeholder, str): + try: + image_data = Image.open(image_placeholder) + except Exception: + continue + else: + image_data = image_placeholder + + shared.state.assign_current_image(image_data) parameters, existing_pnginfo = images.read_info_from_image(image_data) if parameters: existing_pnginfo["parameters"] = parameters - pp = scripts_postprocessing.PostprocessedImage(image_data.convert("RGB")) + initial_pp = scripts_postprocessing.PostprocessedImage(image_data.convert("RGB")) - scripts.scripts_postproc.run(pp, args) + scripts.scripts_postproc.run(initial_pp, args) - if opts.use_original_name_batch and name is not None: - basename = os.path.splitext(os.path.basename(name))[0] - forced_filename = basename - else: - basename = '' - forced_filename = None + if shared.state.skipped: + continue + + used_suffixes = {} + for pp in [initial_pp, *initial_pp.extra_images]: + suffix = pp.get_suffix(used_suffixes) + + if opts.use_original_name_batch and name is not None: + basename = os.path.splitext(os.path.basename(name))[0] + forced_filename = basename + suffix + else: + basename = '' + forced_filename = None + + infotext = ", ".join([k if k == v else f'{k}: {generation_parameters_copypaste.quote(v)}' for k, v in pp.info.items() if v is not None]) + + if opts.enable_pnginfo: + pp.image.info = existing_pnginfo + pp.image.info["postprocessing"] = infotext + + if save_output: + fullfn, _ = images.save_image(pp.image, path=outpath, basename=basename, extension=opts.samples_format, info=infotext, short_filename=True, no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=existing_pnginfo, forced_filename=forced_filename, suffix=suffix) - infotext = ", ".join([k if k == v else f'{k}: {generation_parameters_copypaste.quote(v)}' for k, v in pp.info.items() if v is not None]) + if pp.caption: + caption_filename = os.path.splitext(fullfn)[0] + ".txt" + if os.path.isfile(caption_filename): + with open(caption_filename, encoding="utf8") as file: + existing_caption = file.read().strip() + else: + existing_caption = "" - if opts.enable_pnginfo: - pp.image.info = existing_pnginfo - pp.image.info["postprocessing"] = infotext + action = shared.opts.postprocessing_existing_caption_action + if action == 'Prepend' and existing_caption: + caption = f"{existing_caption} {pp.caption}" + elif action == 'Append' and existing_caption: + caption = f"{pp.caption} {existing_caption}" + elif action == 'Keep' and existing_caption: + caption = existing_caption + else: + caption = pp.caption - if save_output: - images.save_image(pp.image, path=outpath, basename=basename, extension=opts.samples_format, info=infotext, short_filename=True, no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=existing_pnginfo, forced_filename=forced_filename) + caption = caption.strip() + if caption: + with open(caption_filename, "w", encoding="utf8") as file: + file.write(caption) - if extras_mode != 2 or show_extras_results: - outputs.append(pp.image) + if extras_mode != 2 or show_extras_results: + outputs.append(pp.image) image_data.close() @@ -99,9 +143,11 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_ "upscaler_2_visibility": extras_upscaler_2_visibility, }, "GFPGAN": { + "enable": True, "gfpgan_visibility": gfpgan_visibility, }, "CodeFormer": { + "enable": True, "codeformer_visibility": codeformer_visibility, "codeformer_weight": codeformer_weight, }, diff --git a/modules/scripts_postprocessing.py b/modules/scripts_postprocessing.py index bac1335dc35..901cad080c0 100644 --- a/modules/scripts_postprocessing.py +++ b/modules/scripts_postprocessing.py @@ -1,13 +1,56 @@ +import dataclasses import os import gradio as gr from modules import errors, shared +@dataclasses.dataclass +class PostprocessedImageSharedInfo: + target_width: int = None + target_height: int = None + + class PostprocessedImage: def __init__(self, image): self.image = image self.info = {} + self.shared = PostprocessedImageSharedInfo() + self.extra_images = [] + self.nametags = [] + self.disable_processing = False + self.caption = None + + def get_suffix(self, used_suffixes=None): + used_suffixes = {} if used_suffixes is None else used_suffixes + suffix = "-".join(self.nametags) + if suffix: + suffix = "-" + suffix + + if suffix not in used_suffixes: + used_suffixes[suffix] = 1 + return suffix + + for i in range(1, 100): + proposed_suffix = suffix + "-" + str(i) + + if proposed_suffix not in used_suffixes: + used_suffixes[proposed_suffix] = 1 + return proposed_suffix + + return suffix + + def create_copy(self, new_image, *, nametags=None, disable_processing=False): + pp = PostprocessedImage(new_image) + pp.shared = self.shared + pp.nametags = self.nametags.copy() + pp.info = self.info.copy() + pp.disable_processing = disable_processing + + if nametags is not None: + pp.nametags += nametags + + return pp class ScriptPostprocessing: @@ -42,10 +85,17 @@ def process(self, pp: PostprocessedImage, **args): pass - def image_changed(self): - pass + def process_firstpass(self, pp: PostprocessedImage, **args): + """ + Called for all scripts before calling process(). Scripts can examine the image here and set fields + of the pp object to communicate things to other scripts. + args contains a dictionary with all values returned by components from ui() + """ + pass + def image_changed(self): + pass def wrap_call(func, filename, funcname, *args, default=None, **kwargs): @@ -118,16 +168,42 @@ def setup_ui(self): return inputs def run(self, pp: PostprocessedImage, args): - for script in self.scripts_in_preferred_order(): - shared.state.job = script.name + scripts = [] + for script in self.scripts_in_preferred_order(): script_args = args[script.args_from:script.args_to] process_args = {} for (name, _component), value in zip(script.controls.items(), script_args): process_args[name] = value - script.process(pp, **process_args) + scripts.append((script, process_args)) + + for script, process_args in scripts: + script.process_firstpass(pp, **process_args) + + all_images = [pp] + + for script, process_args in scripts: + if shared.state.skipped: + break + + shared.state.job = script.name + + for single_image in all_images.copy(): + + if not single_image.disable_processing: + script.process(single_image, **process_args) + + for extra_image in single_image.extra_images: + if not isinstance(extra_image, PostprocessedImage): + extra_image = single_image.create_copy(extra_image) + + all_images.append(extra_image) + + single_image.extra_images.clear() + + pp.extra_images = all_images[1:] def create_args_for_run(self, scripts_args): if not self.ui_created: diff --git a/modules/shared_options.py b/modules/shared_options.py index d8a27180ef6..859dee40456 100644 --- a/modules/shared_options.py +++ b/modules/shared_options.py @@ -357,6 +357,7 @@ 'postprocessing_enable_in_main_ui': OptionInfo([], "Enable postprocessing operations in txt2img and img2img tabs", ui_components.DropdownMulti, lambda: {"choices": [x.name for x in shared_items.postprocessing_scripts()]}), 'postprocessing_operation_order': OptionInfo([], "Postprocessing operation order", ui_components.DropdownMulti, lambda: {"choices": [x.name for x in shared_items.postprocessing_scripts()]}), 'upscaling_max_images_in_cache': OptionInfo(5, "Maximum number of images in upscaling cache", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}), + 'postprocessing_existing_caption_action': OptionInfo("Ignore", "Action for existing captions", gr.Radio, {"choices": ["Ignore", "Keep", "Prepend", "Append"]}).info("when generating captions using postprocessing; Ignore = use generated; Keep = use original; Prepend/Append = combine both"), })) options_templates.update(options_section((None, "Hidden options"), { diff --git a/modules/textual_inversion/preprocess.py b/modules/textual_inversion/preprocess.py deleted file mode 100644 index 789fa08384b..00000000000 --- a/modules/textual_inversion/preprocess.py +++ /dev/null @@ -1,232 +0,0 @@ -import os -from PIL import Image, ImageOps -import math -import tqdm - -from modules import shared, images, deepbooru -from modules.textual_inversion import autocrop - - -def preprocess(id_task, process_src, process_dst, process_width, process_height, preprocess_txt_action, process_keep_original_size, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2, process_focal_crop=False, process_focal_crop_face_weight=0.9, process_focal_crop_entropy_weight=0.15, process_focal_crop_edges_weight=0.5, process_focal_crop_debug=False, process_multicrop=None, process_multicrop_mindim=None, process_multicrop_maxdim=None, process_multicrop_minarea=None, process_multicrop_maxarea=None, process_multicrop_objective=None, process_multicrop_threshold=None): - try: - if process_caption: - shared.interrogator.load() - - if process_caption_deepbooru: - deepbooru.model.start() - - preprocess_work(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_keep_original_size, process_flip, process_split, process_caption, process_caption_deepbooru, split_threshold, overlap_ratio, process_focal_crop, process_focal_crop_face_weight, process_focal_crop_entropy_weight, process_focal_crop_edges_weight, process_focal_crop_debug, process_multicrop, process_multicrop_mindim, process_multicrop_maxdim, process_multicrop_minarea, process_multicrop_maxarea, process_multicrop_objective, process_multicrop_threshold) - - finally: - - if process_caption: - shared.interrogator.send_blip_to_ram() - - if process_caption_deepbooru: - deepbooru.model.stop() - - -def listfiles(dirname): - return os.listdir(dirname) - - -class PreprocessParams: - src = None - dstdir = None - subindex = 0 - flip = False - process_caption = False - process_caption_deepbooru = False - preprocess_txt_action = None - - -def save_pic_with_caption(image, index, params: PreprocessParams, existing_caption=None): - caption = "" - - if params.process_caption: - caption += shared.interrogator.generate_caption(image) - - if params.process_caption_deepbooru: - if caption: - caption += ", " - caption += deepbooru.model.tag_multi(image) - - filename_part = params.src - filename_part = os.path.splitext(filename_part)[0] - filename_part = os.path.basename(filename_part) - - basename = f"{index:05}-{params.subindex}-{filename_part}" - image.save(os.path.join(params.dstdir, f"{basename}.png")) - - if params.preprocess_txt_action == 'prepend' and existing_caption: - caption = f"{existing_caption} {caption}" - elif params.preprocess_txt_action == 'append' and existing_caption: - caption = f"{caption} {existing_caption}" - elif params.preprocess_txt_action == 'copy' and existing_caption: - caption = existing_caption - - caption = caption.strip() - - if caption: - with open(os.path.join(params.dstdir, f"{basename}.txt"), "w", encoding="utf8") as file: - file.write(caption) - - params.subindex += 1 - - -def save_pic(image, index, params, existing_caption=None): - save_pic_with_caption(image, index, params, existing_caption=existing_caption) - - if params.flip: - save_pic_with_caption(ImageOps.mirror(image), index, params, existing_caption=existing_caption) - - -def split_pic(image, inverse_xy, width, height, overlap_ratio): - if inverse_xy: - from_w, from_h = image.height, image.width - to_w, to_h = height, width - else: - from_w, from_h = image.width, image.height - to_w, to_h = width, height - h = from_h * to_w // from_w - if inverse_xy: - image = image.resize((h, to_w)) - else: - image = image.resize((to_w, h)) - - split_count = math.ceil((h - to_h * overlap_ratio) / (to_h * (1.0 - overlap_ratio))) - y_step = (h - to_h) / (split_count - 1) - for i in range(split_count): - y = int(y_step * i) - if inverse_xy: - splitted = image.crop((y, 0, y + to_h, to_w)) - else: - splitted = image.crop((0, y, to_w, y + to_h)) - yield splitted - -# not using torchvision.transforms.CenterCrop because it doesn't allow float regions -def center_crop(image: Image, w: int, h: int): - iw, ih = image.size - if ih / h < iw / w: - sw = w * ih / h - box = (iw - sw) / 2, 0, iw - (iw - sw) / 2, ih - else: - sh = h * iw / w - box = 0, (ih - sh) / 2, iw, ih - (ih - sh) / 2 - return image.resize((w, h), Image.Resampling.LANCZOS, box) - - -def multicrop_pic(image: Image, mindim, maxdim, minarea, maxarea, objective, threshold): - iw, ih = image.size - err = lambda w, h: 1-(lambda x: x if x < 1 else 1/x)(iw/ih/(w/h)) - wh = max(((w, h) for w in range(mindim, maxdim+1, 64) for h in range(mindim, maxdim+1, 64) - if minarea <= w * h <= maxarea and err(w, h) <= threshold), - key= lambda wh: (wh[0]*wh[1], -err(*wh))[::1 if objective=='Maximize area' else -1], - default=None - ) - return wh and center_crop(image, *wh) - - -def preprocess_work(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_keep_original_size, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2, process_focal_crop=False, process_focal_crop_face_weight=0.9, process_focal_crop_entropy_weight=0.3, process_focal_crop_edges_weight=0.5, process_focal_crop_debug=False, process_multicrop=None, process_multicrop_mindim=None, process_multicrop_maxdim=None, process_multicrop_minarea=None, process_multicrop_maxarea=None, process_multicrop_objective=None, process_multicrop_threshold=None): - width = process_width - height = process_height - src = os.path.abspath(process_src) - dst = os.path.abspath(process_dst) - split_threshold = max(0.0, min(1.0, split_threshold)) - overlap_ratio = max(0.0, min(0.9, overlap_ratio)) - - assert src != dst, 'same directory specified as source and destination' - - os.makedirs(dst, exist_ok=True) - - files = listfiles(src) - - shared.state.job = "preprocess" - shared.state.textinfo = "Preprocessing..." - shared.state.job_count = len(files) - - params = PreprocessParams() - params.dstdir = dst - params.flip = process_flip - params.process_caption = process_caption - params.process_caption_deepbooru = process_caption_deepbooru - params.preprocess_txt_action = preprocess_txt_action - - pbar = tqdm.tqdm(files) - for index, imagefile in enumerate(pbar): - params.subindex = 0 - filename = os.path.join(src, imagefile) - try: - img = Image.open(filename) - img = ImageOps.exif_transpose(img) - img = img.convert("RGB") - except Exception: - continue - - description = f"Preprocessing [Image {index}/{len(files)}]" - pbar.set_description(description) - shared.state.textinfo = description - - params.src = filename - - existing_caption = None - existing_caption_filename = f"{os.path.splitext(filename)[0]}.txt" - if os.path.exists(existing_caption_filename): - with open(existing_caption_filename, 'r', encoding="utf8") as file: - existing_caption = file.read() - - if shared.state.interrupted: - break - - if img.height > img.width: - ratio = (img.width * height) / (img.height * width) - inverse_xy = False - else: - ratio = (img.height * width) / (img.width * height) - inverse_xy = True - - process_default_resize = True - - if process_split and ratio < 1.0 and ratio <= split_threshold: - for splitted in split_pic(img, inverse_xy, width, height, overlap_ratio): - save_pic(splitted, index, params, existing_caption=existing_caption) - process_default_resize = False - - if process_focal_crop and img.height != img.width: - - dnn_model_path = None - try: - dnn_model_path = autocrop.download_and_cache_models() - except Exception as e: - print("Unable to load face detection model for auto crop selection. Falling back to lower quality haar method.", e) - - autocrop_settings = autocrop.Settings( - crop_width = width, - crop_height = height, - face_points_weight = process_focal_crop_face_weight, - entropy_points_weight = process_focal_crop_entropy_weight, - corner_points_weight = process_focal_crop_edges_weight, - annotate_image = process_focal_crop_debug, - dnn_model_path = dnn_model_path, - ) - for focal in autocrop.crop_image(img, autocrop_settings): - save_pic(focal, index, params, existing_caption=existing_caption) - process_default_resize = False - - if process_multicrop: - cropped = multicrop_pic(img, process_multicrop_mindim, process_multicrop_maxdim, process_multicrop_minarea, process_multicrop_maxarea, process_multicrop_objective, process_multicrop_threshold) - if cropped is not None: - save_pic(cropped, index, params, existing_caption=existing_caption) - else: - print(f"skipped {img.width}x{img.height} image {filename} (can't find suitable size within error threshold)") - process_default_resize = False - - if process_keep_original_size: - save_pic(img, index, params, existing_caption=existing_caption) - process_default_resize = False - - if process_default_resize: - img = images.resize_image(1, img, width, height) - save_pic(img, index, params, existing_caption=existing_caption) - - shared.state.nextjob() diff --git a/modules/textual_inversion/ui.py b/modules/textual_inversion/ui.py index 35c4feeff45..f149ad1f0f2 100644 --- a/modules/textual_inversion/ui.py +++ b/modules/textual_inversion/ui.py @@ -3,7 +3,6 @@ import gradio as gr import modules.textual_inversion.textual_inversion -import modules.textual_inversion.preprocess from modules import sd_hijack, shared @@ -15,12 +14,6 @@ def create_embedding(name, initialization_text, nvpt, overwrite_old): return gr.Dropdown.update(choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())), f"Created: {filename}", "" -def preprocess(*args): - modules.textual_inversion.preprocess.preprocess(*args) - - return f"Preprocessing {'interrupted' if shared.state.interrupted else 'finished'}.", "" - - def train_embedding(*args): assert not shared.cmd_opts.lowvram, 'Training models with lowvram not possible' diff --git a/modules/ui.py b/modules/ui.py index 08e0ad775c6..d80486dd4ac 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -912,71 +912,6 @@ def select_img2img_tab(tab): with gr.Column(): create_hypernetwork = gr.Button(value="Create hypernetwork", variant='primary', elem_id="train_create_hypernetwork") - with gr.Tab(label="Preprocess images", id="preprocess_images"): - process_src = gr.Textbox(label='Source directory', elem_id="train_process_src") - process_dst = gr.Textbox(label='Destination directory', elem_id="train_process_dst") - process_width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="train_process_width") - process_height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="train_process_height") - preprocess_txt_action = gr.Dropdown(label='Existing Caption txt Action', value="ignore", choices=["ignore", "copy", "prepend", "append"], elem_id="train_preprocess_txt_action") - - with gr.Row(): - process_keep_original_size = gr.Checkbox(label='Keep original size', elem_id="train_process_keep_original_size") - process_flip = gr.Checkbox(label='Create flipped copies', elem_id="train_process_flip") - process_split = gr.Checkbox(label='Split oversized images', elem_id="train_process_split") - process_focal_crop = gr.Checkbox(label='Auto focal point crop', elem_id="train_process_focal_crop") - process_multicrop = gr.Checkbox(label='Auto-sized crop', elem_id="train_process_multicrop") - process_caption = gr.Checkbox(label='Use BLIP for caption', elem_id="train_process_caption") - process_caption_deepbooru = gr.Checkbox(label='Use deepbooru for caption', visible=True, elem_id="train_process_caption_deepbooru") - - with gr.Row(visible=False) as process_split_extra_row: - process_split_threshold = gr.Slider(label='Split image threshold', value=0.5, minimum=0.0, maximum=1.0, step=0.05, elem_id="train_process_split_threshold") - process_overlap_ratio = gr.Slider(label='Split image overlap ratio', value=0.2, minimum=0.0, maximum=0.9, step=0.05, elem_id="train_process_overlap_ratio") - - with gr.Row(visible=False) as process_focal_crop_row: - process_focal_crop_face_weight = gr.Slider(label='Focal point face weight', value=0.9, minimum=0.0, maximum=1.0, step=0.05, elem_id="train_process_focal_crop_face_weight") - process_focal_crop_entropy_weight = gr.Slider(label='Focal point entropy weight', value=0.15, minimum=0.0, maximum=1.0, step=0.05, elem_id="train_process_focal_crop_entropy_weight") - process_focal_crop_edges_weight = gr.Slider(label='Focal point edges weight', value=0.5, minimum=0.0, maximum=1.0, step=0.05, elem_id="train_process_focal_crop_edges_weight") - process_focal_crop_debug = gr.Checkbox(label='Create debug image', elem_id="train_process_focal_crop_debug") - - with gr.Column(visible=False) as process_multicrop_col: - gr.Markdown('Each image is center-cropped with an automatically chosen width and height.') - with gr.Row(): - process_multicrop_mindim = gr.Slider(minimum=64, maximum=2048, step=8, label="Dimension lower bound", value=384, elem_id="train_process_multicrop_mindim") - process_multicrop_maxdim = gr.Slider(minimum=64, maximum=2048, step=8, label="Dimension upper bound", value=768, elem_id="train_process_multicrop_maxdim") - with gr.Row(): - process_multicrop_minarea = gr.Slider(minimum=64*64, maximum=2048*2048, step=1, label="Area lower bound", value=64*64, elem_id="train_process_multicrop_minarea") - process_multicrop_maxarea = gr.Slider(minimum=64*64, maximum=2048*2048, step=1, label="Area upper bound", value=640*640, elem_id="train_process_multicrop_maxarea") - with gr.Row(): - process_multicrop_objective = gr.Radio(["Maximize area", "Minimize error"], value="Maximize area", label="Resizing objective", elem_id="train_process_multicrop_objective") - process_multicrop_threshold = gr.Slider(minimum=0, maximum=1, step=0.01, label="Error threshold", value=0.1, elem_id="train_process_multicrop_threshold") - - with gr.Row(): - with gr.Column(scale=3): - gr.HTML(value="") - - with gr.Column(): - with gr.Row(): - interrupt_preprocessing = gr.Button("Interrupt", elem_id="train_interrupt_preprocessing") - run_preprocess = gr.Button(value="Preprocess", variant='primary', elem_id="train_run_preprocess") - - process_split.change( - fn=lambda show: gr_show(show), - inputs=[process_split], - outputs=[process_split_extra_row], - ) - - process_focal_crop.change( - fn=lambda show: gr_show(show), - inputs=[process_focal_crop], - outputs=[process_focal_crop_row], - ) - - process_multicrop.change( - fn=lambda show: gr_show(show), - inputs=[process_multicrop], - outputs=[process_multicrop_col], - ) - def get_textual_inversion_template_names(): return sorted(textual_inversion.textual_inversion_templates) @@ -1077,42 +1012,6 @@ def get_textual_inversion_template_names(): ] ) - run_preprocess.click( - fn=wrap_gradio_gpu_call(textual_inversion_ui.preprocess, extra_outputs=[gr.update()]), - _js="start_training_textual_inversion", - inputs=[ - dummy_component, - process_src, - process_dst, - process_width, - process_height, - preprocess_txt_action, - process_keep_original_size, - process_flip, - process_split, - process_caption, - process_caption_deepbooru, - process_split_threshold, - process_overlap_ratio, - process_focal_crop, - process_focal_crop_face_weight, - process_focal_crop_entropy_weight, - process_focal_crop_edges_weight, - process_focal_crop_debug, - process_multicrop, - process_multicrop_mindim, - process_multicrop_maxdim, - process_multicrop_minarea, - process_multicrop_maxarea, - process_multicrop_objective, - process_multicrop_threshold, - ], - outputs=[ - ti_output, - ti_outcome, - ], - ) - train_embedding.click( fn=wrap_gradio_gpu_call(textual_inversion_ui.train_embedding, extra_outputs=[gr.update()]), _js="start_training_textual_inversion", @@ -1186,12 +1085,6 @@ def get_textual_inversion_template_names(): outputs=[], ) - interrupt_preprocessing.click( - fn=lambda: shared.state.interrupt(), - inputs=[], - outputs=[], - ) - loadsave = ui_loadsave.UiLoadsave(cmd_opts.ui_config_file) settings = ui_settings.UiSettings() diff --git a/modules/ui_postprocessing.py b/modules/ui_postprocessing.py index 802e1ce71a1..fbad0800aca 100644 --- a/modules/ui_postprocessing.py +++ b/modules/ui_postprocessing.py @@ -1,9 +1,10 @@ import gradio as gr -from modules import scripts, shared, ui_common, postprocessing, call_queue +from modules import scripts, shared, ui_common, postprocessing, call_queue, ui_toprow import modules.generation_parameters_copypaste as parameters_copypaste def create_ui(): + dummy_component = gr.Label(visible=False) tab_index = gr.State(value=0) with gr.Row(equal_height=False, variant='compact'): @@ -20,11 +21,13 @@ def create_ui(): extras_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, placeholder="Leave blank to save images to the default path.", elem_id="extras_batch_output_dir") show_extras_results = gr.Checkbox(label='Show result images', value=True, elem_id="extras_show_extras_results") - submit = gr.Button('Generate', elem_id="extras_generate", variant='primary') - script_inputs = scripts.scripts_postproc.setup_ui() with gr.Column(): + toprow = ui_toprow.Toprow(is_compact=True, is_img2img=False, id_part="extras") + toprow.create_inline_toprow_image() + submit = toprow.submit + result_images, html_info_x, html_info, html_log = ui_common.create_output_panel("extras", shared.opts.outdir_extras_samples) tab_single.select(fn=lambda: 0, inputs=[], outputs=[tab_index]) @@ -33,7 +36,9 @@ def create_ui(): submit.click( fn=call_queue.wrap_gradio_gpu_call(postprocessing.run_postprocessing, extra_outputs=[None, '']), + _js="submit_extras", inputs=[ + dummy_component, tab_index, extras_image, image_batch, @@ -45,8 +50,9 @@ def create_ui(): outputs=[ result_images, html_info_x, - html_info, - ] + html_log, + ], + show_progress=False, ) parameters_copypaste.add_paste_fields("extras", extras_image, None) diff --git a/modules/ui_toprow.py b/modules/ui_toprow.py index 985b5a2dd9c..88838f97749 100644 --- a/modules/ui_toprow.py +++ b/modules/ui_toprow.py @@ -34,8 +34,10 @@ class Toprow: submit_box = None - def __init__(self, is_img2img, is_compact=False): - id_part = "img2img" if is_img2img else "txt2img" + def __init__(self, is_img2img, is_compact=False, id_part=None): + if id_part is None: + id_part = "img2img" if is_img2img else "txt2img" + self.id_part = id_part self.is_img2img = is_img2img self.is_compact = is_compact diff --git a/scripts/postprocessing_caption.py b/scripts/postprocessing_caption.py new file mode 100644 index 00000000000..243e3ad9c62 --- /dev/null +++ b/scripts/postprocessing_caption.py @@ -0,0 +1,30 @@ +from modules import scripts_postprocessing, ui_components, deepbooru, shared +import gradio as gr + + +class ScriptPostprocessingCeption(scripts_postprocessing.ScriptPostprocessing): + name = "Caption" + order = 4000 + + def ui(self): + with ui_components.InputAccordion(False, label="Caption") as enable: + option = gr.CheckboxGroup(value=["Deepbooru"], choices=["Deepbooru", "BLIP"], show_label=False) + + return { + "enable": enable, + "option": option, + } + + def process(self, pp: scripts_postprocessing.PostprocessedImage, enable, option): + if not enable: + return + + captions = [pp.caption] + + if "Deepbooru" in option: + captions.append(deepbooru.model.tag(pp.image)) + + if "BLIP" in option: + captions.append(shared.interrogator.generate_caption(pp.image)) + + pp.caption = ", ".join([x for x in captions if x]) diff --git a/scripts/postprocessing_codeformer.py b/scripts/postprocessing_codeformer.py index a7d80d40e2b..e1e156ddcaa 100644 --- a/scripts/postprocessing_codeformer.py +++ b/scripts/postprocessing_codeformer.py @@ -1,28 +1,28 @@ from PIL import Image import numpy as np -from modules import scripts_postprocessing, codeformer_model +from modules import scripts_postprocessing, codeformer_model, ui_components import gradio as gr -from modules.ui_components import FormRow - class ScriptPostprocessingCodeFormer(scripts_postprocessing.ScriptPostprocessing): name = "CodeFormer" order = 3000 def ui(self): - with FormRow(): - codeformer_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="CodeFormer visibility", value=0, elem_id="extras_codeformer_visibility") - codeformer_weight = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="CodeFormer weight (0 = maximum effect, 1 = minimum effect)", value=0, elem_id="extras_codeformer_weight") + with ui_components.InputAccordion(False, label="CodeFormer") as enable: + with gr.Row(): + codeformer_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="Visibility", value=1.0, elem_id="extras_codeformer_visibility") + codeformer_weight = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="Weight (0 = maximum effect, 1 = minimum effect)", value=0, elem_id="extras_codeformer_weight") return { + "enable": enable, "codeformer_visibility": codeformer_visibility, "codeformer_weight": codeformer_weight, } - def process(self, pp: scripts_postprocessing.PostprocessedImage, codeformer_visibility, codeformer_weight): - if codeformer_visibility == 0: + def process(self, pp: scripts_postprocessing.PostprocessedImage, enable, codeformer_visibility, codeformer_weight): + if codeformer_visibility == 0 or not enable: return restored_img = codeformer_model.codeformer.restore(np.array(pp.image, dtype=np.uint8), w=codeformer_weight) diff --git a/scripts/postprocessing_create_flipped_copies.py b/scripts/postprocessing_create_flipped_copies.py new file mode 100644 index 00000000000..3425571dc3b --- /dev/null +++ b/scripts/postprocessing_create_flipped_copies.py @@ -0,0 +1,32 @@ +from PIL import ImageOps, Image + +from modules import scripts_postprocessing, ui_components +import gradio as gr + + +class ScriptPostprocessingCreateFlippedCopies(scripts_postprocessing.ScriptPostprocessing): + name = "Create flipped copies" + order = 4000 + + def ui(self): + with ui_components.InputAccordion(False, label="Create flipped copies") as enable: + with gr.Row(): + option = gr.CheckboxGroup(value=["Horizontal"], choices=["Horizontal", "Vertical", "Both"], show_label=False) + + return { + "enable": enable, + "option": option, + } + + def process(self, pp: scripts_postprocessing.PostprocessedImage, enable, option): + if not enable: + return + + if "Horizontal" in option: + pp.extra_images.append(ImageOps.mirror(pp.image)) + + if "Vertical" in option: + pp.extra_images.append(pp.image.transpose(Image.Transpose.FLIP_TOP_BOTTOM)) + + if "Both" in option: + pp.extra_images.append(pp.image.transpose(Image.Transpose.FLIP_TOP_BOTTOM).transpose(Image.Transpose.FLIP_LEFT_RIGHT)) diff --git a/scripts/postprocessing_focal_crop.py b/scripts/postprocessing_focal_crop.py new file mode 100644 index 00000000000..d3baf29878a --- /dev/null +++ b/scripts/postprocessing_focal_crop.py @@ -0,0 +1,54 @@ + +from modules import scripts_postprocessing, ui_components, errors +import gradio as gr + +from modules.textual_inversion import autocrop + + +class ScriptPostprocessingFocalCrop(scripts_postprocessing.ScriptPostprocessing): + name = "Auto focal point crop" + order = 4000 + + def ui(self): + with ui_components.InputAccordion(False, label="Auto focal point crop") as enable: + face_weight = gr.Slider(label='Focal point face weight', value=0.9, minimum=0.0, maximum=1.0, step=0.05, elem_id="postprocess_focal_crop_face_weight") + entropy_weight = gr.Slider(label='Focal point entropy weight', value=0.15, minimum=0.0, maximum=1.0, step=0.05, elem_id="postprocess_focal_crop_entropy_weight") + edges_weight = gr.Slider(label='Focal point edges weight', value=0.5, minimum=0.0, maximum=1.0, step=0.05, elem_id="postprocess_focal_crop_edges_weight") + debug = gr.Checkbox(label='Create debug image', elem_id="train_process_focal_crop_debug") + + return { + "enable": enable, + "face_weight": face_weight, + "entropy_weight": entropy_weight, + "edges_weight": edges_weight, + "debug": debug, + } + + def process(self, pp: scripts_postprocessing.PostprocessedImage, enable, face_weight, entropy_weight, edges_weight, debug): + if not enable: + return + + if not pp.shared.target_width or not pp.shared.target_height: + return + + dnn_model_path = None + try: + dnn_model_path = autocrop.download_and_cache_models() + except Exception: + errors.report("Unable to load face detection model for auto crop selection. Falling back to lower quality haar method.", exc_info=True) + + autocrop_settings = autocrop.Settings( + crop_width=pp.shared.target_width, + crop_height=pp.shared.target_height, + face_points_weight=face_weight, + entropy_points_weight=entropy_weight, + corner_points_weight=edges_weight, + annotate_image=debug, + dnn_model_path=dnn_model_path, + ) + + result, *others = autocrop.crop_image(pp.image, autocrop_settings) + + pp.image = result + pp.extra_images = [pp.create_copy(x, nametags=["focal-crop-debug"], disable_processing=True) for x in others] + diff --git a/scripts/postprocessing_gfpgan.py b/scripts/postprocessing_gfpgan.py index d854f3f7748..6e7566055d2 100644 --- a/scripts/postprocessing_gfpgan.py +++ b/scripts/postprocessing_gfpgan.py @@ -1,26 +1,25 @@ from PIL import Image import numpy as np -from modules import scripts_postprocessing, gfpgan_model +from modules import scripts_postprocessing, gfpgan_model, ui_components import gradio as gr -from modules.ui_components import FormRow - class ScriptPostprocessingGfpGan(scripts_postprocessing.ScriptPostprocessing): name = "GFPGAN" order = 2000 def ui(self): - with FormRow(): - gfpgan_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="GFPGAN visibility", value=0, elem_id="extras_gfpgan_visibility") + with ui_components.InputAccordion(False, label="GFPGAN") as enable: + gfpgan_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="Visibility", value=1.0, elem_id="extras_gfpgan_visibility") return { + "enable": enable, "gfpgan_visibility": gfpgan_visibility, } - def process(self, pp: scripts_postprocessing.PostprocessedImage, gfpgan_visibility): - if gfpgan_visibility == 0: + def process(self, pp: scripts_postprocessing.PostprocessedImage, enable, gfpgan_visibility): + if gfpgan_visibility == 0 or not enable: return restored_img = gfpgan_model.gfpgan_fix_faces(np.array(pp.image, dtype=np.uint8)) diff --git a/scripts/postprocessing_split_oversized.py b/scripts/postprocessing_split_oversized.py new file mode 100644 index 00000000000..c4a03160fc6 --- /dev/null +++ b/scripts/postprocessing_split_oversized.py @@ -0,0 +1,71 @@ +import math + +from modules import scripts_postprocessing, ui_components +import gradio as gr + + +def split_pic(image, inverse_xy, width, height, overlap_ratio): + if inverse_xy: + from_w, from_h = image.height, image.width + to_w, to_h = height, width + else: + from_w, from_h = image.width, image.height + to_w, to_h = width, height + h = from_h * to_w // from_w + if inverse_xy: + image = image.resize((h, to_w)) + else: + image = image.resize((to_w, h)) + + split_count = math.ceil((h - to_h * overlap_ratio) / (to_h * (1.0 - overlap_ratio))) + y_step = (h - to_h) / (split_count - 1) + for i in range(split_count): + y = int(y_step * i) + if inverse_xy: + splitted = image.crop((y, 0, y + to_h, to_w)) + else: + splitted = image.crop((0, y, to_w, y + to_h)) + yield splitted + + +class ScriptPostprocessingSplitOversized(scripts_postprocessing.ScriptPostprocessing): + name = "Split oversized images" + order = 4000 + + def ui(self): + with ui_components.InputAccordion(False, label="Split oversized images") as enable: + with gr.Row(): + split_threshold = gr.Slider(label='Threshold', value=0.5, minimum=0.0, maximum=1.0, step=0.05, elem_id="postprocess_split_threshold") + overlap_ratio = gr.Slider(label='Overlap ratio', value=0.2, minimum=0.0, maximum=0.9, step=0.05, elem_id="postprocess_overlap_ratio") + + return { + "enable": enable, + "split_threshold": split_threshold, + "overlap_ratio": overlap_ratio, + } + + def process(self, pp: scripts_postprocessing.PostprocessedImage, enable, split_threshold, overlap_ratio): + if not enable: + return + + width = pp.shared.target_width + height = pp.shared.target_height + + if not width or not height: + return + + if pp.image.height > pp.image.width: + ratio = (pp.image.width * height) / (pp.image.height * width) + inverse_xy = False + else: + ratio = (pp.image.height * width) / (pp.image.width * height) + inverse_xy = True + + if ratio >= 1.0 and ratio > split_threshold: + return + + result, *others = split_pic(pp.image, inverse_xy, width, height, overlap_ratio) + + pp.image = result + pp.extra_images = [pp.create_copy(x) for x in others] + diff --git a/scripts/postprocessing_upscale.py b/scripts/postprocessing_upscale.py index eb42a29e521..ed709688de4 100644 --- a/scripts/postprocessing_upscale.py +++ b/scripts/postprocessing_upscale.py @@ -81,6 +81,14 @@ def upscale(self, image, info, upscaler, upscale_mode, upscale_by, upscale_to_w return image + def process_firstpass(self, pp: scripts_postprocessing.PostprocessedImage, upscale_mode=1, upscale_by=2.0, upscale_to_width=None, upscale_to_height=None, upscale_crop=False, upscaler_1_name=None, upscaler_2_name=None, upscaler_2_visibility=0.0): + if upscale_mode == 1: + pp.shared.target_width = upscale_to_width + pp.shared.target_height = upscale_to_height + else: + pp.shared.target_width = int(pp.image.width * upscale_by) + pp.shared.target_height = int(pp.image.height * upscale_by) + def process(self, pp: scripts_postprocessing.PostprocessedImage, upscale_mode=1, upscale_by=2.0, upscale_to_width=None, upscale_to_height=None, upscale_crop=False, upscaler_1_name=None, upscaler_2_name=None, upscaler_2_visibility=0.0): if upscaler_1_name == "None": upscaler_1_name = None @@ -126,6 +134,10 @@ def ui(self): "upscaler_name": upscaler_name, } + def process_firstpass(self, pp: scripts_postprocessing.PostprocessedImage, upscale_by=2.0, upscaler_name=None): + pp.shared.target_width = int(pp.image.width * upscale_by) + pp.shared.target_height = int(pp.image.height * upscale_by) + def process(self, pp: scripts_postprocessing.PostprocessedImage, upscale_by=2.0, upscaler_name=None): if upscaler_name is None or upscaler_name == "None": return diff --git a/scripts/processing_autosized_crop.py b/scripts/processing_autosized_crop.py new file mode 100644 index 00000000000..c098022645d --- /dev/null +++ b/scripts/processing_autosized_crop.py @@ -0,0 +1,64 @@ +from PIL import Image + +from modules import scripts_postprocessing, ui_components +import gradio as gr + + +def center_crop(image: Image, w: int, h: int): + iw, ih = image.size + if ih / h < iw / w: + sw = w * ih / h + box = (iw - sw) / 2, 0, iw - (iw - sw) / 2, ih + else: + sh = h * iw / w + box = 0, (ih - sh) / 2, iw, ih - (ih - sh) / 2 + return image.resize((w, h), Image.Resampling.LANCZOS, box) + + +def multicrop_pic(image: Image, mindim, maxdim, minarea, maxarea, objective, threshold): + iw, ih = image.size + err = lambda w, h: 1 - (lambda x: x if x < 1 else 1 / x)(iw / ih / (w / h)) + wh = max(((w, h) for w in range(mindim, maxdim + 1, 64) for h in range(mindim, maxdim + 1, 64) + if minarea <= w * h <= maxarea and err(w, h) <= threshold), + key=lambda wh: (wh[0] * wh[1], -err(*wh))[::1 if objective == 'Maximize area' else -1], + default=None + ) + return wh and center_crop(image, *wh) + + +class ScriptPostprocessingAutosizedCrop(scripts_postprocessing.ScriptPostprocessing): + name = "Auto-sized crop" + order = 4000 + + def ui(self): + with ui_components.InputAccordion(False, label="Auto-sized crop") as enable: + gr.Markdown('Each image is center-cropped with an automatically chosen width and height.') + with gr.Row(): + mindim = gr.Slider(minimum=64, maximum=2048, step=8, label="Dimension lower bound", value=384, elem_id="postprocess_multicrop_mindim") + maxdim = gr.Slider(minimum=64, maximum=2048, step=8, label="Dimension upper bound", value=768, elem_id="postprocess_multicrop_maxdim") + with gr.Row(): + minarea = gr.Slider(minimum=64 * 64, maximum=2048 * 2048, step=1, label="Area lower bound", value=64 * 64, elem_id="postprocess_multicrop_minarea") + maxarea = gr.Slider(minimum=64 * 64, maximum=2048 * 2048, step=1, label="Area upper bound", value=640 * 640, elem_id="postprocess_multicrop_maxarea") + with gr.Row(): + objective = gr.Radio(["Maximize area", "Minimize error"], value="Maximize area", label="Resizing objective", elem_id="postprocess_multicrop_objective") + threshold = gr.Slider(minimum=0, maximum=1, step=0.01, label="Error threshold", value=0.1, elem_id="postprocess_multicrop_threshold") + + return { + "enable": enable, + "mindim": mindim, + "maxdim": maxdim, + "minarea": minarea, + "maxarea": maxarea, + "objective": objective, + "threshold": threshold, + } + + def process(self, pp: scripts_postprocessing.PostprocessedImage, enable, mindim, maxdim, minarea, maxarea, objective, threshold): + if not enable: + return + + cropped = multicrop_pic(pp.image, mindim, maxdim, minarea, maxarea, objective, threshold) + if cropped is not None: + pp.image = cropped + else: + print(f"skipped {pp.image.width}x{pp.image.height} image (can't find suitable size within error threshold)") From a5f61aa8c5933d8e5a0e0aa841138eeaccd86d62 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sat, 2 Dec 2023 18:03:34 +0300 Subject: [PATCH 0356/1174] potential fix for #14172 --- modules/sd_hijack.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 3d340fc9b21..14fe62c73cb 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -38,6 +38,10 @@ optimizers = [] current_optimizer: sd_hijack_optimizations.SdOptimization = None +ldm_original_forward = patches.patch(__file__, ldm.modules.diffusionmodules.openaimodel.UNetModel, "forward", sd_unet.UNetModel_forward) +sgm_original_forward = patches.patch(__file__, sgm.modules.diffusionmodules.openaimodel.UNetModel, "forward", sd_unet.UNetModel_forward) + + def list_optimizers(): new_optimizers = script_callbacks.list_optimizers_callback() @@ -255,9 +259,6 @@ def flatten(el): import modules.models.diffusion.ddpm_edit - ldm_original_forward = patches.patch(__file__, ldm.modules.diffusionmodules.openaimodel.UNetModel, "forward", sd_unet.UNetModel_forward) - sgm_original_forward = patches.patch(__file__, sgm.modules.diffusionmodules.openaimodel.UNetModel, "forward", sd_unet.UNetModel_forward) - if isinstance(m, ldm.models.diffusion.ddpm.LatentDiffusion): sd_unet.original_forward = ldm_original_forward elif isinstance(m, modules.models.diffusion.ddpm_edit.LatentDiffusion): @@ -303,11 +304,6 @@ def undo_hijack(self, m): self.layers = None self.clip = None - patches.undo(__file__, ldm.modules.diffusionmodules.openaimodel.UNetModel, "forward") - patches.undo(__file__, sgm.modules.diffusionmodules.openaimodel.UNetModel, "forward") - - sd_unet.original_forward = None - def apply_circular(self, enable): if self.circular_enabled == enable: From ac02216e540cd581f9169c6c791e55721e3117b0 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sat, 2 Dec 2023 19:35:47 +0300 Subject: [PATCH 0357/1174] alternate implementation for unet forward replacement that does not depend on hijack being applied --- modules/sd_hijack.py | 7 +++++-- modules/sd_unet.py | 14 ++++++++------ 2 files changed, 13 insertions(+), 8 deletions(-) diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 14fe62c73cb..e139d9964cb 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -38,8 +38,11 @@ optimizers = [] current_optimizer: sd_hijack_optimizations.SdOptimization = None -ldm_original_forward = patches.patch(__file__, ldm.modules.diffusionmodules.openaimodel.UNetModel, "forward", sd_unet.UNetModel_forward) -sgm_original_forward = patches.patch(__file__, sgm.modules.diffusionmodules.openaimodel.UNetModel, "forward", sd_unet.UNetModel_forward) +ldm_patched_forward = sd_unet.create_unet_forward(ldm.modules.diffusionmodules.openaimodel.UNetModel.forward) +ldm_original_forward = patches.patch(__file__, ldm.modules.diffusionmodules.openaimodel.UNetModel, "forward", ldm_patched_forward) + +sgm_patched_forward = sd_unet.create_unet_forward(sgm.modules.diffusionmodules.openaimodel.UNetModel.forward) +sgm_original_forward = patches.patch(__file__, sgm.modules.diffusionmodules.openaimodel.UNetModel, "forward", sgm_patched_forward) def list_optimizers(): diff --git a/modules/sd_unet.py b/modules/sd_unet.py index 6a7bc9e2646..a771849c8c2 100644 --- a/modules/sd_unet.py +++ b/modules/sd_unet.py @@ -5,8 +5,7 @@ unet_options = [] current_unet_option = None current_unet = None -original_forward = None - +original_forward = None # not used, only left temporarily for compatibility def list_unets(): new_unets = script_callbacks.list_unets_callback() @@ -84,9 +83,12 @@ def deactivate(self): pass -def UNetModel_forward(self, x, timesteps=None, context=None, *args, **kwargs): - if current_unet is not None: - return current_unet.forward(x, timesteps, context, *args, **kwargs) +def create_unet_forward(original_forward): + def UNetModel_forward(self, x, timesteps=None, context=None, *args, **kwargs): + if current_unet is not None: + return current_unet.forward(x, timesteps, context, *args, **kwargs) + + return original_forward(self, x, timesteps, context, *args, **kwargs) - return original_forward(self, x, timesteps, context, *args, **kwargs) + return UNetModel_forward From 309a606c2fa645b6b8623f96ea56117e685a47fb Mon Sep 17 00:00:00 2001 From: drhead <1313496+drhead@users.noreply.github.com> Date: Sat, 2 Dec 2023 13:07:45 -0500 Subject: [PATCH 0358/1174] ensure that original alpha bar always exists --- modules/processing.py | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/modules/processing.py b/modules/processing.py index d73c8bfc1c9..bfa59038b10 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -882,15 +882,17 @@ def rescale_zero_terminal_snr_abar(alphas_cumprod): alphas_bar[-1] = 4.8973451890853435e-08 return alphas_bar - if hasattr(p.sd_model, 'alphas_cumprod') and hasattr(p.sd_model, 'alphas_cumprod_original'): - p.sd_model.alphas_cumprod = p.sd_model.alphas_cumprod_original.to(shared.device) - - if opts.use_downcasted_alpha_bar: - p.extra_generation_params['Downcast alphas_cumprod'] = opts.use_downcasted_alpha_bar - p.sd_model.alphas_cumprod = p.sd_model.alphas_cumprod.half().to(shared.device) - if opts.sd_noise_schedule == "Zero Terminal SNR": - p.extra_generation_params['Noise Schedule'] = opts.sd_noise_schedule - p.sd_model.alphas_cumprod = rescale_zero_terminal_snr_abar(p.sd_model.alphas_cumprod).to(shared.device) + if hasattr(p.sd_model, 'alphas_cumprod') and not hasattr(p.sd_model, 'alphas_cumprod_original'): + p.sd_model.alphas_cumprod_original = p.sd_model.alphas_cumprod + + p.sd_model.alphas_cumprod = p.sd_model.alphas_cumprod_original.to(shared.device) + + if opts.use_downcasted_alpha_bar: + p.extra_generation_params['Downcast alphas_cumprod'] = opts.use_downcasted_alpha_bar + p.sd_model.alphas_cumprod = p.sd_model.alphas_cumprod.half().to(shared.device) + if opts.sd_noise_schedule == "Zero Terminal SNR": + p.extra_generation_params['Noise Schedule'] = opts.sd_noise_schedule + p.sd_model.alphas_cumprod = rescale_zero_terminal_snr_abar(p.sd_model.alphas_cumprod).to(shared.device) with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast(): samples_ddim = p.sample(conditioning=p.c, unconditional_conditioning=p.uc, seeds=p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, prompts=p.prompts) From 81c4ddf6ebebe6f18338de3b0391da1d8521a525 Mon Sep 17 00:00:00 2001 From: drhead <1313496+drhead@users.noreply.github.com> Date: Sat, 2 Dec 2023 13:11:00 -0500 Subject: [PATCH 0359/1174] fix linting --- modules/processing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/processing.py b/modules/processing.py index bfa59038b10..eeccea743a7 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -884,7 +884,7 @@ def rescale_zero_terminal_snr_abar(alphas_cumprod): if hasattr(p.sd_model, 'alphas_cumprod') and not hasattr(p.sd_model, 'alphas_cumprod_original'): p.sd_model.alphas_cumprod_original = p.sd_model.alphas_cumprod - + p.sd_model.alphas_cumprod = p.sd_model.alphas_cumprod_original.to(shared.device) if opts.use_downcasted_alpha_bar: From 83e8c322762c545fd589c060811379582926060f Mon Sep 17 00:00:00 2001 From: catboxanon <122327233+catboxanon@users.noreply.github.com> Date: Sat, 2 Dec 2023 13:30:53 -0500 Subject: [PATCH 0360/1174] Fix `save_samples` being checked early when saving masked composite --- modules/processing.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/modules/processing.py b/modules/processing.py index 5ab6dddef39..4f265801c5e 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -938,14 +938,14 @@ def infotext(index=0, use_main_prompt=False): if opts.enable_pnginfo: image.info["parameters"] = text output_images.append(image) - if save_samples and hasattr(p, 'mask_for_overlay') and p.mask_for_overlay and any([opts.save_mask, opts.save_mask_composite, opts.return_mask, opts.return_mask_composite]): + if hasattr(p, 'mask_for_overlay') and p.mask_for_overlay and any([opts.save_mask, opts.save_mask_composite, opts.return_mask, opts.return_mask_composite]): image_mask = p.mask_for_overlay.convert('RGB') image_mask_composite = Image.composite(image.convert('RGBA').convert('RGBa'), Image.new('RGBa', image.size), images.resize_image(2, p.mask_for_overlay, image.width, image.height).convert('L')).convert('RGBA') - if opts.save_mask: + if save_samples and opts.save_mask: images.save_image(image_mask, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-mask") - if opts.save_mask_composite: + if save_samples and opts.save_mask_composite: images.save_image(image_mask_composite, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-mask-composite") if opts.return_mask: From 4a43334376d9e116f7a1446f042f9af9c0484fc6 Mon Sep 17 00:00:00 2001 From: drhead Date: Sat, 2 Dec 2023 14:05:42 -0500 Subject: [PATCH 0361/1174] Revert 309a606c --- modules/processing.py | 20 +++++++++----------- 1 file changed, 9 insertions(+), 11 deletions(-) diff --git a/modules/processing.py b/modules/processing.py index eeccea743a7..d73c8bfc1c9 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -882,17 +882,15 @@ def rescale_zero_terminal_snr_abar(alphas_cumprod): alphas_bar[-1] = 4.8973451890853435e-08 return alphas_bar - if hasattr(p.sd_model, 'alphas_cumprod') and not hasattr(p.sd_model, 'alphas_cumprod_original'): - p.sd_model.alphas_cumprod_original = p.sd_model.alphas_cumprod - - p.sd_model.alphas_cumprod = p.sd_model.alphas_cumprod_original.to(shared.device) - - if opts.use_downcasted_alpha_bar: - p.extra_generation_params['Downcast alphas_cumprod'] = opts.use_downcasted_alpha_bar - p.sd_model.alphas_cumprod = p.sd_model.alphas_cumprod.half().to(shared.device) - if opts.sd_noise_schedule == "Zero Terminal SNR": - p.extra_generation_params['Noise Schedule'] = opts.sd_noise_schedule - p.sd_model.alphas_cumprod = rescale_zero_terminal_snr_abar(p.sd_model.alphas_cumprod).to(shared.device) + if hasattr(p.sd_model, 'alphas_cumprod') and hasattr(p.sd_model, 'alphas_cumprod_original'): + p.sd_model.alphas_cumprod = p.sd_model.alphas_cumprod_original.to(shared.device) + + if opts.use_downcasted_alpha_bar: + p.extra_generation_params['Downcast alphas_cumprod'] = opts.use_downcasted_alpha_bar + p.sd_model.alphas_cumprod = p.sd_model.alphas_cumprod.half().to(shared.device) + if opts.sd_noise_schedule == "Zero Terminal SNR": + p.extra_generation_params['Noise Schedule'] = opts.sd_noise_schedule + p.sd_model.alphas_cumprod = rescale_zero_terminal_snr_abar(p.sd_model.alphas_cumprod).to(shared.device) with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast(): samples_ddim = p.sample(conditioning=p.c, unconditional_conditioning=p.uc, seeds=p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, prompts=p.prompts) From dc1adeecdd02f3fb910481e808a6d60a77100fea Mon Sep 17 00:00:00 2001 From: drhead Date: Sat, 2 Dec 2023 14:06:56 -0500 Subject: [PATCH 0362/1174] Create alphas_cumprod_original on full precision path --- modules/sd_models.py | 1 + 1 file changed, 1 insertion(+) diff --git a/modules/sd_models.py b/modules/sd_models.py index de80a493a6f..976c7d5be7b 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -374,6 +374,7 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer if shared.cmd_opts.no_half: model.float() + model.alphas_cumprod_original = alphas_cumprod devices.dtype_unet = torch.float32 timer.record("apply float()") else: From 78acdcf677a96894651ff0d7d8287f2a994f3781 Mon Sep 17 00:00:00 2001 From: drhead Date: Sat, 2 Dec 2023 14:09:18 -0500 Subject: [PATCH 0363/1174] fix variable --- modules/sd_models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/sd_models.py b/modules/sd_models.py index 976c7d5be7b..5a19a00a50a 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -374,7 +374,7 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer if shared.cmd_opts.no_half: model.float() - model.alphas_cumprod_original = alphas_cumprod + model.alphas_cumprod_original = model.alphas_cumprod devices.dtype_unet = torch.float32 timer.record("apply float()") else: From 9528d66c9479d02c83b8db6107f6b0cb741612dc Mon Sep 17 00:00:00 2001 From: catboxanon <122327233+catboxanon@users.noreply.github.com> Date: Sat, 2 Dec 2023 14:56:26 -0500 Subject: [PATCH 0364/1174] Re-add setting lost as part of e294e46 --- modules/shared_options.py | 1 + 1 file changed, 1 insertion(+) diff --git a/modules/shared_options.py b/modules/shared_options.py index 859dee40456..e5de0d018fb 100644 --- a/modules/shared_options.py +++ b/modules/shared_options.py @@ -255,6 +255,7 @@ "keyedit_precision_attention": OptionInfo(0.1, "Precision for (attention:1.1) when editing the prompt with Ctrl+up/down", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}), "keyedit_precision_extra": OptionInfo(0.05, "Precision for when editing the prompt with Ctrl+up/down", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}), "keyedit_delimiters": OptionInfo(r".,\/!?%^*;:{}=`~() ", "Word delimiters when editing the prompt with Ctrl+up/down"), + "keyedit_delimiters_whitespace": OptionInfo(["Tab", "Carriage Return", "Line Feed"], "Ctrl+up/down whitespace delimiters", gr.CheckboxGroup, lambda: {"choices": ["Tab", "Carriage Return", "Line Feed"]}), "disable_token_counters": OptionInfo(False, "Disable prompt token counters").needs_reload_ui(), })) From 609dea36ea919aa7db42fd4233c416a45c74578b Mon Sep 17 00:00:00 2001 From: CodeHatchling Date: Sat, 2 Dec 2023 18:56:49 -0700 Subject: [PATCH 0365/1174] Added utility functions related to processing masks. --- modules/images.py | 191 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 191 insertions(+) diff --git a/modules/images.py b/modules/images.py index eb644733898..b5a0cead67a 100644 --- a/modules/images.py +++ b/modules/images.py @@ -776,3 +776,194 @@ def flatten(img, bgcolor): img = background return img.convert('RGB') + + +def weighted_histogram_filter(img, kernel, kernel_center, percentile_min=0.0, percentile_max=1.0, min_width=1.0): + """ + Generalization convolution filter capable of applying + weighted mean, median, maximum, and minimum filters + parametrically using an arbitrary kernel. + + Args: + img (nparray): + The image, a 2-D array of floats, to which the filter is being applied. + kernel (nparray): + The kernel, a 2-D array of floats. + kernel_center (nparray): + The kernel center coordinate, a 1-D array with two elements. + percentile_min (float): + The lower bound of the histogram window used by the filter, + from 0 to 1. + percentile_max (float): + The upper bound of the histogram window used by the filter, + from 0 to 1. + min_width (float): + The minimum size of the histogram window bounds, in weight units. + Must be greater than 0. + + Returns: + (nparray): A filtered copy of the input image "img", a 2-D array of floats. + """ + + # Converts an index tuple into a vector. + def vec(x): + return np.array(x) + + kernel_min = -kernel_center + kernel_max = vec(kernel.shape) - kernel_center + + def weighted_histogram_filter_single(idx): + idx = vec(idx) + min_index = np.maximum(0, idx + kernel_min) + max_index = np.minimum(vec(img.shape), idx + kernel_max) + window_shape = max_index - min_index + + class WeightedElement: + """ + An element of the histogram, its weight + and bounds. + """ + def __init__(self, value, weight): + self.value: float = value + self.weight: float = weight + self.window_min: float = 0.0 + self.window_max: float = 1.0 + + # Collect the values in the image as WeightedElements, + # weighted by their corresponding kernel values. + values = [] + for window_tup in np.ndindex(tuple(window_shape)): + window_index = vec(window_tup) + image_index = window_index + min_index + centered_kernel_index = image_index - idx + kernel_index = centered_kernel_index + kernel_center + element = WeightedElement(img[tuple(image_index)], kernel[tuple(kernel_index)]) + values.append(element) + + def sort_key(x: WeightedElement): + return x.value + + values.sort(key=sort_key) + + # Calculate the height of the stack (sum) + # and each sample's range they occupy in the stack + sum = 0 + for i in range(len(values)): + values[i].window_min = sum + sum += values[i].weight + values[i].window_max = sum + + # Calculate what range of this stack ("window") + # we want to get the weighted average across. + window_min = sum * percentile_min + window_max = sum * percentile_max + window_width = window_max - window_min + + # Ensure the window is within the stack and at least a certain size. + if window_width < min_width: + window_center = (window_min + window_max) / 2 + window_min = window_center - min_width / 2 + window_max = window_center + min_width / 2 + + if window_max > sum: + window_max = sum + window_min = sum - min_width + + if window_min < 0: + window_min = 0 + window_max = min_width + + value = 0 + value_weight = 0 + + # Get the weighted average of all the samples + # that overlap with the window, weighted + # by the size of their overlap. + for i in range(len(values)): + if window_min >= values[i].window_max: + continue + if window_max <= values[i].window_min: + break + + s = max(window_min, values[i].window_min) + e = min(window_max, values[i].window_max) + w = e - s + + value += values[i].value * w + value_weight += w + + return value / value_weight if value_weight != 0 else 0 + + img_out = img.copy() + + # Apply the kernel operation over each pixel. + for index in np.ndindex(img.shape): + img_out[index] = weighted_histogram_filter_single(index) + + return img_out + +def smoothstep(x): + """ + The smoothstep function, input should be clamped to 0-1 range. + Turns a diagonal line (f(x) = x) into a sigmoid-like curve. + """ + return x * x * (3 - 2 * x) + +def smootherstep(x): + """ + The smootherstep function, input should be clamped to 0-1 range. + Turns a diagonal line (f(x) = x) into a sigmoid-like curve. + """ + return x * x * x * (x * (6 * x - 15) + 10) + + +def get_gaussian_kernel(stddev_radius=1.0, max_radius=2): + """ + Creates a Gaussian kernel with thresholded edges. + + Args: + stddev_radius (float): + Standard deviation of the gaussian kernel, in pixels. + max_radius (int): + The size of the filter kernel. The number of pixels is (max_radius*2+1) ** 2. + The kernel is thresholded so that any values one pixel beyond this radius + is weighted at 0. + + Returns: + (nparray, nparray): A kernel array (shape: (N, N)), its center coordinate (shape: (2)) + """ + # Evaluates a 0-1 normalized gaussian function for a given square distance from the mean. + def gaussian(sqr_mag): + return math.exp(-sqr_mag / (stddev_radius * stddev_radius)) + + # Helper function for converting a tuple to an array. + def vec(x): + return np.array(x) + + """ + Since a gaussian is unbounded, we need to limit ourselves + to a finite range. + We taper the ends off at the end of that range so they equal zero + while preserving the maximum value of 1 at the mean. + """ + zero_radius = max_radius + 1.0 + gauss_zero = gaussian(zero_radius * zero_radius) + gauss_kernel_scale = 1 / (1 - gauss_zero) + + def gaussian_kernel_func(coordinate): + x = coordinate[0] ** 2.0 + coordinate[1] ** 2.0 + x = gaussian(x) + x -= gauss_zero + x /= gauss_kernel_scale + x = max(0.0, x) + return x + + size = max_radius * 2 + 1 + kernel_center = max_radius + kernel = np.zeros((size, size)) + + for index in np.ndindex(kernel.shape): + kernel[index] = gaussian_kernel_func(vec(index) - kernel_center) + + return kernel, kernel_center + From 73ab982d1b7394574d1cf2e0a151bc457eeed769 Mon Sep 17 00:00:00 2001 From: CodeHatchling Date: Sat, 2 Dec 2023 21:07:02 -0700 Subject: [PATCH 0366/1174] Blend masks are now produced afterward, based on an estimate of the visual difference between the original and modified latent images. This should remove ghosting and clipping artifacts from masks, while preserving the details of largely unchanged content. --- modules/processing.py | 119 ++++++++++++++++++++++++++++++++---------- 1 file changed, 90 insertions(+), 29 deletions(-) diff --git a/modules/processing.py b/modules/processing.py index 92fdebadd98..ad716e11f7c 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -9,7 +9,7 @@ import torch import numpy as np -from PIL import Image, ImageOps +from PIL import Image, ImageOps, ImageFilter import random import cv2 from skimage import exposure @@ -62,6 +62,16 @@ def apply_color_correction(correction, original_image): return image.convert('RGB') +def uncrop(image, dest_size, paste_loc): + x, y, w, h = paste_loc + base_image = Image.new('RGBA', dest_size) + image = images.resize_image(1, image, w, h) + base_image.paste(image, (x, y)) + image = base_image + + return image + + def apply_overlay(image, paste_loc, index, overlays): if overlays is None or index >= len(overlays): return image @@ -69,11 +79,7 @@ def apply_overlay(image, paste_loc, index, overlays): overlay = overlays[index] if paste_loc is not None: - x, y, w, h = paste_loc - base_image = Image.new('RGBA', (overlay.width, overlay.height)) - image = images.resize_image(1, image, w, h) - base_image.paste(image, (x, y)) - image = base_image + image = uncrop(image, (overlay.width, overlay.height), paste_loc) image = image.convert('RGBA') image.alpha_composite(overlay) @@ -140,6 +146,7 @@ class StableDiffusionProcessing: do_not_save_grid: bool = False extra_generation_params: dict[str, Any] = None overlay_images: list = None + masks_for_overlay: list = None eta: float = None do_not_reload_embeddings: bool = False denoising_strength: float = 0 @@ -865,11 +872,66 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: if getattr(samples_ddim, 'already_decoded', False): x_samples_ddim = samples_ddim + # todo: generate masks the old fashioned way else: if opts.sd_vae_decode_method != 'Full': p.extra_generation_params['VAE Decoder'] = opts.sd_vae_decode_method - x_samples_ddim = decode_latent_batch(p.sd_model, samples_ddim, target_device=devices.cpu, check_for_nans=True) + # Generate the mask(s) based on similarity between the original and denoised latent vectors + if getattr(p, "image_mask", None) is not None: + # latent_mask = p.nmask[0].float().cpu() + + # convert the original mask into a form we use to scale distances for thresholding + # mask_scalar = 1-(torch.clamp(latent_mask, min=0, max=1) ** (p.mask_blend_scale / 2)) + # mask_scalar = mask_scalar / (1.00001-mask_scalar) + # mask_scalar = mask_scalar.numpy() + + latent_orig = p.init_latent + latent_proc = samples_ddim + latent_distance = torch.norm(latent_proc - latent_orig, p=2, dim=1) + + kernel, kernel_center = images.get_gaussian_kernel(stddev_radius=1.5, max_radius=2) + + for i, (distance_map, overlay_image) in enumerate(zip(latent_distance, p.overlay_images)): + converted_mask = distance_map.float().cpu().numpy() + converted_mask = images.weighted_histogram_filter(converted_mask, kernel, kernel_center, + percentile_min=0.9, percentile_max=1, min_width=1) + converted_mask = images.weighted_histogram_filter(converted_mask, kernel, kernel_center, + percentile_min=0.25, percentile_max=0.75, min_width=1) + + # The distance at which opacity of original decreases to 50% + # half_weighted_distance = 1 # * mask_scalar + # converted_mask = converted_mask / half_weighted_distance + + converted_mask = 1 / (1 + converted_mask ** 2) + converted_mask = images.smootherstep(converted_mask) + converted_mask = 1 - converted_mask + converted_mask = 255. * converted_mask + converted_mask = converted_mask.astype(np.uint8) + converted_mask = Image.fromarray(converted_mask) + converted_mask = images.resize_image(2, converted_mask, p.width, p.height) + converted_mask = create_binary_mask(converted_mask) + + # Remove aliasing artifacts using a gaussian blur. + converted_mask = converted_mask.filter(ImageFilter.GaussianBlur(radius=4)) + + # Expand the mask to fit the whole image if needed. + if p.paste_to is not None: + converted_mask = uncrop(converted_mask, + (overlay_image.width, overlay_image.height), + p.paste_to) + + p.masks_for_overlay[i] = converted_mask + + image_masked = Image.new('RGBa', (overlay_image.width, overlay_image.height)) + image_masked.paste(overlay_image.convert("RGBA").convert("RGBa"), + mask=ImageOps.invert(converted_mask.convert('L'))) + + p.overlay_images[i] = image_masked.convert('RGBA') + + x_samples_ddim = decode_latent_batch(p.sd_model, samples_ddim, + target_device=devices.cpu, + check_for_nans=True) x_samples_ddim = torch.stack(x_samples_ddim).float() x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) @@ -892,7 +954,9 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: x_samples_ddim = batch_params.images def infotext(index=0, use_main_prompt=False): - return create_infotext(p, p.prompts, p.seeds, p.subseeds, use_main_prompt=use_main_prompt, index=index, all_negative_prompts=p.negative_prompts) + return create_infotext(p, p.prompts, p.seeds, p.subseeds, + use_main_prompt=use_main_prompt, index=index, + all_negative_prompts=p.negative_prompts) save_samples = p.save_samples() @@ -923,19 +987,27 @@ def infotext(index=0, use_main_prompt=False): images.save_image(image_without_cc, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-before-color-correction") image = apply_color_correction(p.color_corrections[i], image) + # If the intention is to show the output from the model + # that is being composited over the original image, + # we need to keep the original image around + # and use it in the composite step. + original_denoised_image = image.copy() image = apply_overlay(image, p.paste_to, i, p.overlay_images) if save_samples: - images.save_image(image, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p) + images.save_image(image, p.outpath_samples, "", p.seeds[i], + p.prompts[i], opts.samples_format, info=infotext(i), p=p) text = infotext(i) infotexts.append(text) if opts.enable_pnginfo: image.info["parameters"] = text output_images.append(image) - if save_samples and hasattr(p, 'mask_for_overlay') and p.mask_for_overlay and any([opts.save_mask, opts.save_mask_composite, opts.return_mask, opts.return_mask_composite]): - image_mask = p.mask_for_overlay.convert('RGB') - image_mask_composite = Image.composite(image.convert('RGBA').convert('RGBa'), Image.new('RGBa', image.size), images.resize_image(2, p.mask_for_overlay, image.width, image.height).convert('L')).convert('RGBA') + if save_samples and hasattr(p, 'masks_for_overlay') and p.masks_for_overlay and any([opts.save_mask, opts.save_mask_composite, opts.return_mask, opts.return_mask_composite]): + image_mask = p.masks_for_overlay[i].convert('RGB') + image_mask_composite = Image.composite( + original_denoised_image.convert('RGBA').convert('RGBa'), Image.new('RGBa', image.size), + images.resize_image(2, p.masks_for_overlay[i], image.width, image.height).convert('L')).convert('RGBA') if opts.save_mask: images.save_image(image_mask, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-mask") @@ -1364,7 +1436,6 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): nmask: torch.Tensor = field(default=None, init=False) image_conditioning: torch.Tensor = field(default=None, init=False) init_img_hash: str = field(default=None, init=False) - mask_for_overlay: Image = field(default=None, init=False) init_latent: torch.Tensor = field(default=None, init=False) def __post_init__(self): @@ -1415,12 +1486,6 @@ def init(self, all_prompts, all_seeds, all_subseeds): image_mask = Image.fromarray(np_mask) if self.inpaint_full_res: - np_mask = np.array(image_mask).astype(np.float32) - np_mask /= 255 - np_mask = 1-pow(1-np_mask, 100) - np_mask *= 255 - np_mask = np.clip(np_mask, 0, 255).astype(np.uint8) - self.mask_for_overlay = Image.fromarray(np_mask) mask = image_mask.convert('L') crop_region = masking.get_crop_region(np.array(mask), self.inpaint_full_res_padding) crop_region = masking.expand_crop_region(crop_region, self.width, self.height, mask.width, mask.height) @@ -1431,13 +1496,8 @@ def init(self, all_prompts, all_seeds, all_subseeds): self.paste_to = (x1, y1, x2-x1, y2-y1) else: image_mask = images.resize_image(self.resize_mode, image_mask, self.width, self.height) - np_mask = np.array(image_mask).astype(np.float32) - np_mask /= 255 - np_mask = 1-pow(1-np_mask, 100) - np_mask *= 255 - np_mask = np.clip(np_mask, 0, 255).astype(np.uint8) - self.mask_for_overlay = Image.fromarray(np_mask) + self.masks_for_overlay = [] self.overlay_images = [] latent_mask = self.latent_mask if self.latent_mask is not None else image_mask @@ -1459,10 +1519,8 @@ def init(self, all_prompts, all_seeds, all_subseeds): image = images.resize_image(self.resize_mode, image, self.width, self.height) if image_mask is not None: - image_masked = Image.new('RGBa', (image.width, image.height)) - image_masked.paste(image.convert("RGBA").convert("RGBa"), mask=ImageOps.invert(self.mask_for_overlay.convert('L'))) - - self.overlay_images.append(image_masked.convert('RGBA')) + self.overlay_images.append(image) + self.masks_for_overlay.append(image_mask) # crop_region is not None if we are doing inpaint full res if crop_region is not None: @@ -1486,6 +1544,9 @@ def init(self, all_prompts, all_seeds, all_subseeds): if self.overlay_images is not None: self.overlay_images = self.overlay_images * self.batch_size + if self.masks_for_overlay is not None: + self.masks_for_overlay = self.masks_for_overlay * self.batch_size + if self.color_corrections is not None and len(self.color_corrections) == 1: self.color_corrections = self.color_corrections * self.batch_size From bb04d400c95df01d191ef6c1a43e66b95425fa33 Mon Sep 17 00:00:00 2001 From: CodeHatchling Date: Sat, 2 Dec 2023 21:08:26 -0700 Subject: [PATCH 0367/1174] Rewrote latent_blend() to use in-place operations and to aggressively "del" references with the intention of minimizing allocations and easing garbage collection. --- modules/sd_samplers_cfg_denoiser.py | 41 ++++++++++++++++++++--------- 1 file changed, 28 insertions(+), 13 deletions(-) diff --git a/modules/sd_samplers_cfg_denoiser.py b/modules/sd_samplers_cfg_denoiser.py index ceb612d7936..efbe7a403de 100644 --- a/modules/sd_samplers_cfg_denoiser.py +++ b/modules/sd_samplers_cfg_denoiser.py @@ -102,29 +102,44 @@ def latent_blend(a, b, t): The "detail_preservation" factor biases the magnitude interpolation towards the larger of the two magnitudes. """ - # Record the original latent vector magnitudes. - # We bring them to a power so that larger magnitudes are favored over smaller ones. - # 64-bit operations are used here to allow large exponents. - a_magnitude = torch.norm(a, p=2, dim=1).to(torch.float64) ** self.inpaint_detail_preservation - b_magnitude = torch.norm(b, p=2, dim=1).to(torch.float64) ** self.inpaint_detail_preservation + # NOTE: We use inplace operations wherever possible. one_minus_t = 1 - t - # Interpolate the powered magnitudes, then un-power them (bring them back to a power of 1). - interp_magnitude = (a_magnitude * one_minus_t + b_magnitude * t) ** (1 / self.inpaint_detail_preservation) - # Linearly interpolate the image vectors. - image_interp = a * one_minus_t + b * t + a_scaled = a * one_minus_t + b_scaled = b * t + image_interp = a_scaled + image_interp.add_(b_scaled) + result_type = image_interp.dtype + del a_scaled, b_scaled # Calculate the magnitude of the interpolated vectors. (We will remove this magnitude.) # 64-bit operations are used here to allow large exponents. - image_interp_magnitude = torch.norm(image_interp, p=2, dim=1).to(torch.float64) + 0.0001 + current_magnitude = torch.norm(image_interp, p=2, dim=1).to(torch.float64).add_(0.00001) + + # Interpolate the powered magnitudes, then un-power them (bring them back to a power of 1). + a_magnitude = torch.norm(a, p=2, dim=1).to(torch.float64).pow_(self.inpaint_detail_preservation) * one_minus_t + b_magnitude = torch.norm(b, p=2, dim=1).to(torch.float64).pow_(self.inpaint_detail_preservation) * t + desired_magnitude = a_magnitude + desired_magnitude.add_(b_magnitude).pow_(1 / self.inpaint_detail_preservation) + del a_magnitude, b_magnitude, one_minus_t # Change the linearly interpolated image vectors' magnitudes to the value we want. # This is the last 64-bit operation. - image_interp *= (interp_magnitude / image_interp_magnitude).to(image_interp.dtype) - - return image_interp + image_interp_scaling_factor = desired_magnitude + image_interp_scaling_factor.div_(current_magnitude) + image_interp_scaled = image_interp + image_interp_scaled.mul_(image_interp_scaling_factor) + del current_magnitude + del desired_magnitude + del image_interp + del image_interp_scaling_factor + + image_interp_scaled = image_interp_scaled.to(result_type) + del result_type + + return image_interp_scaled def get_modified_nmask(nmask, _sigma): """ From d3fdc4af61b7560eede52290e1ede48185680089 Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Sun, 3 Dec 2023 18:22:00 +0900 Subject: [PATCH 0368/1174] rework mask and mask_composite logic --- modules/processing.py | 29 ++++++++++++++--------------- 1 file changed, 14 insertions(+), 15 deletions(-) diff --git a/modules/processing.py b/modules/processing.py index 4f265801c5e..6f01c95f5b6 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -938,21 +938,20 @@ def infotext(index=0, use_main_prompt=False): if opts.enable_pnginfo: image.info["parameters"] = text output_images.append(image) - if hasattr(p, 'mask_for_overlay') and p.mask_for_overlay and any([opts.save_mask, opts.save_mask_composite, opts.return_mask, opts.return_mask_composite]): - image_mask = p.mask_for_overlay.convert('RGB') - image_mask_composite = Image.composite(image.convert('RGBA').convert('RGBa'), Image.new('RGBa', image.size), images.resize_image(2, p.mask_for_overlay, image.width, image.height).convert('L')).convert('RGBA') - - if save_samples and opts.save_mask: - images.save_image(image_mask, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-mask") - - if save_samples and opts.save_mask_composite: - images.save_image(image_mask_composite, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-mask-composite") - - if opts.return_mask: - output_images.append(image_mask) - - if opts.return_mask_composite: - output_images.append(image_mask_composite) + if hasattr(p, 'mask_for_overlay') and p.mask_for_overlay: + if opts.return_mask or opts.save_mask: + image_mask = p.mask_for_overlay.convert('RGB') + if save_samples and opts.save_mask: + images.save_image(image_mask, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-mask") + if opts.return_mask: + output_images.append(image_mask) + + if opts.return_mask_composite or opts.save_mask_composite: + image_mask_composite = Image.composite(image.convert('RGBA').convert('RGBa'), Image.new('RGBa', image.size), images.resize_image(2, p.mask_for_overlay, image.width, image.height).convert('L')).convert('RGBA') + if save_samples and opts.save_mask_composite: + images.save_image(image_mask_composite, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-mask-composite") + if opts.return_mask_composite: + output_images.append(image_mask_composite) del x_samples_ddim From d92ce145bba714c5b257b9853aa22681233651b8 Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Sun, 3 Dec 2023 16:50:20 +0200 Subject: [PATCH 0369/1174] Add import_hook hack to work around basicsr incompatibility Fixes #13985 --- modules/import_hook.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/modules/import_hook.py b/modules/import_hook.py index 28c67dfa897..eba9a372929 100644 --- a/modules/import_hook.py +++ b/modules/import_hook.py @@ -3,3 +3,14 @@ # this will break any attempt to import xformers which will prevent stability diffusion repo from trying to use it if "--xformers" not in "".join(sys.argv): sys.modules["xformers"] = None + +# Hack to fix a changed import in torchvision 0.17+, which otherwise breaks +# basicsr; see https://github.com/AUTOMATIC1111/stable-diffusion-webui/issues/13985 +try: + import torchvision.transforms.functional_tensor # noqa: F401 +except ImportError: + try: + import torchvision.transforms.functional as functional + sys.modules["torchvision.transforms.functional_tensor"] = functional + except ImportError: + pass # shrug... From 28a2b5b4aab43424733039c31d910e8b8dd507cd Mon Sep 17 00:00:00 2001 From: CodeHatchling Date: Sun, 3 Dec 2023 14:20:20 -0700 Subject: [PATCH 0370/1174] Fixed a math mistake. --- modules/images.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/images.py b/modules/images.py index 6648097e851..9495349866b 100644 --- a/modules/images.py +++ b/modules/images.py @@ -969,7 +969,7 @@ def gaussian_kernel_func(coordinate): x = coordinate[0] ** 2.0 + coordinate[1] ** 2.0 x = gaussian(x) x -= gauss_zero - x /= gauss_kernel_scale + x *= gauss_kernel_scale x = max(0.0, x) return x From 552f8bc832cd21ee0338e08b6a701687d0d79fad Mon Sep 17 00:00:00 2001 From: CodeHatchling Date: Sun, 3 Dec 2023 14:49:41 -0700 Subject: [PATCH 0371/1174] "Uncrop" the original denoised image for the composite step, fixing a "ValueError: Images do not match" *shudder* --- modules/processing.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/modules/processing.py b/modules/processing.py index 66aaab83150..cd7216f8313 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -994,6 +994,10 @@ def infotext(index=0, use_main_prompt=False): # we need to keep the original image around # and use it in the composite step. original_denoised_image = image.copy() + + if p.paste_to is not None: + original_denoised_image = uncrop(original_denoised_image, (p.overlay_images[i].width, p.overlay_images[i].height), p.paste_to) + image = apply_overlay(image, p.paste_to, i, p.overlay_images) if save_samples: From 639ccf254bd4d072f33333abb1ada3d08aaab470 Mon Sep 17 00:00:00 2001 From: illtellyoulater <3078931+illtellyoulater@users.noreply.github.com> Date: Mon, 4 Dec 2023 02:35:35 +0000 Subject: [PATCH 0372/1174] Update launch_utils.py to fix wrong dep. checks and reinstalls Fixes failing dependency checks for extensions having a different package name and import name (for example ffmpeg-python / ffmpeg), which currently is causing the unneeded reinstall of packages at runtime. In fact with current code, the same string is used when installing a package and when checking for its presence, as you can see in the following example: > launch_utils.run_pip("install ffmpeg-python", "required package") [ Installing required package: "ffmpeg-python" ... ] [ Installed ] > launch_utils.is_installed("ffmpeg-python") False ... which would actually return true with: > launch_utils.is_installed("ffmpeg") True --- modules/launch_utils.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/modules/launch_utils.py b/modules/launch_utils.py index 6e54d06367c..6664c5e0415 100644 --- a/modules/launch_utils.py +++ b/modules/launch_utils.py @@ -6,6 +6,7 @@ import shutil import sys import importlib.util +import importlib.metadata import platform import json from functools import lru_cache @@ -119,11 +120,16 @@ def run(command, desc=None, errdesc=None, custom_env=None, live: bool = default_ def is_installed(package): try: - spec = importlib.util.find_spec(package) - except ModuleNotFoundError: - return False + dist = importlib.metadata.distribution(package) + except importlib.metadata.PackageNotFoundError: + try: + spec = importlib.util.find_spec(package) + except ModuleNotFoundError: + return False - return spec is not None + return spec is not None + + return dist is not None def repo_dir(name): From 06725af40b94a146c56e693a47cbec6d0af55396 Mon Sep 17 00:00:00 2001 From: missionfloyd Date: Sun, 3 Dec 2023 21:26:12 -0700 Subject: [PATCH 0373/1174] Lint --- modules/launch_utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/modules/launch_utils.py b/modules/launch_utils.py index 6664c5e0415..e71edd01d2b 100644 --- a/modules/launch_utils.py +++ b/modules/launch_utils.py @@ -120,12 +120,12 @@ def run(command, desc=None, errdesc=None, custom_env=None, live: bool = default_ def is_installed(package): try: - dist = importlib.metadata.distribution(package) + dist = importlib.metadata.distribution(package) except importlib.metadata.PackageNotFoundError: - try: + try: spec = importlib.util.find_spec(package) except ModuleNotFoundError: - return False + return False return spec is not None From 9e1f3feb12a7cfe4fd426dd3df5431c805746ecc Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Mon, 4 Dec 2023 09:15:19 +0300 Subject: [PATCH 0374/1174] make webui not crash when running with --disable-all-extensions option --- modules/models/diffusion/ddpm_edit.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/modules/models/diffusion/ddpm_edit.py b/modules/models/diffusion/ddpm_edit.py index b892d5fc7b0..6db340da40b 100644 --- a/modules/models/diffusion/ddpm_edit.py +++ b/modules/models/diffusion/ddpm_edit.py @@ -24,10 +24,15 @@ from ldm.util import log_txt_as_img, exists, default, ismap, isimage, mean_flat, count_params, instantiate_from_config from ldm.modules.ema import LitEma from ldm.modules.distributions.distributions import normal_kl, DiagonalGaussianDistribution -from ldm.models.autoencoder import VQModelInterface, IdentityFirstStage, AutoencoderKL +from ldm.models.autoencoder import IdentityFirstStage, AutoencoderKL from ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor, noise_like from ldm.models.diffusion.ddim import DDIMSampler +try: + from ldm.models.autoencoder import VQModelInterface +except Exception: + class VQModelInterface: + pass __conditioning_keys__ = {'concat': 'c_concat', 'crossattn': 'c_crossattn', From 48fae7ccdc2fe2d2ba8e8cfcb17b56028734e570 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Mon, 4 Dec 2023 09:35:52 +0300 Subject: [PATCH 0375/1174] update changelog --- CHANGELOG.md | 162 +++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 162 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 2c72359fc42..67429bbff0f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,165 @@ +## 1.7.0 + +### Features: +* settings tab rework: add search field, add categories, split UI settings page into many +* add altdiffusion-m18 support ([#13364](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13364)) +* support inference with LyCORIS GLora networks ([#13610](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13610)) +* add lora-embedding bundle system ([#13568](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13568)) +* option to move prompt from top row into generation parameters +* add support for SSD-1B ([#13865](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13865)) +* support inference with OFT networks ([#13692](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13692)) +* script metadata and DAG sorting mechanism ([#13944](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13944)) +* support HyperTile optimization ([#13948](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13948)) +* add support for SD 2.1 Turbo ([#14170](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14170)) +* remove Train->Preprocessing tab and put all its functionality into Extras tab +* initial IPEX support for Intel Arc GPU ([#14171](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14171)) + +### Minor: +* allow reading model hash from images in img2img batch mode ([#12767](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12767)) +* add option to align with sgm repo's sampling implementation ([#12818](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12818)) +* extra field for lora metadata viewer: `ss_output_name` ([#12838](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12838)) +* add action in settings page to calculate all SD checkpoint hashes ([#12909](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12909)) +* add button to copy prompt to style editor ([#12975](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12975)) +* add --skip-load-model-at-start option ([#13253](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13253)) +* write infotext to gif images +* read infotext from gif images ([#13068](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13068)) +* allow configuring the initial state of InputAccordion in ui-config.json ([#13189](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13189)) +* allow editing whitespace delimiters for ctrl+up/ctrl+down prompt editing ([#13444](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13444)) +* prevent accidentally closing popup dialogs ([#13480](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13480)) +* added option to play notification sound or not ([#13631](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13631)) +* show the preview image in the full screen image viewer if available ([#13459](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13459)) +* support for webui.settings.bat ([#13638](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13638)) +* add an option to not print stack traces on ctrl+c +* start/restart generation by Ctrl (Alt) + Enter ([#13644](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13644)) +* update prompts_from_file script to allow concatenating entries with the general prompt ([#13733](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13733)) +* added a visible checkbox to input accordion +* added an option to hide all txt2img/img2img parameters in an accordion ([#13826](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13826)) +* added 'Path' sorting option for Extra network cards ([#13968](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13968)) +* enable prompt hotkeys in style editor ([#13931](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13931)) +* option to show batch img2img results in UI ([#14009](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14009)) +* infotext updates: add option to disregard certain infotext fields, add option to not include VAE in infotext, add explanation to infotext settings page, move some options to infotext settings page +* add FP32 fallback support on sd_vae_approx ([#14046](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14046)) +* support XYZ scripts / split hires path from unet ([#14126](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14126)) +* allow use of mutiple styles csv files ([#14125](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14125)) + +### Extensions and API: +* update gradio to 3.41.2 +* support installed extensions list api ([#12774](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12774)) +* update pnginfo API to return dict with parsed values +* add noisy latent to `ExtraNoiseParams` for callback ([#12856](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12856)) +* show extension datetime in UTC ([#12864](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12864), [#12865](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12865), [#13281](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13281)) +* add an option to choose how to combine hires fix and refiner +* include program version in info response. ([#13135](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13135)) +* sd_unet support for SDXL +* patch DDPM.register_betas so that users can put given_betas in model yaml ([#13276](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13276)) +* xyz_grid: add prepare ([#13266](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13266)) +* allow multiple localization files with same language in extensions ([#13077](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13077)) +* add onEdit function for js and rework token-counter.js to use it +* fix the key error exception when processing override_settings keys ([#13567](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13567)) +* ability for extensions to return custom data via api in response.images ([#13463](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13463)) +* call state.jobnext() before postproces*() ([#13762](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13762)) +* add option to set notification sound volume ([#13884](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13884)) +* update Ruff to 0.1.6 ([#14059](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14059)) +* add Block component creation callback ([#14119](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14119)) +* catch uncaught exception with ui creation scripts ([#14120](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14120)) +* use extension name for determining an extension is installed in the index ([#14063](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14063)) +* update is_installed() from launch_utils.py to fix reinstalling already installed packages ([#14192](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14192)) + +### Bug Fixes: +* fix pix2pix producing bad results +* fix defaults settings page breaking when any of main UI tabs are hidden +* fix error that causes some extra networks to be disabled if both and are present in the prompt +* fix for Reload UI function: if you reload UI on one tab, other opened tabs will no longer stop working +* prevent duplicate resize handler ([#12795](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12795)) +* small typo: vae resolve bug ([#12797](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12797)) +* hide broken image crop tool ([#12792](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12792)) +* don't show hidden samplers in dropdown for XYZ script ([#12780](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12780)) +* fix style editing dialog breaking if it's opened in both img2img and txt2img tabs +* hide --gradio-auth and --api-auth values from /internal/sysinfo report +* add missing infotext for RNG in options ([#12819](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12819)) +* fix notification not playing when built-in webui tab is inactive ([#12834](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12834)) +* honor `--skip-install` for extension installers ([#12832](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12832)) +* don't print blank stdout in extension installers ([#12833](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12833), [#12855](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12855)) +* get progressbar to display correctly in extensions tab +* keep order in list of checkpoints when loading model that doesn't have a checksum +* fix inpainting models in txt2img creating black pictures +* fix generation params regex ([#12876](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12876)) +* fix batch img2img output dir with script ([#12926](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12926)) +* fix #13080 - Hypernetwork/TI preview generation ([#13084](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13084)) +* fix bug with sigma min/max overrides. ([#12995](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12995)) +* more accurate check for enabling cuDNN benchmark on 16XX cards ([#12924](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12924)) +* don't use multicond parser for negative prompt counter ([#13118](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13118)) +* fix data-sort-name containing spaces ([#13412](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13412)) +* update card on correct tab when editing metadata ([#13411](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13411)) +* fix viewing/editing metadata when filename contains an apostrophe ([#13395](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13395)) +* fix: --sd_model in "Prompts from file or textbox" script is not working ([#13302](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13302)) +* better Support for Portable Git ([#13231](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13231)) +* fix issues when webui_dir is not work_dir ([#13210](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13210)) +* fix: lora-bias-backup don't reset cache ([#13178](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13178)) +* account for customizable extra network separators whyen removing extra network text from the prompt ([#12877](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12877)) +* re fix batch img2img output dir with script ([#13170](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13170)) +* fix `--ckpt-dir` path separator and option use `short name` for checkpoint dropdown ([#13139](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13139)) +* consolidated allowed preview formats, Fix extra network `.gif` not woking as preview ([#13121](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13121)) +* fix venv_dir=- environment variable not working as expected on linux ([#13469](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13469)) +* repair unload sd checkpoint button +* edit-attention fixes ([#13533](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13533)) +* fix bug when using --gfpgan-models-path ([#13718](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13718)) +* properly apply sort order for extra network cards when selected from dropdown +* fixes generation restart not working for some users when 'Ctrl+Enter' is pressed ([#13962](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13962)) +* thread safe extra network list_items ([#13014](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13014)) +* fix not able to exit metadata popup when pop up is too big ([#14156](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14156)) +* fix auto focal point crop for opencv >= 4.8 ([#14121](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14121)) +* make 'use-cpu all' actually apply to 'all' ([#14131](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14131)) +* extras tab batch: actually use original filename +* make webui not crash when running with --disable-all-extensions option + +### Other: +* non-local condition ([#12814](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12814)) +* fix minor typos ([#12827](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12827)) +* remove xformers Python version check ([#12842](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12842)) +* style: file-metadata word-break ([#12837](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12837)) +* revert SGM noise multiplier change for img2img because it breaks hires fix +* do not change quicksettings dropdown option when value returned is `None` ([#12854](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12854)) +* [RC 1.6.0 - zoom is partly hidden] Update style.css ([#12839](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12839)) +* chore: change extension time format ([#12851](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12851)) +* WEBUI.SH - Use torch 2.1.0 release candidate for Navi 3 ([#12929](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12929)) +* add Fallback at images.read_info_from_image if exif data was invalid ([#13028](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13028)) +* update cmd arg description ([#12986](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12986)) +* fix: update shared.opts.data when add_option ([#12957](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12957), [#13213](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13213)) +* restore missing tooltips ([#12976](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12976)) +* use default dropdown padding on mobile ([#12880](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12880)) +* put enable console prompts option into settings from commandline args ([#13119](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13119)) +* fix some deprecated types ([#12846](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12846)) +* bump to torchsde==0.2.6 ([#13418](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13418)) +* update dragdrop.js ([#13372](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13372)) +* use orderdict as lru cache:opt/bug ([#13313](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13313)) +* XYZ if not include sub grids do not save sub grid ([#13282](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13282)) +* initialize state.time_start befroe state.job_count ([#13229](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13229)) +* fix fieldname regex ([#13458](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13458)) +* change denoising_strength default to None. ([#13466](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13466)) +* fix regression ([#13475](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13475)) +* fix IndexError ([#13630](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13630)) +* fix: checkpoints_loaded:{checkpoint:state_dict}, model.load_state_dict issue in dict value empty ([#13535](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13535)) +* update bug_report.yml ([#12991](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12991)) +* requirements_versions httpx==0.24.1 ([#13839](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13839)) +* fix parenthesis auto selection ([#13829](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13829)) +* fix #13796 ([#13797](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13797)) +* corrected a typo in `modules/cmd_args.py` ([#13855](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13855)) +* feat: fix randn found element of type float at pos 2 ([#14004](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14004)) +* adds tqdm handler to logging_config.py for progress bar integration ([#13996](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13996)) +* hotfix: call shared.state.end() after postprocessing done ([#13977](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13977)) +* fix dependency address patch 1 ([#13929](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13929)) +* save sysinfo as .json ([#14035](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14035)) +* move exception_records related methods to errors.py ([#14084](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14084)) +* compatibility ([#13936](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13936)) +* json.dump(ensure_ascii=False) ([#14108](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14108)) +* dir buttons start with / so only the correct dir will be shown and no… ([#13957](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13957)) +* alternate implementation for unet forward replacement that does not depend on hijack being applied +* re-add `keyedit_delimiters_whitespace` setting lost as part of commit e294e46 ([#14178](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14178)) +* fix `save_samples` being checked early when saving masked composite ([#14177](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14177)) +* slight optimization for mask and mask_composite ([#14181](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14181)) +* add import_hook hack to work around basicsr/torchvision incompatibility ([#14186](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14186)) + ## 1.6.1 ### Bug Fixes: From aaacf4823241450d88315af9d465d6815119fe0d Mon Sep 17 00:00:00 2001 From: CodeHatchling Date: Mon, 4 Dec 2023 01:27:22 -0700 Subject: [PATCH 0376/1174] Organized the settings and UI of soft inpainting to allow for toggling the feature, and centralizes default values to reduce the amount of copy-pasta. --- modules/img2img.py | 14 +-- modules/processing.py | 5 +- modules/sd_samplers_cfg_denoiser.py | 35 +++++--- modules/sd_samplers_common.py | 4 +- modules/soft_inpainting.py | 133 ++++++++++++++++++++++++++++ modules/ui.py | 17 ++-- scripts/outpainting_mk_2.py | 15 ++-- scripts/poor_mans_outpainting.py | 15 ++-- test/test_img2img.py | 8 +- 9 files changed, 197 insertions(+), 49 deletions(-) create mode 100644 modules/soft_inpainting.py diff --git a/modules/img2img.py b/modules/img2img.py index 596f741c1ef..3aa8a9cef6c 100644 --- a/modules/img2img.py +++ b/modules/img2img.py @@ -15,6 +15,7 @@ import modules.processing as processing from modules.ui import plaintext_to_html import modules.scripts +import modules.soft_inpainting as si def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=False, scale_by=1.0, use_png_info=False, png_info_props=None, png_info_dir=None): @@ -162,6 +163,7 @@ def img2img(id_task: str, sampler_name: str, mask_blur: int, mask_alpha: float, + mask_blend_enabled: bool, mask_blend_power: float, mask_blend_scale: float, inpaint_detail_preservation: float, @@ -227,6 +229,9 @@ def img2img(id_task: str, assert 0. <= denoising_strength <= 1., 'can only work with strength in [0.0, 1.0]' + soft_inpainting = si.SoftInpaintingSettings(mask_blend_power, mask_blend_scale, inpaint_detail_preservation) \ + if mask_blend_enabled else None + p = StableDiffusionProcessingImg2Img( sd_model=shared.sd_model, outpath_samples=opts.outdir_samples or opts.outdir_img2img_samples, @@ -244,9 +249,7 @@ def img2img(id_task: str, init_images=[image], mask=mask, mask_blur=mask_blur, - mask_blend_power=mask_blend_power, - mask_blend_scale=mask_blend_scale, - inpaint_detail_preservation=inpaint_detail_preservation, + soft_inpainting=soft_inpainting, inpainting_fill=inpainting_fill, resize_mode=resize_mode, denoising_strength=denoising_strength, @@ -267,9 +270,8 @@ def img2img(id_task: str, if mask: p.extra_generation_params["Mask blur"] = mask_blur - p.extra_generation_params["Mask blending bias"] = mask_blend_power - p.extra_generation_params["Mask blending preservation"] = mask_blend_scale - p.extra_generation_params["Mask blending contrast boost"] = inpaint_detail_preservation + if soft_inpainting is not None: + soft_inpainting.add_generation_params(p.extra_generation_params) with closing(p): if is_batch: diff --git a/modules/processing.py b/modules/processing.py index cd7216f8313..b209c84a31d 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -30,6 +30,7 @@ import modules.sd_vae as sd_vae from ldm.data.util import AddMiDaS from ldm.models.diffusion.ddpm import LatentDepth2ImageDiffusion +import modules.soft_inpainting as si from einops import repeat, rearrange from blendmodes.blend import blendLayers, BlendType @@ -1425,9 +1426,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): mask_blur_x: int = 4 mask_blur_y: int = 4 mask_blur: int = None - mask_blend_power: float = 1 - mask_blend_scale: float = 0.5 - inpaint_detail_preservation: float = 4 + soft_inpainting: si.SoftInpaintingParameters = si.default inpainting_fill: int = 0 inpaint_full_res: bool = True inpaint_full_res_padding: int = 0 diff --git a/modules/sd_samplers_cfg_denoiser.py b/modules/sd_samplers_cfg_denoiser.py index efbe7a403de..0ee0b7dde03 100644 --- a/modules/sd_samplers_cfg_denoiser.py +++ b/modules/sd_samplers_cfg_denoiser.py @@ -6,6 +6,7 @@ from modules.script_callbacks import CFGDenoiserParams, cfg_denoiser_callback from modules.script_callbacks import CFGDenoisedParams, cfg_denoised_callback from modules.script_callbacks import AfterCFGCallbackParams, cfg_after_cfg_callback +import modules.soft_inpainting as si def catenate_conds(conds): @@ -43,9 +44,7 @@ def __init__(self, sampler): self.model_wrap = None self.mask = None self.nmask = None - self.mask_blend_power = 1 - self.mask_blend_scale = 0.5 - self.inpaint_detail_preservation = 4 + self.soft_inpainting: si.SoftInpaintingParameters = None self.init_latent = None self.steps = None """number of steps as specified by user in UI""" @@ -95,7 +94,8 @@ def update_inner_model(self): self.sampler.sampler_extra_args['uncond'] = uc def forward(self, x, sigma, uncond, cond, cond_scale, s_min_uncond, image_cond): - def latent_blend(a, b, t): + def latent_blend(a, b, t, one_minus_t=None): + """ Interpolates two latent image representations according to the parameter t, where the interpolated vectors' magnitudes are also interpolated separately. @@ -104,7 +104,11 @@ def latent_blend(a, b, t): """ # NOTE: We use inplace operations wherever possible. - one_minus_t = 1 - t + if one_minus_t is None: + one_minus_t = 1 - t + + if self.soft_inpainting is None: + return a * one_minus_t + b * t # Linearly interpolate the image vectors. a_scaled = a * one_minus_t @@ -119,10 +123,10 @@ def latent_blend(a, b, t): current_magnitude = torch.norm(image_interp, p=2, dim=1).to(torch.float64).add_(0.00001) # Interpolate the powered magnitudes, then un-power them (bring them back to a power of 1). - a_magnitude = torch.norm(a, p=2, dim=1).to(torch.float64).pow_(self.inpaint_detail_preservation) * one_minus_t - b_magnitude = torch.norm(b, p=2, dim=1).to(torch.float64).pow_(self.inpaint_detail_preservation) * t + a_magnitude = torch.norm(a, p=2, dim=1).to(torch.float64).pow_(self.soft_inpainting.inpaint_detail_preservation) * one_minus_t + b_magnitude = torch.norm(b, p=2, dim=1).to(torch.float64).pow_(self.soft_inpainting.inpaint_detail_preservation) * t desired_magnitude = a_magnitude - desired_magnitude.add_(b_magnitude).pow_(1 / self.inpaint_detail_preservation) + desired_magnitude.add_(b_magnitude).pow_(1 / self.soft_inpainting.inpaint_detail_preservation) del a_magnitude, b_magnitude, one_minus_t # Change the linearly interpolated image vectors' magnitudes to the value we want. @@ -156,7 +160,10 @@ def get_modified_nmask(nmask, _sigma): NOTE: "mask" is not used """ - return torch.pow(nmask, (_sigma ** self.mask_blend_power) * self.mask_blend_scale) + if self.soft_inpainting is None: + return nmask + + return torch.pow(nmask, (_sigma ** self.soft_inpainting.mask_blend_power) * self.soft_inpainting.mask_blend_scale) if state.interrupted or state.skipped: raise sd_samplers_common.InterruptedException @@ -176,7 +183,10 @@ def get_modified_nmask(nmask, _sigma): # Blend in the original latents (before) if self.mask_before_denoising and self.mask is not None: - x = latent_blend(self.init_latent, x, get_modified_nmask(self.nmask, sigma)) + if self.soft_inpainting is None: + x = latent_blend(self.init_latent, x, self.nmask, self.mask) + else: + x = latent_blend(self.init_latent, x, get_modified_nmask(self.nmask, sigma)) batch_size = len(conds_list) repeats = [len(conds_list[i]) for i in range(batch_size)] @@ -279,7 +289,10 @@ def get_modified_nmask(nmask, _sigma): # Blend in the original latents (after) if not self.mask_before_denoising and self.mask is not None: - denoised = latent_blend(self.init_latent, denoised, get_modified_nmask(self.nmask, sigma)) + if self.soft_inpainting is None: + denoised = latent_blend(self.init_latent, denoised, self.nmask, self.mask) + else: + denoised = latent_blend(self.init_latent, denoised, get_modified_nmask(self.nmask, sigma)) self.sampler.last_latent = self.get_pred_x0(torch.cat([x_in[i:i + 1] for i in denoised_image_indexes]), torch.cat([x_out[i:i + 1] for i in denoised_image_indexes]), sigma) diff --git a/modules/sd_samplers_common.py b/modules/sd_samplers_common.py index ecd8ab0a093..9682bee3d8e 100644 --- a/modules/sd_samplers_common.py +++ b/modules/sd_samplers_common.py @@ -277,9 +277,7 @@ def initialize(self, p) -> dict: self.model_wrap_cfg.p = p self.model_wrap_cfg.mask = p.mask if hasattr(p, 'mask') else None self.model_wrap_cfg.nmask = p.nmask if hasattr(p, 'nmask') else None - self.model_wrap_cfg.mask_blend_power = p.mask_blend_power if hasattr(p, 'mask_blend_power') else None - self.model_wrap_cfg.mask_blend_scale = p.mask_blend_scale if hasattr(p, 'mask_blend_scale') else None - self.model_wrap_cfg.inpaint_detail_preservation = p.inpaint_detail_preservation if hasattr(p, 'inpaint_detail_preservation') else None + self.model_wrap_cfg.soft_inpainting = p.soft_inpainting if hasattr(p, 'soft_inpainting') else None self.model_wrap_cfg.step = 0 self.model_wrap_cfg.image_cfg_scale = getattr(p, 'image_cfg_scale', None) self.eta = p.eta if p.eta is not None else getattr(opts, self.eta_option_field, 0.0) diff --git a/modules/soft_inpainting.py b/modules/soft_inpainting.py new file mode 100644 index 00000000000..259c36ec892 --- /dev/null +++ b/modules/soft_inpainting.py @@ -0,0 +1,133 @@ +class SoftInpaintingSettings: + def __init__(self, mask_blend_power, mask_blend_scale, inpaint_detail_preservation): + self.mask_blend_power = mask_blend_power + self.mask_blend_scale = mask_blend_scale + self.inpaint_detail_preservation = inpaint_detail_preservation + + def get_paste_fields(self): + return [ + (self.mask_blend_power, gen_param_labels.mask_blend_power), + (self.mask_blend_scale, gen_param_labels.mask_blend_scale), + (self.inpaint_detail_preservation, gen_param_labels.inpaint_detail_preservation), + ] + + def add_generation_params(self, dest): + dest[enabled_gen_param_label] = True + dest[gen_param_labels.mask_blend_power] = self.mask_blend_power + dest[gen_param_labels.mask_blend_scale] = self.mask_blend_scale + dest[gen_param_labels.inpaint_detail_preservation] = self.inpaint_detail_preservation + + +enabled_ui_label = "Soft inpainting" +enabled_gen_param_label = "Soft inpainting enabled" +enabled_el_id = "soft_inpainting_enabled" + +default = SoftInpaintingSettings(1, 0.5, 4) +ui_labels = SoftInpaintingSettings("Schedule bias", "Preservation strength", "Transition contrast boost") + +ui_info = SoftInpaintingSettings( + mask_blend_power="Shifts when preservation of original content occurs during denoising.", + # "Below 1: Stronger preservation near the end (with low sigma)\n" + # "1: Balanced (proportional to sigma)\n" + # "Above 1: Stronger preservation in the beginning (with high sigma)", + mask_blend_scale="How strongly partially masked content should be preserved.", + # "Low values: Favors generated content.\n" + # "High values: Favors original content.", + inpaint_detail_preservation="Amplifies the contrast that may be lost in partially masked regions.") + +gen_param_labels = SoftInpaintingSettings("Soft inpainting schedule bias", "Soft inpainting preservation strength", "Soft inpainting transition contrast boost") +el_ids = SoftInpaintingSettings("mask_blend_power", "mask_blend_scale", "inpaint_detail_preservation") + + +def gradio_ui(): + import gradio as gr + from modules.ui_components import InputAccordion + """ + with InputAccordion(False, label="Refiner", elem_id=self.elem_id("enable")) as enable_refiner: + with gr.Row(): + refiner_checkpoint = gr.Dropdown(label='Checkpoint', elem_id=self.elem_id("checkpoint"), choices=sd_models.checkpoint_tiles(), value='', tooltip="switch to another model in the middle of generation") + create_refresh_button(refiner_checkpoint, sd_models.list_models, lambda: {"choices": sd_models.checkpoint_tiles()}, self.elem_id("checkpoint_refresh")) + + refiner_switch_at = gr.Slider(value=0.8, label="Switch at", minimum=0.01, maximum=1.0, step=0.01, elem_id=self.elem_id("switch_at"), tooltip="fraction of sampling steps when the switch to refiner model should happen; 1=never, 0.5=switch in the middle of generation") + + """ + with InputAccordion(False, label=enabled_ui_label, elem_id=enabled_el_id) as soft_inpainting_enabled: + with gr.Group(): + gr.Markdown( + """ + Soft inpainting allows you to **seamlessly blend original content with inpainted content** according to the mask opacity. + **High _Mask blur_** values are recommended! + """) + + result = SoftInpaintingSettings( + gr.Slider(label=ui_labels.mask_blend_power, + info=ui_info.mask_blend_power, + minimum=0, + maximum=8, + step=0.1, + value=default.mask_blend_power, + elem_id=el_ids.mask_blend_power), + gr.Slider(label=ui_labels.mask_blend_scale, + info=ui_info.mask_blend_scale, + minimum=0, + maximum=8, + step=0.05, + value=default.mask_blend_scale, + elem_id=el_ids.mask_blend_scale), + gr.Slider(label=ui_labels.inpaint_detail_preservation, + info=ui_info.inpaint_detail_preservation, + minimum=1, + maximum=32, + step=0.5, + value=default.inpaint_detail_preservation, + elem_id=el_ids.inpaint_detail_preservation)) + + with gr.Accordion("Help", open=False): + gr.Markdown( + f""" + ### {ui_labels.mask_blend_power} + + The blending strength of original content is scaled proportionally with the decreasing noise level values at each step (sigmas). + This ensures that the influence of the denoiser and original content preservation is roughly balanced at each step. + This balance can be shifted using this parameter, controlling whether earlier or later steps have stronger preservation. + + - **Below 1**: Stronger preservation near the end (with low sigma) + - **1**: Balanced (proportional to sigma) + - **Above 1**: Stronger preservation in the beginning (with high sigma) + """) + gr.Markdown( + f""" + ### {ui_labels.mask_blend_scale} + + Skews whether partially masked image regions should be more likely to preserve the original content or favor inpainted content. + This may need to be adjusted depending on the {ui_labels.mask_blend_power}, CFG Scale, prompt and Denoising strength. + + - **Low values**: Favors generated content. + - **High values**: Favors original content. + """) + gr.Markdown( + f""" + ### {ui_labels.inpaint_detail_preservation} + + This parameter controls how the original latent vectors and denoised latent vectors are interpolated. + With higher values, the magnitude of the resulting blended vector will be closer to the maximum of the two interpolated vectors. + This can prevent the loss of contrast that occurs with linear interpolation. + + - **Low values**: Softer blending, details may fade. + - **High values**: Stronger contrast, may over-saturate colors. + """) + + return ( + [ + soft_inpainting_enabled, + result.mask_blend_power, + result.mask_blend_scale, + result.inpaint_detail_preservation + ], + [ + (soft_inpainting_enabled, enabled_gen_param_label), + (result.mask_blend_power, gen_param_labels.mask_blend_power), + (result.mask_blend_scale, gen_param_labels.mask_blend_scale), + (result.inpaint_detail_preservation, gen_param_labels.inpaint_detail_preservation) + ] + ) diff --git a/modules/ui.py b/modules/ui.py index b13ed66cbaa..0e4fb17aa75 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -29,6 +29,7 @@ from modules import prompt_parser from modules.sd_hijack import model_hijack from modules.generation_parameters_copypaste import image_from_url_text +import modules.soft_inpainting as si create_setting_component = ui_settings.create_setting_component @@ -678,9 +679,16 @@ def copy_image(img): with FormRow(): mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=4, elem_id="img2img_mask_blur") mask_alpha = gr.Slider(label="Mask transparency", visible=False, elem_id="img2img_mask_alpha") + + with FormRow(): + soft_inpainting = si.gradio_ui() + + + """ mask_blend_power = gr.Slider(label='Blending bias', minimum=0, maximum=8, step=0.1, value=1, elem_id="img2img_mask_blend_power") mask_blend_scale = gr.Slider(label='Blending preservation', minimum=0, maximum=8, step=0.05, value=0.5, elem_id="img2img_mask_blend_scale") inpaint_detail_preservation = gr.Slider(label='Blending contrast boost', minimum=1, maximum=32, step=0.5, value=4, elem_id="img2img_mask_blend_offset") + """ with FormRow(): inpainting_mask_invert = gr.Radio(label='Mask mode', choices=['Inpaint masked', 'Inpaint not masked'], value='Inpaint masked', type="index", elem_id="img2img_mask_mode") @@ -736,9 +744,7 @@ def select_img2img_tab(tab): sampler_name, mask_blur, mask_alpha, - mask_blend_power, - mask_blend_scale, - inpaint_detail_preservation, + *(soft_inpainting[0]), inpainting_fill, batch_count, batch_size, @@ -837,11 +843,10 @@ def select_img2img_tab(tab): (toprow.ui_styles.dropdown, lambda d: d["Styles array"] if isinstance(d.get("Styles array"), list) else gr.update()), (denoising_strength, "Denoising strength"), (mask_blur, "Mask blur"), - (mask_blend_power, "Mask blending bias"), - (mask_blend_scale, "Mask blending preservation"), - (inpaint_detail_preservation, "Mask blending contrast boost"), + *(soft_inpainting[1]), *scripts.scripts_img2img.infotext_fields ] + parameters_copypaste.add_paste_fields("img2img", init_img, img2img_paste_fields, override_settings) parameters_copypaste.add_paste_fields("inpaint", init_img_with_mask, img2img_paste_fields, override_settings) parameters_copypaste.register_paste_params_button(parameters_copypaste.ParamBinding( diff --git a/scripts/outpainting_mk_2.py b/scripts/outpainting_mk_2.py index bd9cb61bf82..f78886883bd 100644 --- a/scripts/outpainting_mk_2.py +++ b/scripts/outpainting_mk_2.py @@ -10,6 +10,7 @@ from modules import images from modules.processing import Processed, process_images from modules.shared import opts, state +import modules.soft_inpainting as si # this function is taken from https://github.com/parlance-zz/g-diffuser-bot @@ -133,16 +134,14 @@ def ui(self, is_img2img): pixels = gr.Slider(label="Pixels to expand", minimum=8, maximum=256, step=8, value=128, elem_id=self.elem_id("pixels")) mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=8, elem_id=self.elem_id("mask_blur")) - mask_blend_power = gr.Slider(label='Blending bias', minimum=0, maximum=8, step=0.1, value=1, elem_id=self.elem_id("mask_blend_power")) - mask_blend_scale = gr.Slider(label='Blending preservation', minimum=0, maximum=8, step=0.05, value=0.5, elem_id=self.elem_id("mask_blend_scale")) - inpaint_detail_preservation = gr.Slider(label='Blending contrast boost', minimum=1, maximum=32, step=0.5, value=4, elem_id=self.elem_id("inpaint_detail_preservation")) + soft_inpainting = si.gradio_ui()[0] direction = gr.CheckboxGroup(label="Outpainting direction", choices=['left', 'right', 'up', 'down'], value=['left', 'right', 'up', 'down'], elem_id=self.elem_id("direction")) noise_q = gr.Slider(label="Fall-off exponent (lower=higher detail)", minimum=0.0, maximum=4.0, step=0.01, value=1.0, elem_id=self.elem_id("noise_q")) color_variation = gr.Slider(label="Color variation", minimum=0.0, maximum=1.0, step=0.01, value=0.05, elem_id=self.elem_id("color_variation")) - return [info, pixels, mask_blur, mask_blend_power, mask_blend_scale, inpaint_detail_preservation, direction, noise_q, color_variation] + return [info, pixels, mask_blur, *soft_inpainting, direction, noise_q, color_variation] - def run(self, p, _, pixels, mask_blur, mask_blend_power, mask_blend_scale, inpaint_detail_preservation, direction, noise_q, color_variation): + def run(self, p, _, pixels, mask_blur, mask_blend_enabled, mask_blend_power, mask_blend_scale, inpaint_detail_preservation, direction, noise_q, color_variation): initial_seed_and_info = [None, None] process_width = p.width @@ -170,9 +169,9 @@ def run(self, p, _, pixels, mask_blur, mask_blend_power, mask_blend_scale, inpai p.mask_blur_x = mask_blur_x*4 p.mask_blur_y = mask_blur_y*4 - p.mask_blend_power = mask_blend_power - p.mask_blend_scale = mask_blend_scale - p.inpaint_detail_preservation = inpaint_detail_preservation + + p.soft_inpainting = si.SoftInpaintingSettings(mask_blend_power, mask_blend_scale, inpaint_detail_preservation) \ + if mask_blend_enabled else None init_img = p.init_images[0] target_w = math.ceil((init_img.width + left + right) / 64) * 64 diff --git a/scripts/poor_mans_outpainting.py b/scripts/poor_mans_outpainting.py index 5388f5db462..11f7f74a820 100644 --- a/scripts/poor_mans_outpainting.py +++ b/scripts/poor_mans_outpainting.py @@ -7,6 +7,7 @@ from modules import images, devices from modules.processing import Processed, process_images from modules.shared import opts, state +import modules.soft_inpainting as si class Script(scripts.Script): @@ -22,23 +23,19 @@ def ui(self, is_img2img): pixels = gr.Slider(label="Pixels to expand", minimum=8, maximum=256, step=8, value=128, elem_id=self.elem_id("pixels")) mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=4, elem_id=self.elem_id("mask_blur")) - mask_blend_power = gr.Slider(label='Blending bias', minimum=0, maximum=8, step=0.1, value=1, elem_id=self.elem_id("mask_blend_power")) - mask_blend_scale = gr.Slider(label='Blending preservation', minimum=0, maximum=8, step=0.05, value=0.5, elem_id=self.elem_id("mask_blend_scale")) - inpaint_detail_preservation = gr.Slider(label='Blending contrast boost', minimum=1, maximum=32, step=0.5, value=4, elem_id=self.elem_id("inpaint_detail_preservation")) + soft_inpainting = si.gradio_ui()[0] inpainting_fill = gr.Radio(label='Masked content', choices=['fill', 'original', 'latent noise', 'latent nothing'], value='fill', type="index", elem_id=self.elem_id("inpainting_fill")) direction = gr.CheckboxGroup(label="Outpainting direction", choices=['left', 'right', 'up', 'down'], value=['left', 'right', 'up', 'down'], elem_id=self.elem_id("direction")) - return [pixels, mask_blur, mask_blend_power, mask_blend_scale, inpaint_detail_preservation, inpainting_fill, direction] + return [pixels, mask_blur, *soft_inpainting, inpainting_fill, direction] - def run(self, p, pixels, mask_blur, mask_blend_power, mask_blend_scale, inpaint_detail_preservation, inpainting_fill, direction): + def run(self, p, pixels, mask_blur, mask_blend_enabled, mask_blend_power, mask_blend_scale, inpaint_detail_preservation, inpainting_fill, direction): initial_seed = None initial_info = None p.mask_blur = mask_blur * 2 - p.mask_blend_power = mask_blend_power - p.mask_blend_scale = mask_blend_scale - p.inpaint_detail_preservation = inpaint_detail_preservation - + p.soft_inpainting = si.SoftInpaintingSettings(mask_blend_power, mask_blend_scale, inpaint_detail_preservation) \ + if mask_blend_enabled else None p.inpainting_fill = inpainting_fill p.inpaint_full_res = False diff --git a/test/test_img2img.py b/test/test_img2img.py index 5cda2dbaefe..87bd850912e 100644 --- a/test/test_img2img.py +++ b/test/test_img2img.py @@ -1,6 +1,7 @@ import pytest import requests +import modules.soft_inpainting as si @pytest.fixture() @@ -24,9 +25,10 @@ def simple_img2img_request(img2img_basic_image_base64): "inpainting_mask_invert": False, "mask": None, "mask_blur": 4, - "mask_blend_power": 1, - "mask_blend_scale": 0.5, - "inpaint_detail_preservation": 4, + "mask_blend_enabled": True, + "mask_blend_power": si.default.mask_blend_power, + "mask_blend_scale": si.default.mask_blend_scale, + "inpaint_detail_preservation": si.default.inpaint_detail_preservation, "n_iter": 1, "negative_prompt": "", "override_settings": {}, From 259d33c3c8e27557cb9bab9b3a1dd7fc7450d16c Mon Sep 17 00:00:00 2001 From: CodeHatchling Date: Mon, 4 Dec 2023 01:57:21 -0700 Subject: [PATCH 0377/1174] Enables the original functionality to be toggled on and off. --- modules/processing.py | 99 ++++++++++++++++++++++++++++++------------- 1 file changed, 70 insertions(+), 29 deletions(-) diff --git a/modules/processing.py b/modules/processing.py index b209c84a31d..b40b1a40d49 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -88,9 +88,12 @@ def apply_overlay(image, paste_loc, index, overlays): return image -def create_binary_mask(image): +def create_binary_mask(image, round=True): if image.mode == 'RGBA' and image.getextrema()[-1] != (255, 255): - image = image.split()[-1].convert("L") + if round: + image = image.split()[-1].convert("L").point(lambda x: 255 if x > 128 else 0) + else: + image = image.split()[-1].convert("L") else: image = image.convert('L') return image @@ -316,7 +319,7 @@ def unclip_image_conditioning(self, source_image): c_adm = torch.cat((c_adm, noise_level_emb), 1) return c_adm - def inpainting_image_conditioning(self, source_image, latent_image, image_mask=None): + def inpainting_image_conditioning(self, source_image, latent_image, image_mask=None, round_image_mask=True): self.is_using_inpainting_conditioning = True # Handle the different mask inputs @@ -327,6 +330,11 @@ def inpainting_image_conditioning(self, source_image, latent_image, image_mask=N conditioning_mask = np.array(image_mask.convert("L")) conditioning_mask = conditioning_mask.astype(np.float32) / 255.0 conditioning_mask = torch.from_numpy(conditioning_mask[None, None]) + + if round_image_mask: + # Caller is requesting a discretized mask as input, so we round to either 1.0 or 0.0 + conditioning_mask = torch.round(conditioning_mask) + else: conditioning_mask = source_image.new_ones(1, 1, *source_image.shape[-2:]) @@ -350,7 +358,7 @@ def inpainting_image_conditioning(self, source_image, latent_image, image_mask=N return image_conditioning - def img2img_image_conditioning(self, source_image, latent_image, image_mask=None): + def img2img_image_conditioning(self, source_image, latent_image, image_mask=None, round_image_mask=True): source_image = devices.cond_cast_float(source_image) # HACK: Using introspection as the Depth2Image model doesn't appear to uniquely @@ -362,7 +370,10 @@ def img2img_image_conditioning(self, source_image, latent_image, image_mask=None return self.edit_image_conditioning(source_image) if self.sampler.conditioning_key in {'hybrid', 'concat'}: - return self.inpainting_image_conditioning(source_image, latent_image, image_mask=image_mask) + return self.inpainting_image_conditioning(source_image, + latent_image, + image_mask=image_mask, + round_image_mask=round_image_mask) if self.sampler.conditioning_key == "crossattn-adm": return self.unclip_image_conditioning(source_image) @@ -878,8 +889,9 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: else: if opts.sd_vae_decode_method != 'Full': p.extra_generation_params['VAE Decoder'] = opts.sd_vae_decode_method + # Generate the mask(s) based on similarity between the original and denoised latent vectors - if getattr(p, "image_mask", None) is not None: + if getattr(p, "image_mask", None) is not None and getattr(p, "soft_inpainting", None) is not None: # latent_mask = p.nmask[0].float().cpu() # convert the original mask into a form we use to scale distances for thresholding @@ -911,7 +923,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: converted_mask = converted_mask.astype(np.uint8) converted_mask = Image.fromarray(converted_mask) converted_mask = images.resize_image(2, converted_mask, p.width, p.height) - converted_mask = create_binary_mask(converted_mask) + converted_mask = create_binary_mask(converted_mask, round=False) # Remove aliasing artifacts using a gaussian blur. converted_mask = converted_mask.filter(ImageFilter.GaussianBlur(radius=4)) @@ -1010,23 +1022,33 @@ def infotext(index=0, use_main_prompt=False): if opts.enable_pnginfo: image.info["parameters"] = text output_images.append(image) - if save_samples and hasattr(p, 'masks_for_overlay') and p.masks_for_overlay and any([opts.save_mask, opts.save_mask_composite, opts.return_mask, opts.return_mask_composite]): - image_mask = p.masks_for_overlay[i].convert('RGB') - image_mask_composite = Image.composite( - original_denoised_image.convert('RGBA').convert('RGBa'), Image.new('RGBa', image.size), - images.resize_image(2, p.masks_for_overlay[i], image.width, image.height).convert('L')).convert('RGBA') - - if opts.save_mask: - images.save_image(image_mask, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-mask") - - if opts.save_mask_composite: - images.save_image(image_mask_composite, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-mask-composite") - - if opts.return_mask: - output_images.append(image_mask) - - if opts.return_mask_composite: - output_images.append(image_mask_composite) + if save_samples and any([opts.save_mask, opts.save_mask_composite, opts.return_mask, opts.return_mask_composite]): + if hasattr(p, 'masks_for_overlay') and p.masks_for_overlay: + image_mask = p.masks_for_overlay[i].convert('RGB') + image_mask_composite = Image.composite( + original_denoised_image.convert('RGBA').convert('RGBa'), Image.new('RGBa', image.size), + images.resize_image(2, p.masks_for_overlay[i], image.width, image.height).convert('L')).convert('RGBA') + elif hasattr(p, 'mask_for_overlay') and p.mask_for_overlay: + image_mask = p.mask_for_overlay.convert('RGB') + image_mask_composite = Image.composite( + original_denoised_image.convert('RGBA').convert('RGBa'), Image.new('RGBa', image.size), + images.resize_image(2, p.mask_for_overlay, image.width, image.height).convert('L')).convert('RGBA') + else: + image_mask = None + image_mask_composite = None + + if image_mask is not None and image_mask_composite is not None: + if opts.save_mask: + images.save_image(image_mask, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-mask") + + if opts.save_mask_composite: + images.save_image(image_mask_composite, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-mask-composite") + + if opts.return_mask: + output_images.append(image_mask) + + if opts.return_mask_composite: + output_images.append(image_mask_composite) del x_samples_ddim @@ -1439,6 +1461,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): nmask: torch.Tensor = field(default=None, init=False) image_conditioning: torch.Tensor = field(default=None, init=False) init_img_hash: str = field(default=None, init=False) + mask_for_overlay: Image = field(default=None, init=False) init_latent: torch.Tensor = field(default=None, init=False) def __post_init__(self): @@ -1471,7 +1494,7 @@ def init(self, all_prompts, all_seeds, all_subseeds): if image_mask is not None: # image_mask is passed in as RGBA by Gradio to support alpha masks, # but we still want to support binary masks. - image_mask = create_binary_mask(image_mask) + image_mask = create_binary_mask(image_mask, round=(self.soft_inpainting is None)) if self.inpainting_mask_invert: image_mask = ImageOps.invert(image_mask) @@ -1489,6 +1512,7 @@ def init(self, all_prompts, all_seeds, all_subseeds): image_mask = Image.fromarray(np_mask) if self.inpaint_full_res: + self.mask_for_overlay = image_mask if self.soft_inpainting is None else None mask = image_mask.convert('L') crop_region = masking.get_crop_region(np.array(mask), self.inpaint_full_res_padding) crop_region = masking.expand_crop_region(crop_region, self.width, self.height, mask.width, mask.height) @@ -1500,7 +1524,12 @@ def init(self, all_prompts, all_seeds, all_subseeds): else: image_mask = images.resize_image(self.resize_mode, image_mask, self.width, self.height) - self.masks_for_overlay = [] + if self.soft_inpainting is None: + np_mask = np.array(image_mask) + np_mask = np.clip((np_mask.astype(np.float32)) * 2, 0, 255).astype(np.uint8) + self.mask_for_overlay = Image.fromarray(np_mask) + + self.masks_for_overlay = [] if self.soft_inpainting is not None else None self.overlay_images = [] latent_mask = self.latent_mask if self.latent_mask is not None else image_mask @@ -1522,8 +1551,15 @@ def init(self, all_prompts, all_seeds, all_subseeds): image = images.resize_image(self.resize_mode, image, self.width, self.height) if image_mask is not None: - self.overlay_images.append(image) - self.masks_for_overlay.append(image_mask) + if self.soft_inpainting is not None: + # We apply the masks AFTER to adjust mask based on changed content. + self.overlay_images.append(image) + self.masks_for_overlay.append(image_mask) + else: + image_masked = Image.new('RGBa', (image.width, image.height)) + image_masked.paste(image.convert("RGBA").convert("RGBa"), + mask=ImageOps.invert(self.mask_for_overlay.convert('L'))) + self.overlay_images.append(image_masked.convert('RGBA')) # crop_region is not None if we are doing inpaint full res if crop_region is not None: @@ -1576,6 +1612,8 @@ def init(self, all_prompts, all_seeds, all_subseeds): latmask = init_mask.convert('RGB').resize((self.init_latent.shape[3], self.init_latent.shape[2])) latmask = np.moveaxis(np.array(latmask, dtype=np.float32), 2, 0) / 255 latmask = latmask[0] + if self.soft_inpainting is None: + latmask = np.around(latmask) latmask = np.tile(latmask[None], (4, 1, 1)) self.mask = torch.asarray(1.0 - latmask).to(shared.device).type(self.sd_model.dtype) @@ -1587,7 +1625,10 @@ def init(self, all_prompts, all_seeds, all_subseeds): elif self.inpainting_fill == 3: self.init_latent = self.init_latent * self.mask - self.image_conditioning = self.img2img_image_conditioning(image * 2 - 1, self.init_latent, image_mask) + self.image_conditioning = self.img2img_image_conditioning(image * 2 - 1, + self.init_latent, + image_mask, + self.soft_inpainting is None) def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts): x = self.rng.next() From 15322e1b1a9e31edcc2f7d72a32d02365058737d Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Mon, 4 Dec 2023 12:36:41 +0300 Subject: [PATCH 0378/1174] repair old handler for postprocessing API --- modules/postprocessing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/postprocessing.py b/modules/postprocessing.py index 3c85a74c1b3..d166f859b68 100644 --- a/modules/postprocessing.py +++ b/modules/postprocessing.py @@ -153,4 +153,4 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_ }, }) - return run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir, show_extras_results, *args, save_output=save_output) + return run_postprocessing("", extras_mode, image, image_folder, input_dir, output_dir, show_extras_results, *args, save_output=save_output) From 24dae9bc4cc03a30236957d9c35d37aed79f6f5d Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Mon, 4 Dec 2023 12:36:41 +0300 Subject: [PATCH 0379/1174] repair old handler for postprocessing API --- modules/postprocessing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/postprocessing.py b/modules/postprocessing.py index 3c85a74c1b3..d166f859b68 100644 --- a/modules/postprocessing.py +++ b/modules/postprocessing.py @@ -153,4 +153,4 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_ }, }) - return run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir, show_extras_results, *args, save_output=save_output) + return run_postprocessing("", extras_mode, image, image_folder, input_dir, output_dir, show_extras_results, *args, save_output=save_output) From 883d6a2b34a2817304d23c2481a6f9fc56687a53 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Mon, 4 Dec 2023 13:11:00 +0300 Subject: [PATCH 0380/1174] repair old handler for postprocessing API in a way that doesn't break interface --- modules/postprocessing.py | 8 ++++++-- modules/ui_postprocessing.py | 2 +- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/modules/postprocessing.py b/modules/postprocessing.py index d166f859b68..0c59fad480b 100644 --- a/modules/postprocessing.py +++ b/modules/postprocessing.py @@ -6,7 +6,7 @@ from modules.shared import opts -def run_postprocessing(id_task, extras_mode, image, image_folder, input_dir, output_dir, show_extras_results, *args, save_output: bool = True): +def run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir, show_extras_results, *args, save_output: bool = True): devices.torch_gc() shared.state.begin(job="extras") @@ -128,6 +128,10 @@ def get_images(extras_mode, image, image_folder, input_dir): return outputs, ui_common.plaintext_to_html(infotext), '' +def run_postprocessing_webui(id_task, *args, **kwargs): + return run_postprocessing(*args, **kwargs) + + def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_dir, show_extras_results, gfpgan_visibility, codeformer_visibility, codeformer_weight, upscaling_resize, upscaling_resize_w, upscaling_resize_h, upscaling_crop, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility, upscale_first: bool, save_output: bool = True): """old handler for API""" @@ -153,4 +157,4 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_ }, }) - return run_postprocessing("", extras_mode, image, image_folder, input_dir, output_dir, show_extras_results, *args, save_output=save_output) + return run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir, show_extras_results, *args, save_output=save_output) diff --git a/modules/ui_postprocessing.py b/modules/ui_postprocessing.py index fbad0800aca..13d888e48d1 100644 --- a/modules/ui_postprocessing.py +++ b/modules/ui_postprocessing.py @@ -35,7 +35,7 @@ def create_ui(): tab_batch_dir.select(fn=lambda: 2, inputs=[], outputs=[tab_index]) submit.click( - fn=call_queue.wrap_gradio_gpu_call(postprocessing.run_postprocessing, extra_outputs=[None, '']), + fn=call_queue.wrap_gradio_gpu_call(postprocessing.run_postprocessing_webui, extra_outputs=[None, '']), _js="submit_extras", inputs=[ dummy_component, From 81105ee0135f1c475920bf44d3a04fc181aed29e Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Mon, 4 Dec 2023 13:11:00 +0300 Subject: [PATCH 0381/1174] repair old handler for postprocessing API in a way that doesn't break interface --- modules/postprocessing.py | 8 ++++++-- modules/ui_postprocessing.py | 2 +- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/modules/postprocessing.py b/modules/postprocessing.py index d166f859b68..0c59fad480b 100644 --- a/modules/postprocessing.py +++ b/modules/postprocessing.py @@ -6,7 +6,7 @@ from modules.shared import opts -def run_postprocessing(id_task, extras_mode, image, image_folder, input_dir, output_dir, show_extras_results, *args, save_output: bool = True): +def run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir, show_extras_results, *args, save_output: bool = True): devices.torch_gc() shared.state.begin(job="extras") @@ -128,6 +128,10 @@ def get_images(extras_mode, image, image_folder, input_dir): return outputs, ui_common.plaintext_to_html(infotext), '' +def run_postprocessing_webui(id_task, *args, **kwargs): + return run_postprocessing(*args, **kwargs) + + def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_dir, show_extras_results, gfpgan_visibility, codeformer_visibility, codeformer_weight, upscaling_resize, upscaling_resize_w, upscaling_resize_h, upscaling_crop, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility, upscale_first: bool, save_output: bool = True): """old handler for API""" @@ -153,4 +157,4 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_ }, }) - return run_postprocessing("", extras_mode, image, image_folder, input_dir, output_dir, show_extras_results, *args, save_output=save_output) + return run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir, show_extras_results, *args, save_output=save_output) diff --git a/modules/ui_postprocessing.py b/modules/ui_postprocessing.py index fbad0800aca..13d888e48d1 100644 --- a/modules/ui_postprocessing.py +++ b/modules/ui_postprocessing.py @@ -35,7 +35,7 @@ def create_ui(): tab_batch_dir.select(fn=lambda: 2, inputs=[], outputs=[tab_index]) submit.click( - fn=call_queue.wrap_gradio_gpu_call(postprocessing.run_postprocessing, extra_outputs=[None, '']), + fn=call_queue.wrap_gradio_gpu_call(postprocessing.run_postprocessing_webui, extra_outputs=[None, '']), _js="submit_extras", inputs=[ dummy_component, From 22e23dbf29b0bbc807daa57318c31145f8dd0774 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Mon, 4 Dec 2023 15:56:03 +0300 Subject: [PATCH 0382/1174] add hypertile infotext --- .../hypertile/scripts/hypertile_script.py | 53 +++++++++++++++---- 1 file changed, 42 insertions(+), 11 deletions(-) diff --git a/extensions-builtin/hypertile/scripts/hypertile_script.py b/extensions-builtin/hypertile/scripts/hypertile_script.py index d3ab6091500..395d584b605 100644 --- a/extensions-builtin/hypertile/scripts/hypertile_script.py +++ b/extensions-builtin/hypertile/scripts/hypertile_script.py @@ -17,11 +17,42 @@ def process(self, p, *args): configure_hypertile(p.width, p.height, enable_unet=shared.opts.hypertile_enable_unet) + self.add_infotext(p) + def before_hr(self, p, *args): + + enable = shared.opts.hypertile_enable_unet_secondpass or shared.opts.hypertile_enable_unet + # exclusive hypertile seed for the second pass - if not shared.opts.hypertile_enable_unet: + if enable: hypertile.set_hypertile_seed(p.all_seeds[0]) - configure_hypertile(p.hr_upscale_to_x, p.hr_upscale_to_y, enable_unet=shared.opts.hypertile_enable_unet_secondpass) + + configure_hypertile(p.hr_upscale_to_x, p.hr_upscale_to_y, enable_unet=enable) + + if enable and not shared.opts.hypertile_enable_unet: + p.extra_generation_params["Hypertile U-Net second pass"] = True + + self.add_infotext(p, add_unet_params=True) + + def add_infotext(self, p, add_unet_params=False): + def option(name): + value = getattr(shared.opts, name) + default_value = shared.opts.get_default(name) + return None if value == default_value else value + + if shared.opts.hypertile_enable_unet: + p.extra_generation_params["Hypertile U-Net"] = True + + if shared.opts.hypertile_enable_unet or add_unet_params: + p.extra_generation_params["Hypertile U-Net max depth"] = option('hypertile_max_depth_unet') + p.extra_generation_params["Hypertile U-Net max tile size"] = option('hypertile_max_tile_unet') + p.extra_generation_params["Hypertile U-Net swap size"] = option('hypertile_swap_size_unet') + + if shared.opts.hypertile_enable_vae: + p.extra_generation_params["Hypertile VAE"] = True + p.extra_generation_params["Hypertile VAE max depth"] = option('hypertile_max_depth_vae') + p.extra_generation_params["Hypertile VAE max tile size"] = option('hypertile_max_tile_vae') + p.extra_generation_params["Hypertile VAE swap size"] = option('hypertile_swap_size_vae') def configure_hypertile(width, height, enable_unet=True): @@ -57,16 +88,16 @@ def on_ui_settings(): benefit. """), - "hypertile_enable_unet": shared.OptionInfo(False, "Enable Hypertile U-Net").info("noticeable change in details of the generated picture; if enabled, overrides the setting below"), - "hypertile_enable_unet_secondpass": shared.OptionInfo(False, "Enable Hypertile U-Net for hires fix second pass"), - "hypertile_max_depth_unet": shared.OptionInfo(3, "Hypertile U-Net max depth", gr.Slider, {"minimum": 0, "maximum": 3, "step": 1}), - "hypertile_max_tile_unet": shared.OptionInfo(256, "Hypertile U-net max tile size", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}), - "hypertile_swap_size_unet": shared.OptionInfo(3, "Hypertile U-net swap size", gr.Slider, {"minimum": 0, "maximum": 64, "step": 1}), + "hypertile_enable_unet": shared.OptionInfo(False, "Enable Hypertile U-Net", infotext="Hypertile U-Net").info("enables hypertile for all modes, including hires fix second pass; noticeable change in details of the generated picture"), + "hypertile_enable_unet_secondpass": shared.OptionInfo(False, "Enable Hypertile U-Net for hires fix second pass", infotext="Hypertile U-Net second pass").info("enables hypertile just for hires fix second pass - regardless of whether the above setting is enabled"), + "hypertile_max_depth_unet": shared.OptionInfo(3, "Hypertile U-Net max depth", gr.Slider, {"minimum": 0, "maximum": 3, "step": 1}, infotext="Hypertile U-Net max depth").info("larger = more neural network layers affected; minor effect on performance"), + "hypertile_max_tile_unet": shared.OptionInfo(256, "Hypertile U-Net max tile size", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}, infotext="Hypertile U-Net max tile size").info("larger = worse performance"), + "hypertile_swap_size_unet": shared.OptionInfo(3, "Hypertile U-Net swap size", gr.Slider, {"minimum": 0, "maximum": 64, "step": 1}, infotext="Hypertile U-Net swap size"), - "hypertile_enable_vae": shared.OptionInfo(False, "Enable Hypertile VAE").info("minimal change in the generated picture"), - "hypertile_max_depth_vae": shared.OptionInfo(3, "Hypertile VAE max depth", gr.Slider, {"minimum": 0, "maximum": 3, "step": 1}), - "hypertile_max_tile_vae": shared.OptionInfo(128, "Hypertile VAE max tile size", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}), - "hypertile_swap_size_vae": shared.OptionInfo(3, "Hypertile VAE swap size ", gr.Slider, {"minimum": 0, "maximum": 64, "step": 1}), + "hypertile_enable_vae": shared.OptionInfo(False, "Enable Hypertile VAE", infotext="Hypertile VAE").info("minimal change in the generated picture"), + "hypertile_max_depth_vae": shared.OptionInfo(3, "Hypertile VAE max depth", gr.Slider, {"minimum": 0, "maximum": 3, "step": 1}, infotext="Hypertile VAE max depth"), + "hypertile_max_tile_vae": shared.OptionInfo(128, "Hypertile VAE max tile size", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}, infotext="Hypertile VAE max tile size"), + "hypertile_swap_size_vae": shared.OptionInfo(3, "Hypertile VAE swap size ", gr.Slider, {"minimum": 0, "maximum": 64, "step": 1}, infotext="Hypertile VAE swap size"), } for name, opt in options.items(): From 368d66c9ccca0270cc64d6a64d22bfa562f28361 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Mon, 4 Dec 2023 15:56:03 +0300 Subject: [PATCH 0383/1174] add hypertile infotext --- .../hypertile/scripts/hypertile_script.py | 53 +++++++++++++++---- 1 file changed, 42 insertions(+), 11 deletions(-) diff --git a/extensions-builtin/hypertile/scripts/hypertile_script.py b/extensions-builtin/hypertile/scripts/hypertile_script.py index d3ab6091500..395d584b605 100644 --- a/extensions-builtin/hypertile/scripts/hypertile_script.py +++ b/extensions-builtin/hypertile/scripts/hypertile_script.py @@ -17,11 +17,42 @@ def process(self, p, *args): configure_hypertile(p.width, p.height, enable_unet=shared.opts.hypertile_enable_unet) + self.add_infotext(p) + def before_hr(self, p, *args): + + enable = shared.opts.hypertile_enable_unet_secondpass or shared.opts.hypertile_enable_unet + # exclusive hypertile seed for the second pass - if not shared.opts.hypertile_enable_unet: + if enable: hypertile.set_hypertile_seed(p.all_seeds[0]) - configure_hypertile(p.hr_upscale_to_x, p.hr_upscale_to_y, enable_unet=shared.opts.hypertile_enable_unet_secondpass) + + configure_hypertile(p.hr_upscale_to_x, p.hr_upscale_to_y, enable_unet=enable) + + if enable and not shared.opts.hypertile_enable_unet: + p.extra_generation_params["Hypertile U-Net second pass"] = True + + self.add_infotext(p, add_unet_params=True) + + def add_infotext(self, p, add_unet_params=False): + def option(name): + value = getattr(shared.opts, name) + default_value = shared.opts.get_default(name) + return None if value == default_value else value + + if shared.opts.hypertile_enable_unet: + p.extra_generation_params["Hypertile U-Net"] = True + + if shared.opts.hypertile_enable_unet or add_unet_params: + p.extra_generation_params["Hypertile U-Net max depth"] = option('hypertile_max_depth_unet') + p.extra_generation_params["Hypertile U-Net max tile size"] = option('hypertile_max_tile_unet') + p.extra_generation_params["Hypertile U-Net swap size"] = option('hypertile_swap_size_unet') + + if shared.opts.hypertile_enable_vae: + p.extra_generation_params["Hypertile VAE"] = True + p.extra_generation_params["Hypertile VAE max depth"] = option('hypertile_max_depth_vae') + p.extra_generation_params["Hypertile VAE max tile size"] = option('hypertile_max_tile_vae') + p.extra_generation_params["Hypertile VAE swap size"] = option('hypertile_swap_size_vae') def configure_hypertile(width, height, enable_unet=True): @@ -57,16 +88,16 @@ def on_ui_settings(): benefit. """), - "hypertile_enable_unet": shared.OptionInfo(False, "Enable Hypertile U-Net").info("noticeable change in details of the generated picture; if enabled, overrides the setting below"), - "hypertile_enable_unet_secondpass": shared.OptionInfo(False, "Enable Hypertile U-Net for hires fix second pass"), - "hypertile_max_depth_unet": shared.OptionInfo(3, "Hypertile U-Net max depth", gr.Slider, {"minimum": 0, "maximum": 3, "step": 1}), - "hypertile_max_tile_unet": shared.OptionInfo(256, "Hypertile U-net max tile size", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}), - "hypertile_swap_size_unet": shared.OptionInfo(3, "Hypertile U-net swap size", gr.Slider, {"minimum": 0, "maximum": 64, "step": 1}), + "hypertile_enable_unet": shared.OptionInfo(False, "Enable Hypertile U-Net", infotext="Hypertile U-Net").info("enables hypertile for all modes, including hires fix second pass; noticeable change in details of the generated picture"), + "hypertile_enable_unet_secondpass": shared.OptionInfo(False, "Enable Hypertile U-Net for hires fix second pass", infotext="Hypertile U-Net second pass").info("enables hypertile just for hires fix second pass - regardless of whether the above setting is enabled"), + "hypertile_max_depth_unet": shared.OptionInfo(3, "Hypertile U-Net max depth", gr.Slider, {"minimum": 0, "maximum": 3, "step": 1}, infotext="Hypertile U-Net max depth").info("larger = more neural network layers affected; minor effect on performance"), + "hypertile_max_tile_unet": shared.OptionInfo(256, "Hypertile U-Net max tile size", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}, infotext="Hypertile U-Net max tile size").info("larger = worse performance"), + "hypertile_swap_size_unet": shared.OptionInfo(3, "Hypertile U-Net swap size", gr.Slider, {"minimum": 0, "maximum": 64, "step": 1}, infotext="Hypertile U-Net swap size"), - "hypertile_enable_vae": shared.OptionInfo(False, "Enable Hypertile VAE").info("minimal change in the generated picture"), - "hypertile_max_depth_vae": shared.OptionInfo(3, "Hypertile VAE max depth", gr.Slider, {"minimum": 0, "maximum": 3, "step": 1}), - "hypertile_max_tile_vae": shared.OptionInfo(128, "Hypertile VAE max tile size", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}), - "hypertile_swap_size_vae": shared.OptionInfo(3, "Hypertile VAE swap size ", gr.Slider, {"minimum": 0, "maximum": 64, "step": 1}), + "hypertile_enable_vae": shared.OptionInfo(False, "Enable Hypertile VAE", infotext="Hypertile VAE").info("minimal change in the generated picture"), + "hypertile_max_depth_vae": shared.OptionInfo(3, "Hypertile VAE max depth", gr.Slider, {"minimum": 0, "maximum": 3, "step": 1}, infotext="Hypertile VAE max depth"), + "hypertile_max_tile_vae": shared.OptionInfo(128, "Hypertile VAE max tile size", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}, infotext="Hypertile VAE max tile size"), + "hypertile_swap_size_vae": shared.OptionInfo(3, "Hypertile VAE swap size ", gr.Slider, {"minimum": 0, "maximum": 64, "step": 1}, infotext="Hypertile VAE swap size"), } for name, opt in options.items(): From 854f8c318c2610c76259056ab02739176aa849e8 Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Tue, 5 Dec 2023 04:40:12 +0900 Subject: [PATCH 0384/1174] remove clean_text() --- modules/styles.py | 23 +++-------------------- 1 file changed, 3 insertions(+), 20 deletions(-) diff --git a/modules/styles.py b/modules/styles.py index 4d218cd7ed0..7fb6c2e1177 100644 --- a/modules/styles.py +++ b/modules/styles.py @@ -2,7 +2,6 @@ import fnmatch import os import os.path -import re import typing import shutil @@ -14,22 +13,6 @@ class PromptStyle(typing.NamedTuple): path: str = None -def clean_text(text: str) -> str: - """ - Iterating through a list of regular expressions and replacement strings, we - clean up the prompt and style text to make it easier to match against each - other. - """ - re_list = [ - ("multiple commas", re.compile("(,+\s+)+,?"), ", "), - ("multiple spaces", re.compile("\s{2,}"), " "), - ] - for _, regex, replace in re_list: - text = regex.sub(replace, text) - - return text.strip(", ") - - def merge_prompts(style_prompt: str, prompt: str) -> str: if "{prompt}" in style_prompt: res = style_prompt.replace("{prompt}", prompt) @@ -44,7 +27,7 @@ def apply_styles_to_prompt(prompt, styles): for style in styles: prompt = merge_prompts(style, prompt) - return clean_text(prompt) + return prompt def unwrap_style_text_from_prompt(style_text, prompt): @@ -56,8 +39,8 @@ def unwrap_style_text_from_prompt(style_text, prompt): Note that the "cleaned" version of the style text is only used for matching purposes here. It isn't returned; the original style text is not modified. """ - stripped_prompt = clean_text(prompt) - stripped_style_text = clean_text(style_text) + stripped_prompt = prompt + stripped_style_text = style_text if "{prompt}" in stripped_style_text: # Work out whether the prompt is wrapped in the style text. If so, we # return True and the "inner" prompt text that isn't part of the style. From 976c1053efeb5054692ed3cfa294cf79196f3946 Mon Sep 17 00:00:00 2001 From: CodeHatchling Date: Mon, 4 Dec 2023 16:06:58 -0700 Subject: [PATCH 0385/1174] Cleaned up code, moved main code contributions into soft_inpainting.py --- modules/processing.py | 56 ++------- modules/sd_samplers_cfg_denoiser.py | 84 ++----------- modules/soft_inpainting.py | 177 ++++++++++++++++++++++++---- modules/ui.py | 7 -- 4 files changed, 174 insertions(+), 150 deletions(-) diff --git a/modules/processing.py b/modules/processing.py index b40b1a40d49..0b3603875a7 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -892,55 +892,13 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: # Generate the mask(s) based on similarity between the original and denoised latent vectors if getattr(p, "image_mask", None) is not None and getattr(p, "soft_inpainting", None) is not None: - # latent_mask = p.nmask[0].float().cpu() - - # convert the original mask into a form we use to scale distances for thresholding - # mask_scalar = 1-(torch.clamp(latent_mask, min=0, max=1) ** (p.mask_blend_scale / 2)) - # mask_scalar = mask_scalar / (1.00001-mask_scalar) - # mask_scalar = mask_scalar.numpy() - - latent_orig = p.init_latent - latent_proc = samples_ddim - latent_distance = torch.norm(latent_proc - latent_orig, p=2, dim=1) - - kernel, kernel_center = images.get_gaussian_kernel(stddev_radius=1.5, max_radius=2) - - for i, (distance_map, overlay_image) in enumerate(zip(latent_distance, p.overlay_images)): - converted_mask = distance_map.float().cpu().numpy() - converted_mask = images.weighted_histogram_filter(converted_mask, kernel, kernel_center, - percentile_min=0.9, percentile_max=1, min_width=1) - converted_mask = images.weighted_histogram_filter(converted_mask, kernel, kernel_center, - percentile_min=0.25, percentile_max=0.75, min_width=1) - - # The distance at which opacity of original decreases to 50% - # half_weighted_distance = 1 # * mask_scalar - # converted_mask = converted_mask / half_weighted_distance - - converted_mask = 1 / (1 + converted_mask ** 2) - converted_mask = images.smootherstep(converted_mask) - converted_mask = 1 - converted_mask - converted_mask = 255. * converted_mask - converted_mask = converted_mask.astype(np.uint8) - converted_mask = Image.fromarray(converted_mask) - converted_mask = images.resize_image(2, converted_mask, p.width, p.height) - converted_mask = create_binary_mask(converted_mask, round=False) - - # Remove aliasing artifacts using a gaussian blur. - converted_mask = converted_mask.filter(ImageFilter.GaussianBlur(radius=4)) - - # Expand the mask to fit the whole image if needed. - if p.paste_to is not None: - converted_mask = uncrop(converted_mask, - (overlay_image.width, overlay_image.height), - p.paste_to) - - p.masks_for_overlay[i] = converted_mask - - image_masked = Image.new('RGBa', (overlay_image.width, overlay_image.height)) - image_masked.paste(overlay_image.convert("RGBA").convert("RGBa"), - mask=ImageOps.invert(converted_mask.convert('L'))) - - p.overlay_images[i] = image_masked.convert('RGBA') + si.generate_adaptive_masks(latent_orig=p.init_latent, + latent_processed=samples_ddim, + overlay_images=p.overlay_images, + masks_for_overlay=p.masks_for_overlay, + width=p.width, + height=p.height, + paste_to=p.paste_to) x_samples_ddim = decode_latent_batch(p.sd_model, samples_ddim, target_device=devices.cpu, diff --git a/modules/sd_samplers_cfg_denoiser.py b/modules/sd_samplers_cfg_denoiser.py index 0ee0b7dde03..a700e6922fc 100644 --- a/modules/sd_samplers_cfg_denoiser.py +++ b/modules/sd_samplers_cfg_denoiser.py @@ -94,76 +94,6 @@ def update_inner_model(self): self.sampler.sampler_extra_args['uncond'] = uc def forward(self, x, sigma, uncond, cond, cond_scale, s_min_uncond, image_cond): - def latent_blend(a, b, t, one_minus_t=None): - - """ - Interpolates two latent image representations according to the parameter t, - where the interpolated vectors' magnitudes are also interpolated separately. - The "detail_preservation" factor biases the magnitude interpolation towards - the larger of the two magnitudes. - """ - # NOTE: We use inplace operations wherever possible. - - if one_minus_t is None: - one_minus_t = 1 - t - - if self.soft_inpainting is None: - return a * one_minus_t + b * t - - # Linearly interpolate the image vectors. - a_scaled = a * one_minus_t - b_scaled = b * t - image_interp = a_scaled - image_interp.add_(b_scaled) - result_type = image_interp.dtype - del a_scaled, b_scaled - - # Calculate the magnitude of the interpolated vectors. (We will remove this magnitude.) - # 64-bit operations are used here to allow large exponents. - current_magnitude = torch.norm(image_interp, p=2, dim=1).to(torch.float64).add_(0.00001) - - # Interpolate the powered magnitudes, then un-power them (bring them back to a power of 1). - a_magnitude = torch.norm(a, p=2, dim=1).to(torch.float64).pow_(self.soft_inpainting.inpaint_detail_preservation) * one_minus_t - b_magnitude = torch.norm(b, p=2, dim=1).to(torch.float64).pow_(self.soft_inpainting.inpaint_detail_preservation) * t - desired_magnitude = a_magnitude - desired_magnitude.add_(b_magnitude).pow_(1 / self.soft_inpainting.inpaint_detail_preservation) - del a_magnitude, b_magnitude, one_minus_t - - # Change the linearly interpolated image vectors' magnitudes to the value we want. - # This is the last 64-bit operation. - image_interp_scaling_factor = desired_magnitude - image_interp_scaling_factor.div_(current_magnitude) - image_interp_scaled = image_interp - image_interp_scaled.mul_(image_interp_scaling_factor) - del current_magnitude - del desired_magnitude - del image_interp - del image_interp_scaling_factor - - image_interp_scaled = image_interp_scaled.to(result_type) - del result_type - - return image_interp_scaled - - def get_modified_nmask(nmask, _sigma): - """ - Converts a negative mask representing the transparency of the original latent vectors being overlayed - to a mask that is scaled according to the denoising strength for this step. - - Where: - 0 = fully opaque, infinite density, fully masked - 1 = fully transparent, zero density, fully unmasked - - We bring this transparency to a power, as this allows one to simulate N number of blending operations - where N can be any positive real value. Using this one can control the balance of influence between - the denoiser and the original latents according to the sigma value. - - NOTE: "mask" is not used - """ - if self.soft_inpainting is None: - return nmask - - return torch.pow(nmask, (_sigma ** self.soft_inpainting.mask_blend_power) * self.soft_inpainting.mask_blend_scale) if state.interrupted or state.skipped: raise sd_samplers_common.InterruptedException @@ -184,9 +114,12 @@ def get_modified_nmask(nmask, _sigma): # Blend in the original latents (before) if self.mask_before_denoising and self.mask is not None: if self.soft_inpainting is None: - x = latent_blend(self.init_latent, x, self.nmask, self.mask) + x = self.init_latent * self.mask + self.nmask * x else: - x = latent_blend(self.init_latent, x, get_modified_nmask(self.nmask, sigma)) + x = si.latent_blend(self.soft_inpainting, + self.init_latent, + x, + si.get_modified_nmask(self.soft_inpainting, self.nmask, sigma)) batch_size = len(conds_list) repeats = [len(conds_list[i]) for i in range(batch_size)] @@ -290,9 +223,12 @@ def get_modified_nmask(nmask, _sigma): # Blend in the original latents (after) if not self.mask_before_denoising and self.mask is not None: if self.soft_inpainting is None: - denoised = latent_blend(self.init_latent, denoised, self.nmask, self.mask) + denoised = self.init_latent * self.mask + self.nmask * denoised else: - denoised = latent_blend(self.init_latent, denoised, get_modified_nmask(self.nmask, sigma)) + denoised = si.latent_blend(self.soft_inpainting, + self.init_latent, + denoised, + si.get_modified_nmask(self.soft_inpainting, self.nmask, sigma)) self.sampler.last_latent = self.get_pred_x0(torch.cat([x_in[i:i + 1] for i in denoised_image_indexes]), torch.cat([x_out[i:i + 1] for i in denoised_image_indexes]), sigma) diff --git a/modules/soft_inpainting.py b/modules/soft_inpainting.py index 259c36ec892..b81c8dd95d1 100644 --- a/modules/soft_inpainting.py +++ b/modules/soft_inpainting.py @@ -4,13 +4,6 @@ def __init__(self, mask_blend_power, mask_blend_scale, inpaint_detail_preservati self.mask_blend_scale = mask_blend_scale self.inpaint_detail_preservation = inpaint_detail_preservation - def get_paste_fields(self): - return [ - (self.mask_blend_power, gen_param_labels.mask_blend_power), - (self.mask_blend_scale, gen_param_labels.mask_blend_scale), - (self.inpaint_detail_preservation, gen_param_labels.inpaint_detail_preservation), - ] - def add_generation_params(self, dest): dest[enabled_gen_param_label] = True dest[gen_param_labels.mask_blend_power] = self.mask_blend_power @@ -18,25 +11,169 @@ def add_generation_params(self, dest): dest[gen_param_labels.inpaint_detail_preservation] = self.inpaint_detail_preservation +# ------------------- Methods ------------------- + + +def latent_blend(soft_inpainting, a, b, t): + """ + Interpolates two latent image representations according to the parameter t, + where the interpolated vectors' magnitudes are also interpolated separately. + The "detail_preservation" factor biases the magnitude interpolation towards + the larger of the two magnitudes. + """ + import torch + + # NOTE: We use inplace operations wherever possible. + + one_minus_t = 1 - t + + # Linearly interpolate the image vectors. + a_scaled = a * one_minus_t + b_scaled = b * t + image_interp = a_scaled + image_interp.add_(b_scaled) + result_type = image_interp.dtype + del a_scaled, b_scaled + + # Calculate the magnitude of the interpolated vectors. (We will remove this magnitude.) + # 64-bit operations are used here to allow large exponents. + current_magnitude = torch.norm(image_interp, p=2, dim=1).to(torch.float64).add_(0.00001) + + # Interpolate the powered magnitudes, then un-power them (bring them back to a power of 1). + a_magnitude = torch.norm(a, p=2, dim=1).to(torch.float64).pow_(soft_inpainting.inpaint_detail_preservation) * one_minus_t + b_magnitude = torch.norm(b, p=2, dim=1).to(torch.float64).pow_(soft_inpainting.inpaint_detail_preservation) * t + desired_magnitude = a_magnitude + desired_magnitude.add_(b_magnitude).pow_(1 / soft_inpainting.inpaint_detail_preservation) + del a_magnitude, b_magnitude, one_minus_t + + # Change the linearly interpolated image vectors' magnitudes to the value we want. + # This is the last 64-bit operation. + image_interp_scaling_factor = desired_magnitude + image_interp_scaling_factor.div_(current_magnitude) + image_interp_scaling_factor = image_interp_scaling_factor.to(result_type) + image_interp_scaled = image_interp + image_interp_scaled.mul_(image_interp_scaling_factor) + del current_magnitude + del desired_magnitude + del image_interp + del image_interp_scaling_factor + del result_type + + return image_interp_scaled + + +def get_modified_nmask(soft_inpainting, nmask, sigma): + """ + Converts a negative mask representing the transparency of the original latent vectors being overlayed + to a mask that is scaled according to the denoising strength for this step. + + Where: + 0 = fully opaque, infinite density, fully masked + 1 = fully transparent, zero density, fully unmasked + + We bring this transparency to a power, as this allows one to simulate N number of blending operations + where N can be any positive real value. Using this one can control the balance of influence between + the denoiser and the original latents according to the sigma value. + + NOTE: "mask" is not used + """ + import torch + return torch.pow(nmask, (sigma ** soft_inpainting.mask_blend_power) * soft_inpainting.mask_blend_scale) + + +def generate_adaptive_masks( + latent_orig, + latent_processed, + overlay_images, + masks_for_overlay, + width, height, + paste_to): + import torch + import numpy as np + import modules.processing as proc + import modules.images as images + from PIL import Image, ImageOps, ImageFilter + + # TODO: Bias the blending according to the latent mask, add adjustable parameter for bias control. + # latent_mask = p.nmask[0].float().cpu() + # convert the original mask into a form we use to scale distances for thresholding + # mask_scalar = 1-(torch.clamp(latent_mask, min=0, max=1) ** (p.mask_blend_scale / 2)) + # mask_scalar = mask_scalar / (1.00001-mask_scalar) + # mask_scalar = mask_scalar.numpy() + + latent_distance = torch.norm(latent_processed - latent_orig, p=2, dim=1) + + kernel, kernel_center = images.get_gaussian_kernel(stddev_radius=1.5, max_radius=2) + + for i, (distance_map, overlay_image) in enumerate(zip(latent_distance, overlay_images)): + converted_mask = distance_map.float().cpu().numpy() + converted_mask = images.weighted_histogram_filter(converted_mask, kernel, kernel_center, + percentile_min=0.9, percentile_max=1, min_width=1) + converted_mask = images.weighted_histogram_filter(converted_mask, kernel, kernel_center, + percentile_min=0.25, percentile_max=0.75, min_width=1) + + # The distance at which opacity of original decreases to 50% + # half_weighted_distance = 1 # * mask_scalar + # converted_mask = converted_mask / half_weighted_distance + + converted_mask = 1 / (1 + converted_mask ** 2) + converted_mask = images.smootherstep(converted_mask) + converted_mask = 1 - converted_mask + converted_mask = 255. * converted_mask + converted_mask = converted_mask.astype(np.uint8) + converted_mask = Image.fromarray(converted_mask) + converted_mask = images.resize_image(2, converted_mask, width, height) + converted_mask = proc.create_binary_mask(converted_mask, round=False) + + # Remove aliasing artifacts using a gaussian blur. + converted_mask = converted_mask.filter(ImageFilter.GaussianBlur(radius=4)) + + # Expand the mask to fit the whole image if needed. + if paste_to is not None: + converted_mask = proc. uncrop(converted_mask, + (overlay_image.width, overlay_image.height), + paste_to) + + masks_for_overlay[i] = converted_mask + + image_masked = Image.new('RGBa', (overlay_image.width, overlay_image.height)) + image_masked.paste(overlay_image.convert("RGBA").convert("RGBa"), + mask=ImageOps.invert(converted_mask.convert('L'))) + + overlay_images[i] = image_masked.convert('RGBA') + + +# ------------------- Constants ------------------- + + +default = SoftInpaintingSettings(1, 0.5, 4) + enabled_ui_label = "Soft inpainting" enabled_gen_param_label = "Soft inpainting enabled" enabled_el_id = "soft_inpainting_enabled" -default = SoftInpaintingSettings(1, 0.5, 4) -ui_labels = SoftInpaintingSettings("Schedule bias", "Preservation strength", "Transition contrast boost") +ui_labels = SoftInpaintingSettings( + "Schedule bias", + "Preservation strength", + "Transition contrast boost") ui_info = SoftInpaintingSettings( - mask_blend_power="Shifts when preservation of original content occurs during denoising.", - # "Below 1: Stronger preservation near the end (with low sigma)\n" - # "1: Balanced (proportional to sigma)\n" - # "Above 1: Stronger preservation in the beginning (with high sigma)", - mask_blend_scale="How strongly partially masked content should be preserved.", - # "Low values: Favors generated content.\n" - # "High values: Favors original content.", - inpaint_detail_preservation="Amplifies the contrast that may be lost in partially masked regions.") - -gen_param_labels = SoftInpaintingSettings("Soft inpainting schedule bias", "Soft inpainting preservation strength", "Soft inpainting transition contrast boost") -el_ids = SoftInpaintingSettings("mask_blend_power", "mask_blend_scale", "inpaint_detail_preservation") + "Shifts when preservation of original content occurs during denoising.", + "How strongly partially masked content should be preserved.", + "Amplifies the contrast that may be lost in partially masked regions.") + +gen_param_labels = SoftInpaintingSettings( + "Soft inpainting schedule bias", + "Soft inpainting preservation strength", + "Soft inpainting transition contrast boost") + +el_ids = SoftInpaintingSettings( + "mask_blend_power", + "mask_blend_scale", + "inpaint_detail_preservation") + + +# ------------------- UI ------------------- def gradio_ui(): diff --git a/modules/ui.py b/modules/ui.py index 0e4fb17aa75..4f1265a3e8b 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -683,13 +683,6 @@ def copy_image(img): with FormRow(): soft_inpainting = si.gradio_ui() - - """ - mask_blend_power = gr.Slider(label='Blending bias', minimum=0, maximum=8, step=0.1, value=1, elem_id="img2img_mask_blend_power") - mask_blend_scale = gr.Slider(label='Blending preservation', minimum=0, maximum=8, step=0.05, value=0.5, elem_id="img2img_mask_blend_scale") - inpaint_detail_preservation = gr.Slider(label='Blending contrast boost', minimum=1, maximum=32, step=0.5, value=4, elem_id="img2img_mask_blend_offset") - """ - with FormRow(): inpainting_mask_invert = gr.Radio(label='Mask mode', choices=['Inpaint masked', 'Inpaint not masked'], value='Inpaint masked', type="index", elem_id="img2img_mask_mode") From 1455159cf44cd8c21656818463f6095eae887540 Mon Sep 17 00:00:00 2001 From: CodeHatchling Date: Mon, 4 Dec 2023 16:43:57 -0700 Subject: [PATCH 0386/1174] Fixed issue with whitespace, removed commented out code that was meant to be used as a reference. --- modules/soft_inpainting.py | 20 ++++++-------------- 1 file changed, 6 insertions(+), 14 deletions(-) diff --git a/modules/soft_inpainting.py b/modules/soft_inpainting.py index b81c8dd95d1..56a87774693 100644 --- a/modules/soft_inpainting.py +++ b/modules/soft_inpainting.py @@ -179,15 +179,7 @@ def generate_adaptive_masks( def gradio_ui(): import gradio as gr from modules.ui_components import InputAccordion - """ - with InputAccordion(False, label="Refiner", elem_id=self.elem_id("enable")) as enable_refiner: - with gr.Row(): - refiner_checkpoint = gr.Dropdown(label='Checkpoint', elem_id=self.elem_id("checkpoint"), choices=sd_models.checkpoint_tiles(), value='', tooltip="switch to another model in the middle of generation") - create_refresh_button(refiner_checkpoint, sd_models.list_models, lambda: {"choices": sd_models.checkpoint_tiles()}, self.elem_id("checkpoint_refresh")) - - refiner_switch_at = gr.Slider(value=0.8, label="Switch at", minimum=0.01, maximum=1.0, step=0.01, elem_id=self.elem_id("switch_at"), tooltip="fraction of sampling steps when the switch to refiner model should happen; 1=never, 0.5=switch in the middle of generation") - """ with InputAccordion(False, label=enabled_ui_label, elem_id=enabled_el_id) as soft_inpainting_enabled: with gr.Group(): gr.Markdown( @@ -223,11 +215,11 @@ def gradio_ui(): gr.Markdown( f""" ### {ui_labels.mask_blend_power} - + The blending strength of original content is scaled proportionally with the decreasing noise level values at each step (sigmas). This ensures that the influence of the denoiser and original content preservation is roughly balanced at each step. This balance can be shifted using this parameter, controlling whether earlier or later steps have stronger preservation. - + - **Below 1**: Stronger preservation near the end (with low sigma) - **1**: Balanced (proportional to sigma) - **Above 1**: Stronger preservation in the beginning (with high sigma) @@ -235,21 +227,21 @@ def gradio_ui(): gr.Markdown( f""" ### {ui_labels.mask_blend_scale} - + Skews whether partially masked image regions should be more likely to preserve the original content or favor inpainted content. This may need to be adjusted depending on the {ui_labels.mask_blend_power}, CFG Scale, prompt and Denoising strength. - + - **Low values**: Favors generated content. - **High values**: Favors original content. """) gr.Markdown( f""" ### {ui_labels.inpaint_detail_preservation} - + This parameter controls how the original latent vectors and denoised latent vectors are interpolated. With higher values, the magnitude of the resulting blended vector will be closer to the maximum of the two interpolated vectors. This can prevent the loss of contrast that occurs with linear interpolation. - + - **Low values**: Softer blending, details may fade. - **High values**: Stronger contrast, may over-saturate colors. """) From 57f29bd61dc30f1a8c94ead9b780f4655f7d7d6d Mon Sep 17 00:00:00 2001 From: CodeHatchling Date: Mon, 4 Dec 2023 17:41:18 -0700 Subject: [PATCH 0387/1174] Re-introduce latent blending step from the vanilla inpainting procedure. --- modules/processing.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/modules/processing.py b/modules/processing.py index 0b3603875a7..c8dc4d9340e 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -1597,6 +1597,9 @@ def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subs samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning, image_conditioning=self.image_conditioning) + if self.mask is not None and self.soft_inpainting is None: + samples = samples * self.nmask + self.init_latent * self.mask + del x devices.torch_gc() From 60c602232fd760fb548fb0b3d18b5297f8823c2a Mon Sep 17 00:00:00 2001 From: CodeHatchling Date: Mon, 4 Dec 2023 17:41:51 -0700 Subject: [PATCH 0388/1174] Restored original formatting. --- modules/processing.py | 36 +++++++++++------------------------- 1 file changed, 11 insertions(+), 25 deletions(-) diff --git a/modules/processing.py b/modules/processing.py index c8dc4d9340e..90ae249a4db 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -370,10 +370,7 @@ def img2img_image_conditioning(self, source_image, latent_image, image_mask=None return self.edit_image_conditioning(source_image) if self.sampler.conditioning_key in {'hybrid', 'concat'}: - return self.inpainting_image_conditioning(source_image, - latent_image, - image_mask=image_mask, - round_image_mask=round_image_mask) + return self.inpainting_image_conditioning(source_image, latent_image, image_mask=image_mask, round_image_mask=round_image_mask) if self.sampler.conditioning_key == "crossattn-adm": return self.unclip_image_conditioning(source_image) @@ -885,7 +882,8 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: if getattr(samples_ddim, 'already_decoded', False): x_samples_ddim = samples_ddim - # todo: generate masks the old fashioned way + # todo: generate adaptive masks based on pixel differences. + # if p.masks_for_overlay is used, it will already be populated with masks else: if opts.sd_vae_decode_method != 'Full': p.extra_generation_params['VAE Decoder'] = opts.sd_vae_decode_method @@ -900,9 +898,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: height=p.height, paste_to=p.paste_to) - x_samples_ddim = decode_latent_batch(p.sd_model, samples_ddim, - target_device=devices.cpu, - check_for_nans=True) + x_samples_ddim = decode_latent_batch(p.sd_model, samples_ddim, target_device=devices.cpu, check_for_nans=True) x_samples_ddim = torch.stack(x_samples_ddim).float() x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) @@ -927,9 +923,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: x_samples_ddim = batch_params.images def infotext(index=0, use_main_prompt=False): - return create_infotext(p, p.prompts, p.seeds, p.subseeds, - use_main_prompt=use_main_prompt, index=index, - all_negative_prompts=p.negative_prompts) + return create_infotext(p, p.prompts, p.seeds, p.subseeds, use_main_prompt=use_main_prompt, index=index, all_negative_prompts=p.negative_prompts) save_samples = p.save_samples() @@ -972,8 +966,7 @@ def infotext(index=0, use_main_prompt=False): image = apply_overlay(image, p.paste_to, i, p.overlay_images) if save_samples: - images.save_image(image, p.outpath_samples, "", p.seeds[i], - p.prompts[i], opts.samples_format, info=infotext(i), p=p) + images.save_image(image, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p) text = infotext(i) infotexts.append(text) @@ -983,14 +976,10 @@ def infotext(index=0, use_main_prompt=False): if save_samples and any([opts.save_mask, opts.save_mask_composite, opts.return_mask, opts.return_mask_composite]): if hasattr(p, 'masks_for_overlay') and p.masks_for_overlay: image_mask = p.masks_for_overlay[i].convert('RGB') - image_mask_composite = Image.composite( - original_denoised_image.convert('RGBA').convert('RGBa'), Image.new('RGBa', image.size), - images.resize_image(2, p.masks_for_overlay[i], image.width, image.height).convert('L')).convert('RGBA') + image_mask_composite = Image.composite(original_denoised_image.convert('RGBA').convert('RGBa'), Image.new('RGBa', image.size), images.resize_image(2, p.masks_for_overlay[i], image.width, image.height).convert('L')).convert('RGBA') elif hasattr(p, 'mask_for_overlay') and p.mask_for_overlay: image_mask = p.mask_for_overlay.convert('RGB') - image_mask_composite = Image.composite( - original_denoised_image.convert('RGBA').convert('RGBa'), Image.new('RGBa', image.size), - images.resize_image(2, p.mask_for_overlay, image.width, image.height).convert('L')).convert('RGBA') + image_mask_composite = Image.composite(original_denoised_image.convert('RGBA').convert('RGBa'), Image.new('RGBa', image.size), images.resize_image(2, p.mask_for_overlay, image.width, image.height).convert('L')).convert('RGBA') else: image_mask = None image_mask_composite = None @@ -1515,8 +1504,8 @@ def init(self, all_prompts, all_seeds, all_subseeds): self.masks_for_overlay.append(image_mask) else: image_masked = Image.new('RGBa', (image.width, image.height)) - image_masked.paste(image.convert("RGBA").convert("RGBa"), - mask=ImageOps.invert(self.mask_for_overlay.convert('L'))) + image_masked.paste(image.convert("RGBA").convert("RGBa"), mask=ImageOps.invert(self.mask_for_overlay.convert('L'))) + self.overlay_images.append(image_masked.convert('RGBA')) # crop_region is not None if we are doing inpaint full res @@ -1583,10 +1572,7 @@ def init(self, all_prompts, all_seeds, all_subseeds): elif self.inpainting_fill == 3: self.init_latent = self.init_latent * self.mask - self.image_conditioning = self.img2img_image_conditioning(image * 2 - 1, - self.init_latent, - image_mask, - self.soft_inpainting is None) + self.image_conditioning = self.img2img_image_conditioning(image * 2 - 1, self.init_latent, image_mask, self.soft_inpainting is None) def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts): x = self.rng.next() From b32a334e3da7b06d82441beaa08a673b4f55bca1 Mon Sep 17 00:00:00 2001 From: CodeHatchling Date: Mon, 4 Dec 2023 17:57:10 -0700 Subject: [PATCH 0389/1174] Applies a convert('RGBA') operation early to mimic previous behaviour. --- modules/processing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/processing.py b/modules/processing.py index 90ae249a4db..7fc282cfd21 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -1500,7 +1500,7 @@ def init(self, all_prompts, all_seeds, all_subseeds): if image_mask is not None: if self.soft_inpainting is not None: # We apply the masks AFTER to adjust mask based on changed content. - self.overlay_images.append(image) + self.overlay_images.append(image.convert('RGBA')) self.masks_for_overlay.append(image_mask) else: image_masked = Image.new('RGBa', (image.width, image.height)) From 6fc12428e3c5f903584ca7986e0c441f80fa2807 Mon Sep 17 00:00:00 2001 From: CodeHatchling Date: Mon, 4 Dec 2023 19:42:59 -0700 Subject: [PATCH 0390/1174] Fixed issue where batched inpainting (batch size > 1) wouldn't work because of mismatched tensor sizes. The 'already_decoded' decoded case should also be handled correctly (tested indirectly). --- modules/processing.py | 23 ++++++++----- modules/soft_inpainting.py | 66 ++++++++++++++++++++++++++++++++------ 2 files changed, 71 insertions(+), 18 deletions(-) diff --git a/modules/processing.py b/modules/processing.py index 7fc282cfd21..71bb056a26b 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -883,20 +883,27 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: if getattr(samples_ddim, 'already_decoded', False): x_samples_ddim = samples_ddim # todo: generate adaptive masks based on pixel differences. - # if p.masks_for_overlay is used, it will already be populated with masks + if getattr(p, "image_mask", None) is not None and getattr(p, "soft_inpainting", None) is not None: + si.apply_masks(soft_inpainting=p.soft_inpainting, + nmask=p.nmask, + overlay_images=p.overlay_images, + masks_for_overlay=p.masks_for_overlay, + width=p.width, + height=p.height, + paste_to=p.paste_to) else: if opts.sd_vae_decode_method != 'Full': p.extra_generation_params['VAE Decoder'] = opts.sd_vae_decode_method # Generate the mask(s) based on similarity between the original and denoised latent vectors if getattr(p, "image_mask", None) is not None and getattr(p, "soft_inpainting", None) is not None: - si.generate_adaptive_masks(latent_orig=p.init_latent, - latent_processed=samples_ddim, - overlay_images=p.overlay_images, - masks_for_overlay=p.masks_for_overlay, - width=p.width, - height=p.height, - paste_to=p.paste_to) + si.apply_adaptive_masks(latent_orig=p.init_latent, + latent_processed=samples_ddim, + overlay_images=p.overlay_images, + masks_for_overlay=p.masks_for_overlay, + width=p.width, + height=p.height, + paste_to=p.paste_to) x_samples_ddim = decode_latent_batch(p.sd_model, samples_ddim, target_device=devices.cpu, check_for_nans=True) diff --git a/modules/soft_inpainting.py b/modules/soft_inpainting.py index 56a87774693..b36ac8fa1b7 100644 --- a/modules/soft_inpainting.py +++ b/modules/soft_inpainting.py @@ -25,26 +25,32 @@ def latent_blend(soft_inpainting, a, b, t): # NOTE: We use inplace operations wherever possible. - one_minus_t = 1 - t + # [4][w][h] to [1][4][w][h] + t2 = t.unsqueeze(0) + # [4][w][h] to [1][1][w][h] - the [4] seem redundant. + t3 = t[0].unsqueeze(0).unsqueeze(0) + + one_minus_t2 = 1 - t2 + one_minus_t3 = 1 - t3 # Linearly interpolate the image vectors. - a_scaled = a * one_minus_t - b_scaled = b * t + a_scaled = a * one_minus_t2 + b_scaled = b * t2 image_interp = a_scaled image_interp.add_(b_scaled) result_type = image_interp.dtype - del a_scaled, b_scaled + del a_scaled, b_scaled, t2, one_minus_t2 # Calculate the magnitude of the interpolated vectors. (We will remove this magnitude.) # 64-bit operations are used here to allow large exponents. - current_magnitude = torch.norm(image_interp, p=2, dim=1).to(torch.float64).add_(0.00001) + current_magnitude = torch.norm(image_interp, p=2, dim=1, keepdim=True).to(torch.float64).add_(0.00001) # Interpolate the powered magnitudes, then un-power them (bring them back to a power of 1). - a_magnitude = torch.norm(a, p=2, dim=1).to(torch.float64).pow_(soft_inpainting.inpaint_detail_preservation) * one_minus_t - b_magnitude = torch.norm(b, p=2, dim=1).to(torch.float64).pow_(soft_inpainting.inpaint_detail_preservation) * t + a_magnitude = torch.norm(a, p=2, dim=1, keepdim=True).to(torch.float64).pow_(soft_inpainting.inpaint_detail_preservation) * one_minus_t3 + b_magnitude = torch.norm(b, p=2, dim=1, keepdim=True).to(torch.float64).pow_(soft_inpainting.inpaint_detail_preservation) * t3 desired_magnitude = a_magnitude desired_magnitude.add_(b_magnitude).pow_(1 / soft_inpainting.inpaint_detail_preservation) - del a_magnitude, b_magnitude, one_minus_t + del a_magnitude, b_magnitude, t3, one_minus_t3 # Change the linearly interpolated image vectors' magnitudes to the value we want. # This is the last 64-bit operation. @@ -78,10 +84,11 @@ def get_modified_nmask(soft_inpainting, nmask, sigma): NOTE: "mask" is not used """ import torch - return torch.pow(nmask, (sigma ** soft_inpainting.mask_blend_power) * soft_inpainting.mask_blend_scale) + # todo: Why is sigma 2D? Both values are the same. + return torch.pow(nmask, (sigma[0] ** soft_inpainting.mask_blend_power) * soft_inpainting.mask_blend_scale) -def generate_adaptive_masks( +def apply_adaptive_masks( latent_orig, latent_processed, overlay_images, @@ -142,6 +149,45 @@ def generate_adaptive_masks( overlay_images[i] = image_masked.convert('RGBA') +def apply_masks( + soft_inpainting, + nmask, + overlay_images, + masks_for_overlay, + width, height, + paste_to): + import torch + import numpy as np + import modules.processing as proc + import modules.images as images + from PIL import Image, ImageOps, ImageFilter + + converted_mask = nmask[0].float() + converted_mask = torch.clamp(converted_mask, min=0, max=1).pow_(soft_inpainting.mask_blend_scale / 2) + converted_mask = 255. * converted_mask + converted_mask = converted_mask.cpu().numpy().astype(np.uint8) + converted_mask = Image.fromarray(converted_mask) + converted_mask = images.resize_image(2, converted_mask, width, height) + converted_mask = proc.create_binary_mask(converted_mask, round=False) + + # Remove aliasing artifacts using a gaussian blur. + converted_mask = converted_mask.filter(ImageFilter.GaussianBlur(radius=4)) + + # Expand the mask to fit the whole image if needed. + if paste_to is not None: + converted_mask = proc.uncrop(converted_mask, + (width, height), + paste_to) + + for i, overlay_image in enumerate(overlay_images): + masks_for_overlay[i] = converted_mask + + image_masked = Image.new('RGBa', (overlay_image.width, overlay_image.height)) + image_masked.paste(overlay_image.convert("RGBA").convert("RGBa"), + mask=ImageOps.invert(converted_mask.convert('L'))) + + overlay_images[i] = image_masked.convert('RGBA') + # ------------------- Constants ------------------- From 49bbf1140731036875573bb7c44aa7e74623c856 Mon Sep 17 00:00:00 2001 From: CodeHatchling Date: Mon, 4 Dec 2023 19:47:40 -0700 Subject: [PATCH 0391/1174] Fixed unused import. --- modules/processing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/processing.py b/modules/processing.py index 71bb056a26b..e1823ac3321 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -9,7 +9,7 @@ import torch import numpy as np -from PIL import Image, ImageOps, ImageFilter +from PIL import Image, ImageOps import random import cv2 from skimage import exposure From 120a84bd2f01ec4489bd12bd68f319798ef30782 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Tue, 5 Dec 2023 07:15:39 +0300 Subject: [PATCH 0392/1174] Merge pull request #14203 from AUTOMATIC1111/remove-clean_text() remove clean_text() --- modules/styles.py | 23 +++-------------------- 1 file changed, 3 insertions(+), 20 deletions(-) diff --git a/modules/styles.py b/modules/styles.py index 4d218cd7ed0..7fb6c2e1177 100644 --- a/modules/styles.py +++ b/modules/styles.py @@ -2,7 +2,6 @@ import fnmatch import os import os.path -import re import typing import shutil @@ -14,22 +13,6 @@ class PromptStyle(typing.NamedTuple): path: str = None -def clean_text(text: str) -> str: - """ - Iterating through a list of regular expressions and replacement strings, we - clean up the prompt and style text to make it easier to match against each - other. - """ - re_list = [ - ("multiple commas", re.compile("(,+\s+)+,?"), ", "), - ("multiple spaces", re.compile("\s{2,}"), " "), - ] - for _, regex, replace in re_list: - text = regex.sub(replace, text) - - return text.strip(", ") - - def merge_prompts(style_prompt: str, prompt: str) -> str: if "{prompt}" in style_prompt: res = style_prompt.replace("{prompt}", prompt) @@ -44,7 +27,7 @@ def apply_styles_to_prompt(prompt, styles): for style in styles: prompt = merge_prompts(style, prompt) - return clean_text(prompt) + return prompt def unwrap_style_text_from_prompt(style_text, prompt): @@ -56,8 +39,8 @@ def unwrap_style_text_from_prompt(style_text, prompt): Note that the "cleaned" version of the style text is only used for matching purposes here. It isn't returned; the original style text is not modified. """ - stripped_prompt = clean_text(prompt) - stripped_style_text = clean_text(style_text) + stripped_prompt = prompt + stripped_style_text = style_text if "{prompt}" in stripped_style_text: # Work out whether the prompt is wrapped in the style text. If so, we # return True and the "inner" prompt text that isn't part of the style. From 895456c4a2e87f5fe3ee23b4482e68fce317a1ca Mon Sep 17 00:00:00 2001 From: Jabasukuriputo Wang Date: Tue, 5 Dec 2023 18:00:48 -0600 Subject: [PATCH 0393/1174] change state dict comparison to ref compare --- modules/sd_disable_initialization.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/sd_disable_initialization.py b/modules/sd_disable_initialization.py index 8863107ae6f..273a7edd8b4 100644 --- a/modules/sd_disable_initialization.py +++ b/modules/sd_disable_initialization.py @@ -215,7 +215,7 @@ def load_state_dict(original, module, state_dict, strict=True): would be on the meta device. """ - if state_dict == sd: + if state_dict is sd: state_dict = {k: v.to(device="meta", dtype=v.dtype) for k, v in state_dict.items()} original(module, state_dict, strict=strict) From 672dc4efa8e0da38426b121e7c7216d0a8e465fd Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Wed, 6 Dec 2023 15:16:10 +0800 Subject: [PATCH 0394/1174] Fix forced reload --- modules/sd_models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/sd_models.py b/modules/sd_models.py index dcf816b3a85..d0046f88c6d 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -801,7 +801,7 @@ def reload_model_weights(sd_model=None, info=None, forced_reload=False): if check_fp8(sd_model) != devices.fp8: # load from state dict again to prevent extra numerical errors forced_reload = True - elif sd_model.sd_model_checkpoint == checkpoint_info.filename: + elif sd_model.sd_model_checkpoint == checkpoint_info.filename and not forced_reload: return sd_model sd_model = reuse_model_from_already_loaded(sd_model, checkpoint_info, timer) From 4d56383025f2cbd00dc6296161e31a896624ab75 Mon Sep 17 00:00:00 2001 From: "fuchen.ljl" Date: Wed, 6 Dec 2023 20:23:56 +0800 Subject: [PATCH 0395/1174] Long distance memory overflow issue Problem: The memory will slowly increase with the drawing until restarting. Observation: GC analysis shows that no occupation has occurred, so it is suspected to be a problem with the underlying allocator. Reason: Under Linux, glibc is used to allocate memory. glibc uses brk and mmap to allocate memory, and the memory allocated by brk cannot be released until the high-address memory is released. That is to say, if you apply for two pieces of memory A and B through brk, it is impossible to release A before B is released, and it is still occupied by the process. Check the suspected "memory leak" through TOP. So I replaced TCMalloc, but found that libtcmalloc_minimal could not find ptthread_Key_Create. After analysis, it was found that pthread was not entered during compilation. --- webui.sh | 30 +++++++++++++++++++++++------- 1 file changed, 23 insertions(+), 7 deletions(-) diff --git a/webui.sh b/webui.sh index 3d0f87eed74..081624c4eb8 100755 --- a/webui.sh +++ b/webui.sh @@ -222,13 +222,29 @@ fi # Try using TCMalloc on Linux prepare_tcmalloc() { if [[ "${OSTYPE}" == "linux"* ]] && [[ -z "${NO_TCMALLOC}" ]] && [[ -z "${LD_PRELOAD}" ]]; then - TCMALLOC="$(PATH=/usr/sbin:$PATH ldconfig -p | grep -Po "libtcmalloc(_minimal|)\.so\.\d" | head -n 1)" - if [[ ! -z "${TCMALLOC}" ]]; then - echo "Using TCMalloc: ${TCMALLOC}" - export LD_PRELOAD="${TCMALLOC}" - else - printf "\e[1m\e[31mCannot locate TCMalloc (improves CPU memory usage)\e[0m\n" - fi + # Define Tcmalloc Libs arrays + TCMALLOC_LIBS=("libtcmalloc(_minimal|)\.so\.\d" "libtcmalloc\.so\.\d") + + # Traversal array + for lib in "${TCMALLOC_LIBS[@]}" + do + #Determine which type of tcmalloc library the library supports + TCMALLOC="$(PATH=/usr/sbin:$PATH ldconfig -p | grep -P $lib | head -n 1)" + TC_INFO=(${TCMALLOC//=>/}) + if [[ ! -z "${TC_INFO}" ]]; then + echo "Using TCMalloc: ${TC_INFO}" + #Determine if the library is linked to libptthread and resolve undefined symbol: ptthread_Key_Create + if ldd ${TC_INFO[2]} | grep -q 'libpthread'; then + echo "$TC_INFO is linked with libpthread,execute LD_PRELOAD=${TC_INFO}" + export LD_PRELOAD="${TC_INFO}" + break + else + echo "$TC_INFO is not linked with libpthreadand will trigger undefined symbol: ptthread_Key_Create error" + fi + else + printf "\e[1m\e[31mCannot locate TCMalloc (improves CPU memory usage)\e[0m\n" + fi + done fi } From 746783f7a47f38f728f221cc26fe04035d3ca66b Mon Sep 17 00:00:00 2001 From: Nuullll Date: Wed, 6 Dec 2023 20:55:42 +0800 Subject: [PATCH 0396/1174] [IPEX] Fix embedding Cast `torch.bmm` args into same `dtype`. Fixes the following error when using Text Inversion embedding (#14224): ``` RuntimeError: could not create a primitive descriptor for a matmul primitive ``` --- modules/xpu_specific.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/modules/xpu_specific.py b/modules/xpu_specific.py index d933c790378..ec1ad100aae 100644 --- a/modules/xpu_specific.py +++ b/modules/xpu_specific.py @@ -48,3 +48,6 @@ def torch_xpu_gc(): CondFunc('torch.nn.modules.conv.Conv2d.forward', lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)), lambda orig_func, self, input: input.dtype != self.weight.data.dtype) + CondFunc('torch.bmm', + lambda orig_func, input, mat2, out=None: orig_func(input.to(mat2.dtype), mat2, out=out), + lambda orig_func, input, mat2, out=None: input.dtype != mat2.dtype) From 9d2cbf8e97832662e446145d3961c39e78919d3d Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Wed, 6 Dec 2023 23:06:32 +0900 Subject: [PATCH 0397/1174] add option: Live preview in full page image viewer make #13459 "show the preview image in the modal view if available" optional --- javascript/imageviewer.js | 2 +- modules/shared_options.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/javascript/imageviewer.js b/javascript/imageviewer.js index e4dae91bc9d..625c5d148df 100644 --- a/javascript/imageviewer.js +++ b/javascript/imageviewer.js @@ -34,7 +34,7 @@ function updateOnBackgroundChange() { if (modalImage && modalImage.offsetParent) { let currentButton = selected_gallery_button(); let preview = gradioApp().querySelectorAll('.livePreview > img'); - if (preview.length > 0) { + if (opts.js_live_preview_in_modal_lightbox && preview.length > 0) { // show preview image if available modalImage.src = preview[preview.length - 1].src; } else if (currentButton?.children?.length > 0 && modalImage.src != currentButton.children[0].src) { diff --git a/modules/shared_options.py b/modules/shared_options.py index e5de0d018fb..88cfddedeb4 100644 --- a/modules/shared_options.py +++ b/modules/shared_options.py @@ -330,6 +330,7 @@ "live_preview_content": OptionInfo("Prompt", "Live preview subject", gr.Radio, {"choices": ["Combined", "Prompt", "Negative prompt"]}), "live_preview_refresh_period": OptionInfo(1000, "Progressbar and preview update period").info("in milliseconds"), "live_preview_fast_interrupt": OptionInfo(False, "Return image with chosen live preview method on interrupt").info("makes interrupts faster"), + "js_live_preview_in_modal_lightbox": OptionInfo(True, "Show Live preview in full page image viewer"), })) options_templates.update(options_section(('sampler-params', "Sampler parameters", "sd"), { From e90d4334ad37024a802f4ef27069b625a6508f72 Mon Sep 17 00:00:00 2001 From: CodeHatchling Date: Wed, 6 Dec 2023 16:54:42 -0700 Subject: [PATCH 0398/1174] A custom blending function can be provided by p, replacing the use of soft_inpainting. --- modules/sd_samplers_cfg_denoiser.py | 34 ++++++++++++++--------------- modules/sd_samplers_common.py | 1 - 2 files changed, 17 insertions(+), 18 deletions(-) diff --git a/modules/sd_samplers_cfg_denoiser.py b/modules/sd_samplers_cfg_denoiser.py index a700e6922fc..f13e8dcc5b0 100644 --- a/modules/sd_samplers_cfg_denoiser.py +++ b/modules/sd_samplers_cfg_denoiser.py @@ -6,7 +6,6 @@ from modules.script_callbacks import CFGDenoiserParams, cfg_denoiser_callback from modules.script_callbacks import CFGDenoisedParams, cfg_denoised_callback from modules.script_callbacks import AfterCFGCallbackParams, cfg_after_cfg_callback -import modules.soft_inpainting as si def catenate_conds(conds): @@ -44,7 +43,6 @@ def __init__(self, sampler): self.model_wrap = None self.mask = None self.nmask = None - self.soft_inpainting: si.SoftInpaintingParameters = None self.init_latent = None self.steps = None """number of steps as specified by user in UI""" @@ -94,7 +92,6 @@ def update_inner_model(self): self.sampler.sampler_extra_args['uncond'] = uc def forward(self, x, sigma, uncond, cond, cond_scale, s_min_uncond, image_cond): - if state.interrupted or state.skipped: raise sd_samplers_common.InterruptedException @@ -111,15 +108,24 @@ def forward(self, x, sigma, uncond, cond, cond_scale, s_min_uncond, image_cond): assert not is_edit_model or all(len(conds) == 1 for conds in conds_list), "AND is not supported for InstructPix2Pix checkpoint (unless using Image CFG scale = 1.0)" + # If we use masks, blending between the denoised and original latent images occurs here. + def apply_blend(latent): + if hasattr(self.p, "denoiser_masked_blend_function") and callable(self.p.denoiser_masked_blend_function): + return self.p.denoiser_masked_blend_function( + self, + # Using an argument dictionary so that arguments can be added without breaking extensions. + args= + { + "denoiser": self, + "current_latent": latent, + "sigma": sigma + }) + else: + return self.init_latent * self.mask + self.nmask * latent + # Blend in the original latents (before) if self.mask_before_denoising and self.mask is not None: - if self.soft_inpainting is None: - x = self.init_latent * self.mask + self.nmask * x - else: - x = si.latent_blend(self.soft_inpainting, - self.init_latent, - x, - si.get_modified_nmask(self.soft_inpainting, self.nmask, sigma)) + x = apply_blend(x) batch_size = len(conds_list) repeats = [len(conds_list[i]) for i in range(batch_size)] @@ -222,13 +228,7 @@ def forward(self, x, sigma, uncond, cond, cond_scale, s_min_uncond, image_cond): # Blend in the original latents (after) if not self.mask_before_denoising and self.mask is not None: - if self.soft_inpainting is None: - denoised = self.init_latent * self.mask + self.nmask * denoised - else: - denoised = si.latent_blend(self.soft_inpainting, - self.init_latent, - denoised, - si.get_modified_nmask(self.soft_inpainting, self.nmask, sigma)) + denoised = apply_blend(denoised) self.sampler.last_latent = self.get_pred_x0(torch.cat([x_in[i:i + 1] for i in denoised_image_indexes]), torch.cat([x_out[i:i + 1] for i in denoised_image_indexes]), sigma) diff --git a/modules/sd_samplers_common.py b/modules/sd_samplers_common.py index 9682bee3d8e..58efcad2374 100644 --- a/modules/sd_samplers_common.py +++ b/modules/sd_samplers_common.py @@ -277,7 +277,6 @@ def initialize(self, p) -> dict: self.model_wrap_cfg.p = p self.model_wrap_cfg.mask = p.mask if hasattr(p, 'mask') else None self.model_wrap_cfg.nmask = p.nmask if hasattr(p, 'nmask') else None - self.model_wrap_cfg.soft_inpainting = p.soft_inpainting if hasattr(p, 'soft_inpainting') else None self.model_wrap_cfg.step = 0 self.model_wrap_cfg.image_cfg_scale = getattr(p, 'image_cfg_scale', None) self.eta = p.eta if p.eta is not None else getattr(opts, self.eta_option_field, 0.0) From 4608f6236fc24d937f89500b2c9bf48484537cf9 Mon Sep 17 00:00:00 2001 From: CodeHatchling Date: Wed, 6 Dec 2023 18:11:17 -0700 Subject: [PATCH 0399/1174] Removed changes in some scripts since the arguments for soft painting are no longer passed through the same path as "mask_blur". --- modules/img2img.py | 50 +------------------------------- modules/ui.py | 7 ----- scripts/outpainting_mk_2.py | 9 ++---- scripts/poor_mans_outpainting.py | 8 ++--- test/test_img2img.py | 5 ---- 5 files changed, 5 insertions(+), 74 deletions(-) diff --git a/modules/img2img.py b/modules/img2img.py index 3aa8a9cef6c..c583290a0ea 100644 --- a/modules/img2img.py +++ b/modules/img2img.py @@ -15,7 +15,6 @@ import modules.processing as processing from modules.ui import plaintext_to_html import modules.scripts -import modules.soft_inpainting as si def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=False, scale_by=1.0, use_png_info=False, png_info_props=None, png_info_dir=None): @@ -147,48 +146,7 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=Fal return batch_results -def img2img(id_task: str, - mode: int, - prompt: str, - negative_prompt: str, - prompt_styles, - init_img, - sketch, - init_img_with_mask, - inpaint_color_sketch, - inpaint_color_sketch_orig, - init_img_inpaint, - init_mask_inpaint, - steps: int, - sampler_name: str, - mask_blur: int, - mask_alpha: float, - mask_blend_enabled: bool, - mask_blend_power: float, - mask_blend_scale: float, - inpaint_detail_preservation: float, - inpainting_fill: int, - n_iter: int, - batch_size: int, - cfg_scale: float, - image_cfg_scale: float, - denoising_strength: float, - selected_scale_tab: int, - height: int, - width: int, - scale_by: float, - resize_mode: int, - inpaint_full_res: bool, - inpaint_full_res_padding: int, - inpainting_mask_invert: int, - img2img_batch_input_dir: str, - img2img_batch_output_dir: str, - img2img_batch_inpaint_mask_dir: str, - override_settings_texts, - img2img_batch_use_png_info: bool, - img2img_batch_png_info_props: list, - img2img_batch_png_info_dir: str, - request: gr.Request, *args): +def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_name: str, mask_blur: int, mask_alpha: float, inpainting_fill: int, n_iter: int, batch_size: int, cfg_scale: float, image_cfg_scale: float, denoising_strength: float, selected_scale_tab: int, height: int, width: int, scale_by: float, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, override_settings_texts, img2img_batch_use_png_info: bool, img2img_batch_png_info_props: list, img2img_batch_png_info_dir: str, request: gr.Request, *args): override_settings = create_override_settings_dict(override_settings_texts) is_batch = mode == 5 @@ -229,9 +187,6 @@ def img2img(id_task: str, assert 0. <= denoising_strength <= 1., 'can only work with strength in [0.0, 1.0]' - soft_inpainting = si.SoftInpaintingSettings(mask_blend_power, mask_blend_scale, inpaint_detail_preservation) \ - if mask_blend_enabled else None - p = StableDiffusionProcessingImg2Img( sd_model=shared.sd_model, outpath_samples=opts.outdir_samples or opts.outdir_img2img_samples, @@ -249,7 +204,6 @@ def img2img(id_task: str, init_images=[image], mask=mask, mask_blur=mask_blur, - soft_inpainting=soft_inpainting, inpainting_fill=inpainting_fill, resize_mode=resize_mode, denoising_strength=denoising_strength, @@ -270,8 +224,6 @@ def img2img(id_task: str, if mask: p.extra_generation_params["Mask blur"] = mask_blur - if soft_inpainting is not None: - soft_inpainting.add_generation_params(p.extra_generation_params) with closing(p): if is_batch: diff --git a/modules/ui.py b/modules/ui.py index bd2091e1ff9..d80486dd4ac 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -29,7 +29,6 @@ from modules import prompt_parser from modules.sd_hijack import model_hijack from modules.generation_parameters_copypaste import image_from_url_text -import modules.soft_inpainting as si create_setting_component = ui_settings.create_setting_component @@ -680,9 +679,6 @@ def copy_image(img): mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=4, elem_id="img2img_mask_blur") mask_alpha = gr.Slider(label="Mask transparency", visible=False, elem_id="img2img_mask_alpha") - with FormRow(): - soft_inpainting = si.gradio_ui() - with FormRow(): inpainting_mask_invert = gr.Radio(label='Mask mode', choices=['Inpaint masked', 'Inpaint not masked'], value='Inpaint masked', type="index", elem_id="img2img_mask_mode") @@ -737,7 +733,6 @@ def select_img2img_tab(tab): sampler_name, mask_blur, mask_alpha, - *(soft_inpainting[0]), inpainting_fill, batch_count, batch_size, @@ -836,10 +831,8 @@ def select_img2img_tab(tab): (toprow.ui_styles.dropdown, lambda d: d["Styles array"] if isinstance(d.get("Styles array"), list) else gr.update()), (denoising_strength, "Denoising strength"), (mask_blur, "Mask blur"), - *(soft_inpainting[1]), *scripts.scripts_img2img.infotext_fields ] - parameters_copypaste.add_paste_fields("img2img", init_img, img2img_paste_fields, override_settings) parameters_copypaste.add_paste_fields("inpaint", init_img_with_mask, img2img_paste_fields, override_settings) parameters_copypaste.register_paste_params_button(parameters_copypaste.ParamBinding( diff --git a/scripts/outpainting_mk_2.py b/scripts/outpainting_mk_2.py index f78886883bd..c98ab48098e 100644 --- a/scripts/outpainting_mk_2.py +++ b/scripts/outpainting_mk_2.py @@ -10,7 +10,6 @@ from modules import images from modules.processing import Processed, process_images from modules.shared import opts, state -import modules.soft_inpainting as si # this function is taken from https://github.com/parlance-zz/g-diffuser-bot @@ -134,14 +133,13 @@ def ui(self, is_img2img): pixels = gr.Slider(label="Pixels to expand", minimum=8, maximum=256, step=8, value=128, elem_id=self.elem_id("pixels")) mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=8, elem_id=self.elem_id("mask_blur")) - soft_inpainting = si.gradio_ui()[0] direction = gr.CheckboxGroup(label="Outpainting direction", choices=['left', 'right', 'up', 'down'], value=['left', 'right', 'up', 'down'], elem_id=self.elem_id("direction")) noise_q = gr.Slider(label="Fall-off exponent (lower=higher detail)", minimum=0.0, maximum=4.0, step=0.01, value=1.0, elem_id=self.elem_id("noise_q")) color_variation = gr.Slider(label="Color variation", minimum=0.0, maximum=1.0, step=0.01, value=0.05, elem_id=self.elem_id("color_variation")) - return [info, pixels, mask_blur, *soft_inpainting, direction, noise_q, color_variation] + return [info, pixels, mask_blur, direction, noise_q, color_variation] - def run(self, p, _, pixels, mask_blur, mask_blend_enabled, mask_blend_power, mask_blend_scale, inpaint_detail_preservation, direction, noise_q, color_variation): + def run(self, p, _, pixels, mask_blur, direction, noise_q, color_variation): initial_seed_and_info = [None, None] process_width = p.width @@ -170,9 +168,6 @@ def run(self, p, _, pixels, mask_blur, mask_blend_enabled, mask_blend_power, mas p.mask_blur_x = mask_blur_x*4 p.mask_blur_y = mask_blur_y*4 - p.soft_inpainting = si.SoftInpaintingSettings(mask_blend_power, mask_blend_scale, inpaint_detail_preservation) \ - if mask_blend_enabled else None - init_img = p.init_images[0] target_w = math.ceil((init_img.width + left + right) / 64) * 64 target_h = math.ceil((init_img.height + up + down) / 64) * 64 diff --git a/scripts/poor_mans_outpainting.py b/scripts/poor_mans_outpainting.py index 11f7f74a820..ea0632b68c1 100644 --- a/scripts/poor_mans_outpainting.py +++ b/scripts/poor_mans_outpainting.py @@ -7,7 +7,6 @@ from modules import images, devices from modules.processing import Processed, process_images from modules.shared import opts, state -import modules.soft_inpainting as si class Script(scripts.Script): @@ -23,19 +22,16 @@ def ui(self, is_img2img): pixels = gr.Slider(label="Pixels to expand", minimum=8, maximum=256, step=8, value=128, elem_id=self.elem_id("pixels")) mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=4, elem_id=self.elem_id("mask_blur")) - soft_inpainting = si.gradio_ui()[0] inpainting_fill = gr.Radio(label='Masked content', choices=['fill', 'original', 'latent noise', 'latent nothing'], value='fill', type="index", elem_id=self.elem_id("inpainting_fill")) direction = gr.CheckboxGroup(label="Outpainting direction", choices=['left', 'right', 'up', 'down'], value=['left', 'right', 'up', 'down'], elem_id=self.elem_id("direction")) - return [pixels, mask_blur, *soft_inpainting, inpainting_fill, direction] + return [pixels, mask_blur, inpainting_fill, direction] - def run(self, p, pixels, mask_blur, mask_blend_enabled, mask_blend_power, mask_blend_scale, inpaint_detail_preservation, inpainting_fill, direction): + def run(self, p, pixels, mask_blur, inpainting_fill, direction): initial_seed = None initial_info = None p.mask_blur = mask_blur * 2 - p.soft_inpainting = si.SoftInpaintingSettings(mask_blend_power, mask_blend_scale, inpaint_detail_preservation) \ - if mask_blend_enabled else None p.inpainting_fill = inpainting_fill p.inpaint_full_res = False diff --git a/test/test_img2img.py b/test/test_img2img.py index 87bd850912e..117d2d1eb45 100644 --- a/test/test_img2img.py +++ b/test/test_img2img.py @@ -1,7 +1,6 @@ import pytest import requests -import modules.soft_inpainting as si @pytest.fixture() @@ -25,10 +24,6 @@ def simple_img2img_request(img2img_basic_image_base64): "inpainting_mask_invert": False, "mask": None, "mask_blur": 4, - "mask_blend_enabled": True, - "mask_blend_power": si.default.mask_blend_power, - "mask_blend_scale": si.default.mask_blend_scale, - "inpaint_detail_preservation": si.default.inpaint_detail_preservation, "n_iter": 1, "negative_prompt": "", "override_settings": {}, From ac4578912395627731f2cd8529f87a95df1f7644 Mon Sep 17 00:00:00 2001 From: CodeHatchling Date: Wed, 6 Dec 2023 21:16:27 -0700 Subject: [PATCH 0400/1174] Removed soft inpainting, added hooks for softpainting to work instead. --- modules/processing.py | 94 ++++++++++++----------------- modules/scripts.py | 70 +++++++++++++++++++++ modules/sd_samplers_cfg_denoiser.py | 23 +++---- 3 files changed, 118 insertions(+), 69 deletions(-) diff --git a/modules/processing.py b/modules/processing.py index 7d46949fa3e..5a1a90afedd 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -30,7 +30,6 @@ import modules.sd_vae as sd_vae from ldm.data.util import AddMiDaS from ldm.models.diffusion.ddpm import LatentDepth2ImageDiffusion -import modules.soft_inpainting as si from einops import repeat, rearrange from blendmodes.blend import blendLayers, BlendType @@ -73,12 +72,10 @@ def uncrop(image, dest_size, paste_loc): return image -def apply_overlay(image, paste_loc, index, overlays): - if overlays is None or index >= len(overlays): +def apply_overlay(image, paste_loc, overlay): + if overlay is None: return image - overlay = overlays[index] - if paste_loc is not None: image = uncrop(image, (overlay.width, overlay.height), paste_loc) @@ -150,7 +147,6 @@ class StableDiffusionProcessing: do_not_save_grid: bool = False extra_generation_params: dict[str, Any] = None overlay_images: list = None - masks_for_overlay: list = None eta: float = None do_not_reload_embeddings: bool = False denoising_strength: float = None @@ -880,31 +876,17 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast(): samples_ddim = p.sample(conditioning=p.c, unconditional_conditioning=p.uc, seeds=p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, prompts=p.prompts) + if p.scripts is not None: + ps = scripts.PostSampleArgs(samples_ddim) + p.scripts.post_sample(p, ps) + samples_ddim = pp.samples + if getattr(samples_ddim, 'already_decoded', False): x_samples_ddim = samples_ddim - # todo: generate adaptive masks based on pixel differences. - if getattr(p, "image_mask", None) is not None and getattr(p, "soft_inpainting", None) is not None: - si.apply_masks(soft_inpainting=p.soft_inpainting, - nmask=p.nmask, - overlay_images=p.overlay_images, - masks_for_overlay=p.masks_for_overlay, - width=p.width, - height=p.height, - paste_to=p.paste_to) else: if opts.sd_vae_decode_method != 'Full': p.extra_generation_params['VAE Decoder'] = opts.sd_vae_decode_method - # Generate the mask(s) based on similarity between the original and denoised latent vectors - if getattr(p, "image_mask", None) is not None and getattr(p, "soft_inpainting", None) is not None: - si.apply_adaptive_masks(latent_orig=p.init_latent, - latent_processed=samples_ddim, - overlay_images=p.overlay_images, - masks_for_overlay=p.masks_for_overlay, - width=p.width, - height=p.height, - paste_to=p.paste_to) - x_samples_ddim = decode_latent_batch(p.sd_model, samples_ddim, target_device=devices.cpu, check_for_nans=True) x_samples_ddim = torch.stack(x_samples_ddim).float() @@ -955,9 +937,18 @@ def infotext(index=0, use_main_prompt=False): pp = scripts.PostprocessImageArgs(image) p.scripts.postprocess_image(p, pp) image = pp.image + + mask_for_overlay = p.mask_for_overlay + overlay_image = p.overlay_images[i] if p.overlay_images is not None and i < len(p.overlay_images) else None + + if p.scripts is not None: + ppmo = scripts.PostProcessMaskOverlayArgs(i, mask_for_overlay, overlay_image) + p.scripts.postprocess_maskoverlay(p, ppmo) + mask_for_overlay, overlay_image = pp.mask_for_overlay, pp.overlay_image + if p.color_corrections is not None and i < len(p.color_corrections): if save_samples and opts.save_images_before_color_correction: - image_without_cc = apply_overlay(image, p.paste_to, i, p.overlay_images) + image_without_cc = apply_overlay(image, p.paste_to, overlay_image) images.save_image(image_without_cc, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-before-color-correction") image = apply_color_correction(p.color_corrections[i], image) @@ -968,9 +959,9 @@ def infotext(index=0, use_main_prompt=False): original_denoised_image = image.copy() if p.paste_to is not None: - original_denoised_image = uncrop(original_denoised_image, (p.overlay_images[i].width, p.overlay_images[i].height), p.paste_to) + original_denoised_image = uncrop(original_denoised_image, (p.overlay_image.width, p.overlay_image.height), p.paste_to) - image = apply_overlay(image, p.paste_to, i, p.overlay_images) + image = apply_overlay(image, p.paste_to, overlay_image) if save_samples: images.save_image(image, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p) @@ -981,13 +972,6 @@ def infotext(index=0, use_main_prompt=False): image.info["parameters"] = text output_images.append(image) - if hasattr(p, 'mask_for_overlay') and p.mask_for_overlay: - mask_for_overlay = p.mask_for_overlay - elif hasattr(p, 'masks_for_overlay') and p.masks_for_overlay and p.masks_for_overlay[i]: - mask_for_overlay = p.masks_for_overlay[i] - else: - mask_for_overlay = None - if mask_for_overlay is not None: if opts.return_mask or opts.save_mask: image_mask = mask_for_overlay.convert('RGB') @@ -1401,7 +1385,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): mask_blur_x: int = 4 mask_blur_y: int = 4 mask_blur: int = None - soft_inpainting: si.SoftInpaintingParameters = si.default + mask_round: bool = True inpainting_fill: int = 0 inpaint_full_res: bool = True inpaint_full_res_padding: int = 0 @@ -1447,7 +1431,7 @@ def init(self, all_prompts, all_seeds, all_subseeds): if image_mask is not None: # image_mask is passed in as RGBA by Gradio to support alpha masks, # but we still want to support binary masks. - image_mask = create_binary_mask(image_mask, round=(self.soft_inpainting is None)) + image_mask = create_binary_mask(image_mask, round=self.mask_round) if self.inpainting_mask_invert: image_mask = ImageOps.invert(image_mask) @@ -1465,7 +1449,7 @@ def init(self, all_prompts, all_seeds, all_subseeds): image_mask = Image.fromarray(np_mask) if self.inpaint_full_res: - self.mask_for_overlay = image_mask if self.soft_inpainting is None else None + self.mask_for_overlay = image_mask mask = image_mask.convert('L') crop_region = masking.get_crop_region(np.array(mask), self.inpaint_full_res_padding) crop_region = masking.expand_crop_region(crop_region, self.width, self.height, mask.width, mask.height) @@ -1476,13 +1460,10 @@ def init(self, all_prompts, all_seeds, all_subseeds): self.paste_to = (x1, y1, x2-x1, y2-y1) else: image_mask = images.resize_image(self.resize_mode, image_mask, self.width, self.height) + np_mask = np.array(image_mask) + np_mask = np.clip((np_mask.astype(np.float32)) * 2, 0, 255).astype(np.uint8) + self.mask_for_overlay = Image.fromarray(np_mask) - if self.soft_inpainting is None: - np_mask = np.array(image_mask) - np_mask = np.clip((np_mask.astype(np.float32)) * 2, 0, 255).astype(np.uint8) - self.mask_for_overlay = Image.fromarray(np_mask) - - self.masks_for_overlay = [] if self.soft_inpainting is not None else None self.overlay_images = [] latent_mask = self.latent_mask if self.latent_mask is not None else image_mask @@ -1504,15 +1485,10 @@ def init(self, all_prompts, all_seeds, all_subseeds): image = images.resize_image(self.resize_mode, image, self.width, self.height) if image_mask is not None: - if self.soft_inpainting is not None: - # We apply the masks AFTER to adjust mask based on changed content. - self.overlay_images.append(image.convert('RGBA')) - self.masks_for_overlay.append(image_mask) - else: - image_masked = Image.new('RGBa', (image.width, image.height)) - image_masked.paste(image.convert("RGBA").convert("RGBa"), mask=ImageOps.invert(self.mask_for_overlay.convert('L'))) + image_masked = Image.new('RGBa', (image.width, image.height)) + image_masked.paste(image.convert("RGBA").convert("RGBa"), mask=ImageOps.invert(self.mask_for_overlay.convert('L'))) - self.overlay_images.append(image_masked.convert('RGBA')) + self.overlay_images.append(image_masked.convert('RGBA')) # crop_region is not None if we are doing inpaint full res if crop_region is not None: @@ -1565,7 +1541,7 @@ def init(self, all_prompts, all_seeds, all_subseeds): latmask = init_mask.convert('RGB').resize((self.init_latent.shape[3], self.init_latent.shape[2])) latmask = np.moveaxis(np.array(latmask, dtype=np.float32), 2, 0) / 255 latmask = latmask[0] - if self.soft_inpainting is None: + if self.mask_round: latmask = np.around(latmask) latmask = np.tile(latmask[None], (4, 1, 1)) @@ -1578,7 +1554,7 @@ def init(self, all_prompts, all_seeds, all_subseeds): elif self.inpainting_fill == 3: self.init_latent = self.init_latent * self.mask - self.image_conditioning = self.img2img_image_conditioning(image * 2 - 1, self.init_latent, image_mask, self.soft_inpainting is None) + self.image_conditioning = self.img2img_image_conditioning(image * 2 - 1, self.init_latent, image_mask, self.mask_round) def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts): x = self.rng.next() @@ -1589,8 +1565,14 @@ def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subs samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning, image_conditioning=self.image_conditioning) - if self.mask is not None and self.soft_inpainting is None: - samples = samples * self.nmask + self.init_latent * self.mask + blended_samples = samples * self.nmask + self.init_latent * self.mask + + if self.scripts is not None: + mba = scripts.MaskBlendArgs(self, samples, self.nmask, self.init_latent, self.mask, blended_samples, sigma=None, is_final_blend=True) + self.scripts.on_mask_blend(self, mba) + blended_samples = mba.blended_latent + + samples = blended_samples del x devices.torch_gc() diff --git a/modules/scripts.py b/modules/scripts.py index 7f9454eb578..92a07c564e9 100644 --- a/modules/scripts.py +++ b/modules/scripts.py @@ -11,11 +11,31 @@ AlwaysVisible = object() +class MaskBlendArgs: + def __init__(self, current_latent, nmask, init_latent, mask, blended_samples, denoiser=None, sigma=None): + self.current_latent = current_latent + self.nmask = nmask + self.init_latent = init_latent + self.mask = mask + self.blended_samples = blended_samples + + self.denoiser = denoiser + self.is_final_blend = denoiser is None + self.sigma = sigma + +class PostSampleArgs: + def __init__(self, samples): + self.samples = samples class PostprocessImageArgs: def __init__(self, image): self.image = image +class PostProcessMaskOverlayArgs: + def __init__(self, index, mask_for_overlay, overlay_image): + self.index = index + self.mask_for_overlay = mask_for_overlay + self.overlay_image = overlay_image class PostprocessBatchListArgs: def __init__(self, images): @@ -206,6 +226,25 @@ def postprocess_batch_list(self, p, pp: PostprocessBatchListArgs, *args, **kwarg pass + def on_mask_blend(self, p, mba: MaskBlendArgs, *args): + """ + Called in inpainting mode when the original content is blended with the inpainted content. + This is called at every step in the denoising process and once at the end. + If is_final_blend is true, this is called for the final blending stage. + Otherwise, denoiser and sigma are defined and may be used to inform the procedure. + """ + + pass + + def post_sample(self, p, ps: PostSampleArgs, *args): + """ + Called after the samples have been generated, + but before they have been decoded by the VAE, if applicable. + Check getattr(samples, 'already_decoded', False) to test if the images are decoded. + """ + + pass + def postprocess_image(self, p, pp: PostprocessImageArgs, *args): """ Called for every image after it has been generated. @@ -213,6 +252,13 @@ def postprocess_image(self, p, pp: PostprocessImageArgs, *args): pass + def postprocess_maskoverlay(self, p, ppmo: PostProcessMaskOverlayArgs, *args): + """ + Called for every image after it has been generated. + """ + + pass + def postprocess(self, p, processed, *args): """ This function is called after processing ends for AlwaysVisible scripts. @@ -767,6 +813,22 @@ def postprocess_batch_list(self, p, pp: PostprocessBatchListArgs, **kwargs): except Exception: errors.report(f"Error running postprocess_batch_list: {script.filename}", exc_info=True) + def post_sample(self, p, ps: PostSampleArgs): + for script in self.alwayson_scripts: + try: + script_args = p.script_args[script.args_from:script.args_to] + script.post_sample(p, ps, *script_args) + except Exception: + errors.report(f"Error running post_sample: {script.filename}", exc_info=True) + + def on_mask_blend(self, p, mba: MaskBlendArgs): + for script in self.alwayson_scripts: + try: + script_args = p.script_args[script.args_from:script.args_to] + script.on_mask_blend(p, mba, *script_args) + except Exception: + errors.report(f"Error running post_sample: {script.filename}", exc_info=True) + def postprocess_image(self, p, pp: PostprocessImageArgs): for script in self.alwayson_scripts: try: @@ -775,6 +837,14 @@ def postprocess_image(self, p, pp: PostprocessImageArgs): except Exception: errors.report(f"Error running postprocess_image: {script.filename}", exc_info=True) + def postprocess_maskoverlay(self, p, ppmo: PostProcessMaskOverlayArgs): + for script in self.alwayson_scripts: + try: + script_args = p.script_args[script.args_from:script.args_to] + script.postprocess_maskoverlay(p, ppmo, *script_args) + except Exception: + errors.report(f"Error running postprocess_image: {script.filename}", exc_info=True) + def before_component(self, component, **kwargs): for callback, script in self.on_before_component_elem_id.get(kwargs.get("elem_id"), []): try: diff --git a/modules/sd_samplers_cfg_denoiser.py b/modules/sd_samplers_cfg_denoiser.py index f13e8dcc5b0..eb9d5dafacc 100644 --- a/modules/sd_samplers_cfg_denoiser.py +++ b/modules/sd_samplers_cfg_denoiser.py @@ -109,19 +109,16 @@ def forward(self, x, sigma, uncond, cond, cond_scale, s_min_uncond, image_cond): assert not is_edit_model or all(len(conds) == 1 for conds in conds_list), "AND is not supported for InstructPix2Pix checkpoint (unless using Image CFG scale = 1.0)" # If we use masks, blending between the denoised and original latent images occurs here. - def apply_blend(latent): - if hasattr(self.p, "denoiser_masked_blend_function") and callable(self.p.denoiser_masked_blend_function): - return self.p.denoiser_masked_blend_function( - self, - # Using an argument dictionary so that arguments can be added without breaking extensions. - args= - { - "denoiser": self, - "current_latent": latent, - "sigma": sigma - }) - else: - return self.init_latent * self.mask + self.nmask * latent + def apply_blend(current_latent): + blended_latent = current_latent * self.nmask + self.init_latent * self.mask + + if self.p.scripts is not None: + from modules import scripts + mba = scripts.MaskBlendArgs(current_latent, self.nmask, self.init_latent, self.mask, blended_latent, denoiser=self, sigma=sigma) + self.p.scripts.on_mask_blend(self.p, mba) + blended_latent = mba.blended_latent + + return blended_latent # Blend in the original latents (before) if self.mask_before_denoising and self.mask is not None: From 2abc417834d752e43a283f8603bfddfb1c80b30f Mon Sep 17 00:00:00 2001 From: CodeHatchling Date: Wed, 6 Dec 2023 22:25:53 -0700 Subject: [PATCH 0401/1174] Re-implemented soft inpainting via a script. Also fixed some mistakes with the previous hooks, removed unnecessary formatting changes, removed code that I had forgotten to. --- modules/processing.py | 23 +-- modules/scripts.py | 4 +- modules/soft_inpainting.py | 308 ---------------------------- scripts/soft_inpainting.py | 401 +++++++++++++++++++++++++++++++++++++ 4 files changed, 413 insertions(+), 323 deletions(-) delete mode 100644 modules/soft_inpainting.py create mode 100644 scripts/soft_inpainting.py diff --git a/modules/processing.py b/modules/processing.py index 5a1a90afedd..f8d85bdf527 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -879,14 +879,13 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: if p.scripts is not None: ps = scripts.PostSampleArgs(samples_ddim) p.scripts.post_sample(p, ps) - samples_ddim = pp.samples + samples_ddim = ps.samples if getattr(samples_ddim, 'already_decoded', False): x_samples_ddim = samples_ddim else: if opts.sd_vae_decode_method != 'Full': p.extra_generation_params['VAE Decoder'] = opts.sd_vae_decode_method - x_samples_ddim = decode_latent_batch(p.sd_model, samples_ddim, target_device=devices.cpu, check_for_nans=True) x_samples_ddim = torch.stack(x_samples_ddim).float() @@ -944,7 +943,7 @@ def infotext(index=0, use_main_prompt=False): if p.scripts is not None: ppmo = scripts.PostProcessMaskOverlayArgs(i, mask_for_overlay, overlay_image) p.scripts.postprocess_maskoverlay(p, ppmo) - mask_for_overlay, overlay_image = pp.mask_for_overlay, pp.overlay_image + mask_for_overlay, overlay_image = ppmo.mask_for_overlay, ppmo.overlay_image if p.color_corrections is not None and i < len(p.color_corrections): if save_samples and opts.save_images_before_color_correction: @@ -959,7 +958,7 @@ def infotext(index=0, use_main_prompt=False): original_denoised_image = image.copy() if p.paste_to is not None: - original_denoised_image = uncrop(original_denoised_image, (p.overlay_image.width, p.overlay_image.height), p.paste_to) + original_denoised_image = uncrop(original_denoised_image, (overlay_image.width, overlay_image.height), p.paste_to) image = apply_overlay(image, p.paste_to, overlay_image) @@ -1512,9 +1511,6 @@ def init(self, all_prompts, all_seeds, all_subseeds): if self.overlay_images is not None: self.overlay_images = self.overlay_images * self.batch_size - if self.masks_for_overlay is not None: - self.masks_for_overlay = self.masks_for_overlay * self.batch_size - if self.color_corrections is not None and len(self.color_corrections) == 1: self.color_corrections = self.color_corrections * self.batch_size @@ -1565,14 +1561,15 @@ def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subs samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning, image_conditioning=self.image_conditioning) - blended_samples = samples * self.nmask + self.init_latent * self.mask + if self.mask is not None: + blended_samples = samples * self.nmask + self.init_latent * self.mask - if self.scripts is not None: - mba = scripts.MaskBlendArgs(self, samples, self.nmask, self.init_latent, self.mask, blended_samples, sigma=None, is_final_blend=True) - self.scripts.on_mask_blend(self, mba) - blended_samples = mba.blended_latent + if self.scripts is not None: + mba = scripts.MaskBlendArgs(samples, self.nmask, self.init_latent, self.mask, blended_samples) + self.scripts.on_mask_blend(self, mba) + blended_samples = mba.blended_latent - samples = blended_samples + samples = blended_samples del x devices.torch_gc() diff --git a/modules/scripts.py b/modules/scripts.py index 92a07c564e9..b6fcf96e047 100644 --- a/modules/scripts.py +++ b/modules/scripts.py @@ -12,12 +12,12 @@ AlwaysVisible = object() class MaskBlendArgs: - def __init__(self, current_latent, nmask, init_latent, mask, blended_samples, denoiser=None, sigma=None): + def __init__(self, current_latent, nmask, init_latent, mask, blended_latent, denoiser=None, sigma=None): self.current_latent = current_latent self.nmask = nmask self.init_latent = init_latent self.mask = mask - self.blended_samples = blended_samples + self.blended_latent = blended_latent self.denoiser = denoiser self.is_final_blend = denoiser is None diff --git a/modules/soft_inpainting.py b/modules/soft_inpainting.py deleted file mode 100644 index b36ac8fa1b7..00000000000 --- a/modules/soft_inpainting.py +++ /dev/null @@ -1,308 +0,0 @@ -class SoftInpaintingSettings: - def __init__(self, mask_blend_power, mask_blend_scale, inpaint_detail_preservation): - self.mask_blend_power = mask_blend_power - self.mask_blend_scale = mask_blend_scale - self.inpaint_detail_preservation = inpaint_detail_preservation - - def add_generation_params(self, dest): - dest[enabled_gen_param_label] = True - dest[gen_param_labels.mask_blend_power] = self.mask_blend_power - dest[gen_param_labels.mask_blend_scale] = self.mask_blend_scale - dest[gen_param_labels.inpaint_detail_preservation] = self.inpaint_detail_preservation - - -# ------------------- Methods ------------------- - - -def latent_blend(soft_inpainting, a, b, t): - """ - Interpolates two latent image representations according to the parameter t, - where the interpolated vectors' magnitudes are also interpolated separately. - The "detail_preservation" factor biases the magnitude interpolation towards - the larger of the two magnitudes. - """ - import torch - - # NOTE: We use inplace operations wherever possible. - - # [4][w][h] to [1][4][w][h] - t2 = t.unsqueeze(0) - # [4][w][h] to [1][1][w][h] - the [4] seem redundant. - t3 = t[0].unsqueeze(0).unsqueeze(0) - - one_minus_t2 = 1 - t2 - one_minus_t3 = 1 - t3 - - # Linearly interpolate the image vectors. - a_scaled = a * one_minus_t2 - b_scaled = b * t2 - image_interp = a_scaled - image_interp.add_(b_scaled) - result_type = image_interp.dtype - del a_scaled, b_scaled, t2, one_minus_t2 - - # Calculate the magnitude of the interpolated vectors. (We will remove this magnitude.) - # 64-bit operations are used here to allow large exponents. - current_magnitude = torch.norm(image_interp, p=2, dim=1, keepdim=True).to(torch.float64).add_(0.00001) - - # Interpolate the powered magnitudes, then un-power them (bring them back to a power of 1). - a_magnitude = torch.norm(a, p=2, dim=1, keepdim=True).to(torch.float64).pow_(soft_inpainting.inpaint_detail_preservation) * one_minus_t3 - b_magnitude = torch.norm(b, p=2, dim=1, keepdim=True).to(torch.float64).pow_(soft_inpainting.inpaint_detail_preservation) * t3 - desired_magnitude = a_magnitude - desired_magnitude.add_(b_magnitude).pow_(1 / soft_inpainting.inpaint_detail_preservation) - del a_magnitude, b_magnitude, t3, one_minus_t3 - - # Change the linearly interpolated image vectors' magnitudes to the value we want. - # This is the last 64-bit operation. - image_interp_scaling_factor = desired_magnitude - image_interp_scaling_factor.div_(current_magnitude) - image_interp_scaling_factor = image_interp_scaling_factor.to(result_type) - image_interp_scaled = image_interp - image_interp_scaled.mul_(image_interp_scaling_factor) - del current_magnitude - del desired_magnitude - del image_interp - del image_interp_scaling_factor - del result_type - - return image_interp_scaled - - -def get_modified_nmask(soft_inpainting, nmask, sigma): - """ - Converts a negative mask representing the transparency of the original latent vectors being overlayed - to a mask that is scaled according to the denoising strength for this step. - - Where: - 0 = fully opaque, infinite density, fully masked - 1 = fully transparent, zero density, fully unmasked - - We bring this transparency to a power, as this allows one to simulate N number of blending operations - where N can be any positive real value. Using this one can control the balance of influence between - the denoiser and the original latents according to the sigma value. - - NOTE: "mask" is not used - """ - import torch - # todo: Why is sigma 2D? Both values are the same. - return torch.pow(nmask, (sigma[0] ** soft_inpainting.mask_blend_power) * soft_inpainting.mask_blend_scale) - - -def apply_adaptive_masks( - latent_orig, - latent_processed, - overlay_images, - masks_for_overlay, - width, height, - paste_to): - import torch - import numpy as np - import modules.processing as proc - import modules.images as images - from PIL import Image, ImageOps, ImageFilter - - # TODO: Bias the blending according to the latent mask, add adjustable parameter for bias control. - # latent_mask = p.nmask[0].float().cpu() - # convert the original mask into a form we use to scale distances for thresholding - # mask_scalar = 1-(torch.clamp(latent_mask, min=0, max=1) ** (p.mask_blend_scale / 2)) - # mask_scalar = mask_scalar / (1.00001-mask_scalar) - # mask_scalar = mask_scalar.numpy() - - latent_distance = torch.norm(latent_processed - latent_orig, p=2, dim=1) - - kernel, kernel_center = images.get_gaussian_kernel(stddev_radius=1.5, max_radius=2) - - for i, (distance_map, overlay_image) in enumerate(zip(latent_distance, overlay_images)): - converted_mask = distance_map.float().cpu().numpy() - converted_mask = images.weighted_histogram_filter(converted_mask, kernel, kernel_center, - percentile_min=0.9, percentile_max=1, min_width=1) - converted_mask = images.weighted_histogram_filter(converted_mask, kernel, kernel_center, - percentile_min=0.25, percentile_max=0.75, min_width=1) - - # The distance at which opacity of original decreases to 50% - # half_weighted_distance = 1 # * mask_scalar - # converted_mask = converted_mask / half_weighted_distance - - converted_mask = 1 / (1 + converted_mask ** 2) - converted_mask = images.smootherstep(converted_mask) - converted_mask = 1 - converted_mask - converted_mask = 255. * converted_mask - converted_mask = converted_mask.astype(np.uint8) - converted_mask = Image.fromarray(converted_mask) - converted_mask = images.resize_image(2, converted_mask, width, height) - converted_mask = proc.create_binary_mask(converted_mask, round=False) - - # Remove aliasing artifacts using a gaussian blur. - converted_mask = converted_mask.filter(ImageFilter.GaussianBlur(radius=4)) - - # Expand the mask to fit the whole image if needed. - if paste_to is not None: - converted_mask = proc. uncrop(converted_mask, - (overlay_image.width, overlay_image.height), - paste_to) - - masks_for_overlay[i] = converted_mask - - image_masked = Image.new('RGBa', (overlay_image.width, overlay_image.height)) - image_masked.paste(overlay_image.convert("RGBA").convert("RGBa"), - mask=ImageOps.invert(converted_mask.convert('L'))) - - overlay_images[i] = image_masked.convert('RGBA') - -def apply_masks( - soft_inpainting, - nmask, - overlay_images, - masks_for_overlay, - width, height, - paste_to): - import torch - import numpy as np - import modules.processing as proc - import modules.images as images - from PIL import Image, ImageOps, ImageFilter - - converted_mask = nmask[0].float() - converted_mask = torch.clamp(converted_mask, min=0, max=1).pow_(soft_inpainting.mask_blend_scale / 2) - converted_mask = 255. * converted_mask - converted_mask = converted_mask.cpu().numpy().astype(np.uint8) - converted_mask = Image.fromarray(converted_mask) - converted_mask = images.resize_image(2, converted_mask, width, height) - converted_mask = proc.create_binary_mask(converted_mask, round=False) - - # Remove aliasing artifacts using a gaussian blur. - converted_mask = converted_mask.filter(ImageFilter.GaussianBlur(radius=4)) - - # Expand the mask to fit the whole image if needed. - if paste_to is not None: - converted_mask = proc.uncrop(converted_mask, - (width, height), - paste_to) - - for i, overlay_image in enumerate(overlay_images): - masks_for_overlay[i] = converted_mask - - image_masked = Image.new('RGBa', (overlay_image.width, overlay_image.height)) - image_masked.paste(overlay_image.convert("RGBA").convert("RGBa"), - mask=ImageOps.invert(converted_mask.convert('L'))) - - overlay_images[i] = image_masked.convert('RGBA') - - -# ------------------- Constants ------------------- - - -default = SoftInpaintingSettings(1, 0.5, 4) - -enabled_ui_label = "Soft inpainting" -enabled_gen_param_label = "Soft inpainting enabled" -enabled_el_id = "soft_inpainting_enabled" - -ui_labels = SoftInpaintingSettings( - "Schedule bias", - "Preservation strength", - "Transition contrast boost") - -ui_info = SoftInpaintingSettings( - "Shifts when preservation of original content occurs during denoising.", - "How strongly partially masked content should be preserved.", - "Amplifies the contrast that may be lost in partially masked regions.") - -gen_param_labels = SoftInpaintingSettings( - "Soft inpainting schedule bias", - "Soft inpainting preservation strength", - "Soft inpainting transition contrast boost") - -el_ids = SoftInpaintingSettings( - "mask_blend_power", - "mask_blend_scale", - "inpaint_detail_preservation") - - -# ------------------- UI ------------------- - - -def gradio_ui(): - import gradio as gr - from modules.ui_components import InputAccordion - - with InputAccordion(False, label=enabled_ui_label, elem_id=enabled_el_id) as soft_inpainting_enabled: - with gr.Group(): - gr.Markdown( - """ - Soft inpainting allows you to **seamlessly blend original content with inpainted content** according to the mask opacity. - **High _Mask blur_** values are recommended! - """) - - result = SoftInpaintingSettings( - gr.Slider(label=ui_labels.mask_blend_power, - info=ui_info.mask_blend_power, - minimum=0, - maximum=8, - step=0.1, - value=default.mask_blend_power, - elem_id=el_ids.mask_blend_power), - gr.Slider(label=ui_labels.mask_blend_scale, - info=ui_info.mask_blend_scale, - minimum=0, - maximum=8, - step=0.05, - value=default.mask_blend_scale, - elem_id=el_ids.mask_blend_scale), - gr.Slider(label=ui_labels.inpaint_detail_preservation, - info=ui_info.inpaint_detail_preservation, - minimum=1, - maximum=32, - step=0.5, - value=default.inpaint_detail_preservation, - elem_id=el_ids.inpaint_detail_preservation)) - - with gr.Accordion("Help", open=False): - gr.Markdown( - f""" - ### {ui_labels.mask_blend_power} - - The blending strength of original content is scaled proportionally with the decreasing noise level values at each step (sigmas). - This ensures that the influence of the denoiser and original content preservation is roughly balanced at each step. - This balance can be shifted using this parameter, controlling whether earlier or later steps have stronger preservation. - - - **Below 1**: Stronger preservation near the end (with low sigma) - - **1**: Balanced (proportional to sigma) - - **Above 1**: Stronger preservation in the beginning (with high sigma) - """) - gr.Markdown( - f""" - ### {ui_labels.mask_blend_scale} - - Skews whether partially masked image regions should be more likely to preserve the original content or favor inpainted content. - This may need to be adjusted depending on the {ui_labels.mask_blend_power}, CFG Scale, prompt and Denoising strength. - - - **Low values**: Favors generated content. - - **High values**: Favors original content. - """) - gr.Markdown( - f""" - ### {ui_labels.inpaint_detail_preservation} - - This parameter controls how the original latent vectors and denoised latent vectors are interpolated. - With higher values, the magnitude of the resulting blended vector will be closer to the maximum of the two interpolated vectors. - This can prevent the loss of contrast that occurs with linear interpolation. - - - **Low values**: Softer blending, details may fade. - - **High values**: Stronger contrast, may over-saturate colors. - """) - - return ( - [ - soft_inpainting_enabled, - result.mask_blend_power, - result.mask_blend_scale, - result.inpaint_detail_preservation - ], - [ - (soft_inpainting_enabled, enabled_gen_param_label), - (result.mask_blend_power, gen_param_labels.mask_blend_power), - (result.mask_blend_scale, gen_param_labels.mask_blend_scale), - (result.inpaint_detail_preservation, gen_param_labels.inpaint_detail_preservation) - ] - ) diff --git a/scripts/soft_inpainting.py b/scripts/soft_inpainting.py new file mode 100644 index 00000000000..47e0269bf1b --- /dev/null +++ b/scripts/soft_inpainting.py @@ -0,0 +1,401 @@ +import gradio as gr +from modules.ui_components import InputAccordion +import modules.scripts as scripts + + +class SoftInpaintingSettings: + def __init__(self, mask_blend_power, mask_blend_scale, inpaint_detail_preservation): + self.mask_blend_power = mask_blend_power + self.mask_blend_scale = mask_blend_scale + self.inpaint_detail_preservation = inpaint_detail_preservation + + def add_generation_params(self, dest): + dest[enabled_gen_param_label] = True + dest[gen_param_labels.mask_blend_power] = self.mask_blend_power + dest[gen_param_labels.mask_blend_scale] = self.mask_blend_scale + dest[gen_param_labels.inpaint_detail_preservation] = self.inpaint_detail_preservation + + +# ------------------- Methods ------------------- + + +def latent_blend(soft_inpainting, a, b, t): + """ + Interpolates two latent image representations according to the parameter t, + where the interpolated vectors' magnitudes are also interpolated separately. + The "detail_preservation" factor biases the magnitude interpolation towards + the larger of the two magnitudes. + """ + import torch + + # NOTE: We use inplace operations wherever possible. + + # [4][w][h] to [1][4][w][h] + t2 = t.unsqueeze(0) + # [4][w][h] to [1][1][w][h] - the [4] seem redundant. + t3 = t[0].unsqueeze(0).unsqueeze(0) + + one_minus_t2 = 1 - t2 + one_minus_t3 = 1 - t3 + + # Linearly interpolate the image vectors. + a_scaled = a * one_minus_t2 + b_scaled = b * t2 + image_interp = a_scaled + image_interp.add_(b_scaled) + result_type = image_interp.dtype + del a_scaled, b_scaled, t2, one_minus_t2 + + # Calculate the magnitude of the interpolated vectors. (We will remove this magnitude.) + # 64-bit operations are used here to allow large exponents. + current_magnitude = torch.norm(image_interp, p=2, dim=1, keepdim=True).to(torch.float64).add_(0.00001) + + # Interpolate the powered magnitudes, then un-power them (bring them back to a power of 1). + a_magnitude = torch.norm(a, p=2, dim=1, keepdim=True).to(torch.float64).pow_( + soft_inpainting.inpaint_detail_preservation) * one_minus_t3 + b_magnitude = torch.norm(b, p=2, dim=1, keepdim=True).to(torch.float64).pow_( + soft_inpainting.inpaint_detail_preservation) * t3 + desired_magnitude = a_magnitude + desired_magnitude.add_(b_magnitude).pow_(1 / soft_inpainting.inpaint_detail_preservation) + del a_magnitude, b_magnitude, t3, one_minus_t3 + + # Change the linearly interpolated image vectors' magnitudes to the value we want. + # This is the last 64-bit operation. + image_interp_scaling_factor = desired_magnitude + image_interp_scaling_factor.div_(current_magnitude) + image_interp_scaling_factor = image_interp_scaling_factor.to(result_type) + image_interp_scaled = image_interp + image_interp_scaled.mul_(image_interp_scaling_factor) + del current_magnitude + del desired_magnitude + del image_interp + del image_interp_scaling_factor + del result_type + + return image_interp_scaled + + +def get_modified_nmask(soft_inpainting, nmask, sigma): + """ + Converts a negative mask representing the transparency of the original latent vectors being overlayed + to a mask that is scaled according to the denoising strength for this step. + + Where: + 0 = fully opaque, infinite density, fully masked + 1 = fully transparent, zero density, fully unmasked + + We bring this transparency to a power, as this allows one to simulate N number of blending operations + where N can be any positive real value. Using this one can control the balance of influence between + the denoiser and the original latents according to the sigma value. + + NOTE: "mask" is not used + """ + import torch + return torch.pow(nmask, (sigma ** soft_inpainting.mask_blend_power) * soft_inpainting.mask_blend_scale) + + +def apply_adaptive_masks( + latent_orig, + latent_processed, + overlay_images, + width, height, + paste_to): + import torch + import numpy as np + import modules.processing as proc + import modules.images as images + from PIL import Image, ImageOps, ImageFilter + + # TODO: Bias the blending according to the latent mask, add adjustable parameter for bias control. + # latent_mask = p.nmask[0].float().cpu() + # convert the original mask into a form we use to scale distances for thresholding + # mask_scalar = 1-(torch.clamp(latent_mask, min=0, max=1) ** (p.mask_blend_scale / 2)) + # mask_scalar = mask_scalar / (1.00001-mask_scalar) + # mask_scalar = mask_scalar.numpy() + + latent_distance = torch.norm(latent_processed - latent_orig, p=2, dim=1) + + kernel, kernel_center = images.get_gaussian_kernel(stddev_radius=1.5, max_radius=2) + + masks_for_overlay = [] + + for i, (distance_map, overlay_image) in enumerate(zip(latent_distance, overlay_images)): + converted_mask = distance_map.float().cpu().numpy() + converted_mask = images.weighted_histogram_filter(converted_mask, kernel, kernel_center, + percentile_min=0.9, percentile_max=1, min_width=1) + converted_mask = images.weighted_histogram_filter(converted_mask, kernel, kernel_center, + percentile_min=0.25, percentile_max=0.75, min_width=1) + + # The distance at which opacity of original decreases to 50% + # half_weighted_distance = 1 # * mask_scalar + # converted_mask = converted_mask / half_weighted_distance + + converted_mask = 1 / (1 + converted_mask ** 2) + converted_mask = images.smootherstep(converted_mask) + converted_mask = 1 - converted_mask + converted_mask = 255. * converted_mask + converted_mask = converted_mask.astype(np.uint8) + converted_mask = Image.fromarray(converted_mask) + converted_mask = images.resize_image(2, converted_mask, width, height) + converted_mask = proc.create_binary_mask(converted_mask, round=False) + + # Remove aliasing artifacts using a gaussian blur. + converted_mask = converted_mask.filter(ImageFilter.GaussianBlur(radius=4)) + + # Expand the mask to fit the whole image if needed. + if paste_to is not None: + converted_mask = proc.uncrop(converted_mask, + (overlay_image.width, overlay_image.height), + paste_to) + + masks_for_overlay.append(converted_mask) + + image_masked = Image.new('RGBa', (overlay_image.width, overlay_image.height)) + image_masked.paste(overlay_image.convert("RGBA").convert("RGBa"), + mask=ImageOps.invert(converted_mask.convert('L'))) + + overlay_images[i] = image_masked.convert('RGBA') + + return masks_for_overlay + + +def apply_masks( + soft_inpainting, + nmask, + overlay_images, + width, height, + paste_to): + import torch + import numpy as np + import modules.processing as proc + import modules.images as images + from PIL import Image, ImageOps, ImageFilter + + converted_mask = nmask[0].float() + converted_mask = torch.clamp(converted_mask, min=0, max=1).pow_(soft_inpainting.mask_blend_scale / 2) + converted_mask = 255. * converted_mask + converted_mask = converted_mask.cpu().numpy().astype(np.uint8) + converted_mask = Image.fromarray(converted_mask) + converted_mask = images.resize_image(2, converted_mask, width, height) + converted_mask = proc.create_binary_mask(converted_mask, round=False) + + # Remove aliasing artifacts using a gaussian blur. + converted_mask = converted_mask.filter(ImageFilter.GaussianBlur(radius=4)) + + # Expand the mask to fit the whole image if needed. + if paste_to is not None: + converted_mask = proc.uncrop(converted_mask, + (width, height), + paste_to) + + masks_for_overlay = [] + + for i, overlay_image in enumerate(overlay_images): + masks_for_overlay[i] = converted_mask + + image_masked = Image.new('RGBa', (overlay_image.width, overlay_image.height)) + image_masked.paste(overlay_image.convert("RGBA").convert("RGBa"), + mask=ImageOps.invert(converted_mask.convert('L'))) + + overlay_images[i] = image_masked.convert('RGBA') + + return masks_for_overlay + + +# ------------------- Constants ------------------- + + +default = SoftInpaintingSettings(1, 0.5, 4) + +enabled_ui_label = "Soft inpainting" +enabled_gen_param_label = "Soft inpainting enabled" +enabled_el_id = "soft_inpainting_enabled" + +ui_labels = SoftInpaintingSettings( + "Schedule bias", + "Preservation strength", + "Transition contrast boost") + +ui_info = SoftInpaintingSettings( + "Shifts when preservation of original content occurs during denoising.", + "How strongly partially masked content should be preserved.", + "Amplifies the contrast that may be lost in partially masked regions.") + +gen_param_labels = SoftInpaintingSettings( + "Soft inpainting schedule bias", + "Soft inpainting preservation strength", + "Soft inpainting transition contrast boost") + +el_ids = SoftInpaintingSettings( + "mask_blend_power", + "mask_blend_scale", + "inpaint_detail_preservation") + + +class Script(scripts.Script): + + def __init__(self): + self.masks_for_overlay = None + self.overlay_images = None + + def title(self): + return "Soft Inpainting" + + def show(self, is_img2img): + return scripts.AlwaysVisible if is_img2img else False + + def ui(self, is_img2img): + if not is_img2img: + return + + with InputAccordion(False, label=enabled_ui_label, elem_id=enabled_el_id) as soft_inpainting_enabled: + with gr.Group(): + gr.Markdown( + """ + Soft inpainting allows you to **seamlessly blend original content with inpainted content** according to the mask opacity. + **High _Mask blur_** values are recommended! + """) + + result = SoftInpaintingSettings( + gr.Slider(label=ui_labels.mask_blend_power, + info=ui_info.mask_blend_power, + minimum=0, + maximum=8, + step=0.1, + value=default.mask_blend_power, + elem_id=el_ids.mask_blend_power), + gr.Slider(label=ui_labels.mask_blend_scale, + info=ui_info.mask_blend_scale, + minimum=0, + maximum=8, + step=0.05, + value=default.mask_blend_scale, + elem_id=el_ids.mask_blend_scale), + gr.Slider(label=ui_labels.inpaint_detail_preservation, + info=ui_info.inpaint_detail_preservation, + minimum=1, + maximum=32, + step=0.5, + value=default.inpaint_detail_preservation, + elem_id=el_ids.inpaint_detail_preservation)) + + with gr.Accordion("Help", open=False): + gr.Markdown( + f""" + ### {ui_labels.mask_blend_power} + + The blending strength of original content is scaled proportionally with the decreasing noise level values at each step (sigmas). + This ensures that the influence of the denoiser and original content preservation is roughly balanced at each step. + This balance can be shifted using this parameter, controlling whether earlier or later steps have stronger preservation. + + - **Below 1**: Stronger preservation near the end (with low sigma) + - **1**: Balanced (proportional to sigma) + - **Above 1**: Stronger preservation in the beginning (with high sigma) + """) + gr.Markdown( + f""" + ### {ui_labels.mask_blend_scale} + + Skews whether partially masked image regions should be more likely to preserve the original content or favor inpainted content. + This may need to be adjusted depending on the {ui_labels.mask_blend_power}, CFG Scale, prompt and Denoising strength. + + - **Low values**: Favors generated content. + - **High values**: Favors original content. + """) + gr.Markdown( + f""" + ### {ui_labels.inpaint_detail_preservation} + + This parameter controls how the original latent vectors and denoised latent vectors are interpolated. + With higher values, the magnitude of the resulting blended vector will be closer to the maximum of the two interpolated vectors. + This can prevent the loss of contrast that occurs with linear interpolation. + + - **Low values**: Softer blending, details may fade. + - **High values**: Stronger contrast, may over-saturate colors. + """) + + self.infotext_fields = [(soft_inpainting_enabled, enabled_gen_param_label), + (result.mask_blend_power, gen_param_labels.mask_blend_power), + (result.mask_blend_scale, gen_param_labels.mask_blend_scale), + (result.inpaint_detail_preservation, gen_param_labels.inpaint_detail_preservation)] + + self.paste_field_names = [] + for _, field_name in self.infotext_fields: + self.paste_field_names.append(field_name) + + return [soft_inpainting_enabled, + result.mask_blend_power, + result.mask_blend_scale, + result.inpaint_detail_preservation] + + def process(self, p, enabled, power, scale, detail_preservation): + if not enabled: + return + + # Shut off the rounding it normally does. + p.mask_round = False + + settings = SoftInpaintingSettings(power, scale, detail_preservation) + + # p.extra_generation_params["Mask rounding"] = False + settings.add_generation_params(p.extra_generation_params) + + def on_mask_blend(self, p, mba: scripts.MaskBlendArgs, enabled, power, scale, detail_preservation): + if not enabled: + return + + if mba.sigma is None: + mba.blended_latent = mba.current_latent + return + + settings = SoftInpaintingSettings(power, scale, detail_preservation) + + # todo: Why is sigma 2D? Both values are the same. + mba.blended_latent = latent_blend(settings, + mba.init_latent, + mba.current_latent, + get_modified_nmask(settings, mba.nmask, mba.sigma[0])) + + def post_sample(self, p, ps: scripts.PostSampleArgs, enabled, power, scale, detail_preservation): + if not enabled: + return + + settings = SoftInpaintingSettings(power, scale, detail_preservation) + + from modules import images + from modules.shared import opts + + # since the original code puts holes in the existing overlay images, + # we have to rebuild them. + self.overlay_images = [] + for img in p.init_images: + + image = images.flatten(img, opts.img2img_background_color) + + if p.paste_to is None and p.resize_mode != 3: + image = images.resize_image(p.resize_mode, image, p.width, p.height) + + self.overlay_images.append(image.convert('RGBA')) + + if getattr(ps.samples, 'already_decoded', False): + self.masks_for_overlay = apply_masks(soft_inpainting=settings, + nmask=p.nmask, + overlay_images=self.overlay_images, + width=p.width, + height=p.height, + paste_to=p.paste_to) + else: + self.masks_for_overlay = apply_adaptive_masks(latent_orig=p.init_latent, + latent_processed=ps.samples, + overlay_images=self.overlay_images, + width=p.width, + height=p.height, + paste_to=p.paste_to) + + + def postprocess_maskoverlay(self, p, ppmo: scripts.PostProcessMaskOverlayArgs, enabled, power, scale, detail_preservation): + if not enabled: + return + + ppmo.mask_for_overlay = self.masks_for_overlay[ppmo.index] + ppmo.overlay_image = self.overlay_images[ppmo.index] \ No newline at end of file From 8dbacc7d018774a3bc801cc57617795274a15087 Mon Sep 17 00:00:00 2001 From: CodeHatchling Date: Thu, 7 Dec 2023 14:30:30 -0700 Subject: [PATCH 0402/1174] Fixed "No newline at end of file". --- scripts/soft_inpainting.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/soft_inpainting.py b/scripts/soft_inpainting.py index 47e0269bf1b..6d0cf847952 100644 --- a/scripts/soft_inpainting.py +++ b/scripts/soft_inpainting.py @@ -398,4 +398,4 @@ def postprocess_maskoverlay(self, p, ppmo: scripts.PostProcessMaskOverlayArgs, e return ppmo.mask_for_overlay = self.masks_for_overlay[ppmo.index] - ppmo.overlay_image = self.overlay_images[ppmo.index] \ No newline at end of file + ppmo.overlay_image = self.overlay_images[ppmo.index] From 56604f08a18588e8e6b57d7c3f9c61d6624846f8 Mon Sep 17 00:00:00 2001 From: CodeHatchling Date: Thu, 7 Dec 2023 14:53:44 -0700 Subject: [PATCH 0403/1174] Moved image filters used by soft inpainting into soft_inpainting.py from images.py --- modules/images.py | 190 ---------------------------------- scripts/soft_inpainting.py | 205 +++++++++++++++++++++++++++++++++++-- 2 files changed, 199 insertions(+), 196 deletions(-) diff --git a/modules/images.py b/modules/images.py index 9495349866b..16f9ae7cce1 100644 --- a/modules/images.py +++ b/modules/images.py @@ -792,193 +792,3 @@ def flatten(img, bgcolor): return img.convert('RGB') - -def weighted_histogram_filter(img, kernel, kernel_center, percentile_min=0.0, percentile_max=1.0, min_width=1.0): - """ - Generalization convolution filter capable of applying - weighted mean, median, maximum, and minimum filters - parametrically using an arbitrary kernel. - - Args: - img (nparray): - The image, a 2-D array of floats, to which the filter is being applied. - kernel (nparray): - The kernel, a 2-D array of floats. - kernel_center (nparray): - The kernel center coordinate, a 1-D array with two elements. - percentile_min (float): - The lower bound of the histogram window used by the filter, - from 0 to 1. - percentile_max (float): - The upper bound of the histogram window used by the filter, - from 0 to 1. - min_width (float): - The minimum size of the histogram window bounds, in weight units. - Must be greater than 0. - - Returns: - (nparray): A filtered copy of the input image "img", a 2-D array of floats. - """ - - # Converts an index tuple into a vector. - def vec(x): - return np.array(x) - - kernel_min = -kernel_center - kernel_max = vec(kernel.shape) - kernel_center - - def weighted_histogram_filter_single(idx): - idx = vec(idx) - min_index = np.maximum(0, idx + kernel_min) - max_index = np.minimum(vec(img.shape), idx + kernel_max) - window_shape = max_index - min_index - - class WeightedElement: - """ - An element of the histogram, its weight - and bounds. - """ - def __init__(self, value, weight): - self.value: float = value - self.weight: float = weight - self.window_min: float = 0.0 - self.window_max: float = 1.0 - - # Collect the values in the image as WeightedElements, - # weighted by their corresponding kernel values. - values = [] - for window_tup in np.ndindex(tuple(window_shape)): - window_index = vec(window_tup) - image_index = window_index + min_index - centered_kernel_index = image_index - idx - kernel_index = centered_kernel_index + kernel_center - element = WeightedElement(img[tuple(image_index)], kernel[tuple(kernel_index)]) - values.append(element) - - def sort_key(x: WeightedElement): - return x.value - - values.sort(key=sort_key) - - # Calculate the height of the stack (sum) - # and each sample's range they occupy in the stack - sum = 0 - for i in range(len(values)): - values[i].window_min = sum - sum += values[i].weight - values[i].window_max = sum - - # Calculate what range of this stack ("window") - # we want to get the weighted average across. - window_min = sum * percentile_min - window_max = sum * percentile_max - window_width = window_max - window_min - - # Ensure the window is within the stack and at least a certain size. - if window_width < min_width: - window_center = (window_min + window_max) / 2 - window_min = window_center - min_width / 2 - window_max = window_center + min_width / 2 - - if window_max > sum: - window_max = sum - window_min = sum - min_width - - if window_min < 0: - window_min = 0 - window_max = min_width - - value = 0 - value_weight = 0 - - # Get the weighted average of all the samples - # that overlap with the window, weighted - # by the size of their overlap. - for i in range(len(values)): - if window_min >= values[i].window_max: - continue - if window_max <= values[i].window_min: - break - - s = max(window_min, values[i].window_min) - e = min(window_max, values[i].window_max) - w = e - s - - value += values[i].value * w - value_weight += w - - return value / value_weight if value_weight != 0 else 0 - - img_out = img.copy() - - # Apply the kernel operation over each pixel. - for index in np.ndindex(img.shape): - img_out[index] = weighted_histogram_filter_single(index) - - return img_out - -def smoothstep(x): - """ - The smoothstep function, input should be clamped to 0-1 range. - Turns a diagonal line (f(x) = x) into a sigmoid-like curve. - """ - return x * x * (3 - 2 * x) - -def smootherstep(x): - """ - The smootherstep function, input should be clamped to 0-1 range. - Turns a diagonal line (f(x) = x) into a sigmoid-like curve. - """ - return x * x * x * (x * (6 * x - 15) + 10) - - -def get_gaussian_kernel(stddev_radius=1.0, max_radius=2): - """ - Creates a Gaussian kernel with thresholded edges. - - Args: - stddev_radius (float): - Standard deviation of the gaussian kernel, in pixels. - max_radius (int): - The size of the filter kernel. The number of pixels is (max_radius*2+1) ** 2. - The kernel is thresholded so that any values one pixel beyond this radius - is weighted at 0. - - Returns: - (nparray, nparray): A kernel array (shape: (N, N)), its center coordinate (shape: (2)) - """ - # Evaluates a 0-1 normalized gaussian function for a given square distance from the mean. - def gaussian(sqr_mag): - return math.exp(-sqr_mag / (stddev_radius * stddev_radius)) - - # Helper function for converting a tuple to an array. - def vec(x): - return np.array(x) - - """ - Since a gaussian is unbounded, we need to limit ourselves - to a finite range. - We taper the ends off at the end of that range so they equal zero - while preserving the maximum value of 1 at the mean. - """ - zero_radius = max_radius + 1.0 - gauss_zero = gaussian(zero_radius * zero_radius) - gauss_kernel_scale = 1 / (1 - gauss_zero) - - def gaussian_kernel_func(coordinate): - x = coordinate[0] ** 2.0 + coordinate[1] ** 2.0 - x = gaussian(x) - x -= gauss_zero - x *= gauss_kernel_scale - x = max(0.0, x) - return x - - size = max_radius * 2 + 1 - kernel_center = max_radius - kernel = np.zeros((size, size)) - - for index in np.ndindex(kernel.shape): - kernel[index] = gaussian_kernel_func(vec(index) - kernel_center) - - return kernel, kernel_center - diff --git a/scripts/soft_inpainting.py b/scripts/soft_inpainting.py index 6d0cf847952..1f451b55370 100644 --- a/scripts/soft_inpainting.py +++ b/scripts/soft_inpainting.py @@ -1,4 +1,6 @@ +import numpy as np import gradio as gr +import math from modules.ui_components import InputAccordion import modules.scripts as scripts @@ -101,7 +103,6 @@ def apply_adaptive_masks( width, height, paste_to): import torch - import numpy as np import modules.processing as proc import modules.images as images from PIL import Image, ImageOps, ImageFilter @@ -115,15 +116,15 @@ def apply_adaptive_masks( latent_distance = torch.norm(latent_processed - latent_orig, p=2, dim=1) - kernel, kernel_center = images.get_gaussian_kernel(stddev_radius=1.5, max_radius=2) + kernel, kernel_center = get_gaussian_kernel(stddev_radius=1.5, max_radius=2) masks_for_overlay = [] for i, (distance_map, overlay_image) in enumerate(zip(latent_distance, overlay_images)): converted_mask = distance_map.float().cpu().numpy() - converted_mask = images.weighted_histogram_filter(converted_mask, kernel, kernel_center, + converted_mask = weighted_histogram_filter(converted_mask, kernel, kernel_center, percentile_min=0.9, percentile_max=1, min_width=1) - converted_mask = images.weighted_histogram_filter(converted_mask, kernel, kernel_center, + converted_mask = weighted_histogram_filter(converted_mask, kernel, kernel_center, percentile_min=0.25, percentile_max=0.75, min_width=1) # The distance at which opacity of original decreases to 50% @@ -131,7 +132,7 @@ def apply_adaptive_masks( # converted_mask = converted_mask / half_weighted_distance converted_mask = 1 / (1 + converted_mask ** 2) - converted_mask = images.smootherstep(converted_mask) + converted_mask = smootherstep(converted_mask) converted_mask = 1 - converted_mask converted_mask = 255. * converted_mask converted_mask = converted_mask.astype(np.uint8) @@ -166,7 +167,6 @@ def apply_masks( width, height, paste_to): import torch - import numpy as np import modules.processing as proc import modules.images as images from PIL import Image, ImageOps, ImageFilter @@ -202,6 +202,196 @@ def apply_masks( return masks_for_overlay +def weighted_histogram_filter(img, kernel, kernel_center, percentile_min=0.0, percentile_max=1.0, min_width=1.0): + """ + Generalization convolution filter capable of applying + weighted mean, median, maximum, and minimum filters + parametrically using an arbitrary kernel. + + Args: + img (nparray): + The image, a 2-D array of floats, to which the filter is being applied. + kernel (nparray): + The kernel, a 2-D array of floats. + kernel_center (nparray): + The kernel center coordinate, a 1-D array with two elements. + percentile_min (float): + The lower bound of the histogram window used by the filter, + from 0 to 1. + percentile_max (float): + The upper bound of the histogram window used by the filter, + from 0 to 1. + min_width (float): + The minimum size of the histogram window bounds, in weight units. + Must be greater than 0. + + Returns: + (nparray): A filtered copy of the input image "img", a 2-D array of floats. + """ + + # Converts an index tuple into a vector. + def vec(x): + return np.array(x) + + kernel_min = -kernel_center + kernel_max = vec(kernel.shape) - kernel_center + + def weighted_histogram_filter_single(idx): + idx = vec(idx) + min_index = np.maximum(0, idx + kernel_min) + max_index = np.minimum(vec(img.shape), idx + kernel_max) + window_shape = max_index - min_index + + class WeightedElement: + """ + An element of the histogram, its weight + and bounds. + """ + def __init__(self, value, weight): + self.value: float = value + self.weight: float = weight + self.window_min: float = 0.0 + self.window_max: float = 1.0 + + # Collect the values in the image as WeightedElements, + # weighted by their corresponding kernel values. + values = [] + for window_tup in np.ndindex(tuple(window_shape)): + window_index = vec(window_tup) + image_index = window_index + min_index + centered_kernel_index = image_index - idx + kernel_index = centered_kernel_index + kernel_center + element = WeightedElement(img[tuple(image_index)], kernel[tuple(kernel_index)]) + values.append(element) + + def sort_key(x: WeightedElement): + return x.value + + values.sort(key=sort_key) + + # Calculate the height of the stack (sum) + # and each sample's range they occupy in the stack + sum = 0 + for i in range(len(values)): + values[i].window_min = sum + sum += values[i].weight + values[i].window_max = sum + + # Calculate what range of this stack ("window") + # we want to get the weighted average across. + window_min = sum * percentile_min + window_max = sum * percentile_max + window_width = window_max - window_min + + # Ensure the window is within the stack and at least a certain size. + if window_width < min_width: + window_center = (window_min + window_max) / 2 + window_min = window_center - min_width / 2 + window_max = window_center + min_width / 2 + + if window_max > sum: + window_max = sum + window_min = sum - min_width + + if window_min < 0: + window_min = 0 + window_max = min_width + + value = 0 + value_weight = 0 + + # Get the weighted average of all the samples + # that overlap with the window, weighted + # by the size of their overlap. + for i in range(len(values)): + if window_min >= values[i].window_max: + continue + if window_max <= values[i].window_min: + break + + s = max(window_min, values[i].window_min) + e = min(window_max, values[i].window_max) + w = e - s + + value += values[i].value * w + value_weight += w + + return value / value_weight if value_weight != 0 else 0 + + img_out = img.copy() + + # Apply the kernel operation over each pixel. + for index in np.ndindex(img.shape): + img_out[index] = weighted_histogram_filter_single(index) + + return img_out + +def smoothstep(x): + """ + The smoothstep function, input should be clamped to 0-1 range. + Turns a diagonal line (f(x) = x) into a sigmoid-like curve. + """ + return x * x * (3 - 2 * x) + +def smootherstep(x): + """ + The smootherstep function, input should be clamped to 0-1 range. + Turns a diagonal line (f(x) = x) into a sigmoid-like curve. + """ + return x * x * x * (x * (6 * x - 15) + 10) + + +def get_gaussian_kernel(stddev_radius=1.0, max_radius=2): + """ + Creates a Gaussian kernel with thresholded edges. + + Args: + stddev_radius (float): + Standard deviation of the gaussian kernel, in pixels. + max_radius (int): + The size of the filter kernel. The number of pixels is (max_radius*2+1) ** 2. + The kernel is thresholded so that any values one pixel beyond this radius + is weighted at 0. + + Returns: + (nparray, nparray): A kernel array (shape: (N, N)), its center coordinate (shape: (2)) + """ + # Evaluates a 0-1 normalized gaussian function for a given square distance from the mean. + def gaussian(sqr_mag): + return math.exp(-sqr_mag / (stddev_radius * stddev_radius)) + + # Helper function for converting a tuple to an array. + def vec(x): + return np.array(x) + + """ + Since a gaussian is unbounded, we need to limit ourselves + to a finite range. + We taper the ends off at the end of that range so they equal zero + while preserving the maximum value of 1 at the mean. + """ + zero_radius = max_radius + 1.0 + gauss_zero = gaussian(zero_radius * zero_radius) + gauss_kernel_scale = 1 / (1 - gauss_zero) + + def gaussian_kernel_func(coordinate): + x = coordinate[0] ** 2.0 + coordinate[1] ** 2.0 + x = gaussian(x) + x -= gauss_zero + x *= gauss_kernel_scale + x = max(0.0, x) + return x + + size = max_radius * 2 + 1 + kernel_center = max_radius + kernel = np.zeros((size, size)) + + for index in np.ndindex(kernel.shape): + kernel[index] = gaussian_kernel_func(vec(index) - kernel_center) + + return kernel, kernel_center + + # ------------------- Constants ------------------- @@ -232,6 +422,9 @@ def apply_masks( "inpaint_detail_preservation") +# ----- + + class Script(scripts.Script): def __init__(self): From 0ef4a4cb2365051b1e308f0136a0d8c01d071569 Mon Sep 17 00:00:00 2001 From: CodeHatchling Date: Thu, 7 Dec 2023 14:54:26 -0700 Subject: [PATCH 0404/1174] Fixed error that occurs when using vanilla samplers (somehow). --- modules/processing.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/modules/processing.py b/modules/processing.py index f8d85bdf527..bea01ec6884 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -937,8 +937,8 @@ def infotext(index=0, use_main_prompt=False): p.scripts.postprocess_image(p, pp) image = pp.image - mask_for_overlay = p.mask_for_overlay - overlay_image = p.overlay_images[i] if p.overlay_images is not None and i < len(p.overlay_images) else None + mask_for_overlay = getattr(p, "mask_for_overlay", None) + overlay_image = p.overlay_images[i] if getattr(p, "overlay_images", None) is not None and i < len(p.overlay_images) else None if p.scripts is not None: ppmo = scripts.PostProcessMaskOverlayArgs(i, mask_for_overlay, overlay_image) From f284ae23bcdfa212cf4763659c06e124ec5b1456 Mon Sep 17 00:00:00 2001 From: CodeHatchling Date: Thu, 7 Dec 2023 20:19:35 -0700 Subject: [PATCH 0405/1174] Added parameters for the composite stage, fixed batched generation. --- scripts/soft_inpainting.py | 198 +++++++++++++++++++++++++++++-------- 1 file changed, 155 insertions(+), 43 deletions(-) diff --git a/scripts/soft_inpainting.py b/scripts/soft_inpainting.py index 1f451b55370..1b21aee9db9 100644 --- a/scripts/soft_inpainting.py +++ b/scripts/soft_inpainting.py @@ -6,22 +6,34 @@ class SoftInpaintingSettings: - def __init__(self, mask_blend_power, mask_blend_scale, inpaint_detail_preservation): + def __init__(self, + mask_blend_power, + mask_blend_scale, + inpaint_detail_preservation, + composite_mask_influence, + composite_difference_threshold, + composite_difference_contrast): self.mask_blend_power = mask_blend_power self.mask_blend_scale = mask_blend_scale self.inpaint_detail_preservation = inpaint_detail_preservation + self.composite_mask_influence = composite_mask_influence + self.composite_difference_threshold = composite_difference_threshold + self.composite_difference_contrast = composite_difference_contrast def add_generation_params(self, dest): dest[enabled_gen_param_label] = True dest[gen_param_labels.mask_blend_power] = self.mask_blend_power dest[gen_param_labels.mask_blend_scale] = self.mask_blend_scale dest[gen_param_labels.inpaint_detail_preservation] = self.inpaint_detail_preservation + dest[gen_param_labels.composite_mask_influence] = self.composite_mask_influence + dest[gen_param_labels.composite_difference_threshold] = self.composite_difference_threshold + dest[gen_param_labels.composite_difference_contrast] = self.composite_difference_contrast # ------------------- Methods ------------------- -def latent_blend(soft_inpainting, a, b, t): +def latent_blend(settings, a, b, t): """ Interpolates two latent image representations according to the parameter t, where the interpolated vectors' magnitudes are also interpolated separately. @@ -54,11 +66,11 @@ def latent_blend(soft_inpainting, a, b, t): # Interpolate the powered magnitudes, then un-power them (bring them back to a power of 1). a_magnitude = torch.norm(a, p=2, dim=1, keepdim=True).to(torch.float64).pow_( - soft_inpainting.inpaint_detail_preservation) * one_minus_t3 + settings.inpaint_detail_preservation) * one_minus_t3 b_magnitude = torch.norm(b, p=2, dim=1, keepdim=True).to(torch.float64).pow_( - soft_inpainting.inpaint_detail_preservation) * t3 + settings.inpaint_detail_preservation) * t3 desired_magnitude = a_magnitude - desired_magnitude.add_(b_magnitude).pow_(1 / soft_inpainting.inpaint_detail_preservation) + desired_magnitude.add_(b_magnitude).pow_(1 / settings.inpaint_detail_preservation) del a_magnitude, b_magnitude, t3, one_minus_t3 # Change the linearly interpolated image vectors' magnitudes to the value we want. @@ -77,7 +89,7 @@ def latent_blend(soft_inpainting, a, b, t): return image_interp_scaled -def get_modified_nmask(soft_inpainting, nmask, sigma): +def get_modified_nmask(settings, nmask, sigma): """ Converts a negative mask representing the transparency of the original latent vectors being overlayed to a mask that is scaled according to the denoising strength for this step. @@ -93,10 +105,12 @@ def get_modified_nmask(soft_inpainting, nmask, sigma): NOTE: "mask" is not used """ import torch - return torch.pow(nmask, (sigma ** soft_inpainting.mask_blend_power) * soft_inpainting.mask_blend_scale) + return torch.pow(nmask, (sigma ** settings.mask_blend_power) * settings.mask_blend_scale) def apply_adaptive_masks( + settings:SoftInpaintingSettings, + nmask, latent_orig, latent_processed, overlay_images, @@ -108,11 +122,13 @@ def apply_adaptive_masks( from PIL import Image, ImageOps, ImageFilter # TODO: Bias the blending according to the latent mask, add adjustable parameter for bias control. - # latent_mask = p.nmask[0].float().cpu() + latent_mask = nmask[0].float() # convert the original mask into a form we use to scale distances for thresholding - # mask_scalar = 1-(torch.clamp(latent_mask, min=0, max=1) ** (p.mask_blend_scale / 2)) - # mask_scalar = mask_scalar / (1.00001-mask_scalar) - # mask_scalar = mask_scalar.numpy() + mask_scalar = 1-(torch.clamp(latent_mask, min=0, max=1) ** (settings.mask_blend_scale / 2)) + mask_scalar = (0.5 * (1-settings.composite_mask_influence) + + mask_scalar * settings.composite_mask_influence) + mask_scalar = mask_scalar / (1.00001-mask_scalar) + mask_scalar = mask_scalar.cpu().numpy() latent_distance = torch.norm(latent_processed - latent_orig, p=2, dim=1) @@ -128,10 +144,10 @@ def apply_adaptive_masks( percentile_min=0.25, percentile_max=0.75, min_width=1) # The distance at which opacity of original decreases to 50% - # half_weighted_distance = 1 # * mask_scalar - # converted_mask = converted_mask / half_weighted_distance + half_weighted_distance = settings.composite_difference_threshold * mask_scalar + converted_mask = converted_mask / half_weighted_distance - converted_mask = 1 / (1 + converted_mask ** 2) + converted_mask = 1 / (1 + converted_mask ** settings.composite_difference_contrast) converted_mask = smootherstep(converted_mask) converted_mask = 1 - converted_mask converted_mask = 255. * converted_mask @@ -161,7 +177,7 @@ def apply_adaptive_masks( def apply_masks( - soft_inpainting, + settings, nmask, overlay_images, width, height, @@ -172,7 +188,7 @@ def apply_masks( from PIL import Image, ImageOps, ImageFilter converted_mask = nmask[0].float() - converted_mask = torch.clamp(converted_mask, min=0, max=1).pow_(soft_inpainting.mask_blend_scale / 2) + converted_mask = torch.clamp(converted_mask, min=0, max=1).pow_(settings.mask_blend_scale / 2) converted_mask = 255. * converted_mask converted_mask = converted_mask.cpu().numpy().astype(np.uint8) converted_mask = Image.fromarray(converted_mask) @@ -395,7 +411,7 @@ def gaussian_kernel_func(coordinate): # ------------------- Constants ------------------- -default = SoftInpaintingSettings(1, 0.5, 4) +default = SoftInpaintingSettings(1, 0.5, 4, 0, 0.5, 2) enabled_ui_label = "Soft inpainting" enabled_gen_param_label = "Soft inpainting enabled" @@ -404,25 +420,37 @@ def gaussian_kernel_func(coordinate): ui_labels = SoftInpaintingSettings( "Schedule bias", "Preservation strength", - "Transition contrast boost") + "Transition contrast boost", + "Mask influence", + "Difference threshold", + "Difference contrast") ui_info = SoftInpaintingSettings( "Shifts when preservation of original content occurs during denoising.", "How strongly partially masked content should be preserved.", - "Amplifies the contrast that may be lost in partially masked regions.") + "Amplifies the contrast that may be lost in partially masked regions.", + "How strongly the original mask should bias the difference threshold.", + "How much an image region can change before the original pixels are not blended in anymore.", + "How sharp the transition should be between blended and not blended.") gen_param_labels = SoftInpaintingSettings( "Soft inpainting schedule bias", "Soft inpainting preservation strength", - "Soft inpainting transition contrast boost") + "Soft inpainting transition contrast boost", + "Soft inpainting mask influence", + "Soft inpainting difference threshold", + "Soft inpainting difference contrast") el_ids = SoftInpaintingSettings( "mask_blend_power", "mask_blend_scale", - "inpaint_detail_preservation") + "inpaint_detail_preservation", + "composite_mask_influence", + "composite_difference_threshold", + "composite_difference_contrast") -# ----- +# ------------------- Script ------------------- class Script(scripts.Script): @@ -449,28 +477,62 @@ def ui(self, is_img2img): **High _Mask blur_** values are recommended! """) - result = SoftInpaintingSettings( + power = \ gr.Slider(label=ui_labels.mask_blend_power, info=ui_info.mask_blend_power, minimum=0, maximum=8, step=0.1, value=default.mask_blend_power, - elem_id=el_ids.mask_blend_power), + elem_id=el_ids.mask_blend_power) + scale = \ gr.Slider(label=ui_labels.mask_blend_scale, info=ui_info.mask_blend_scale, minimum=0, maximum=8, step=0.05, value=default.mask_blend_scale, - elem_id=el_ids.mask_blend_scale), + elem_id=el_ids.mask_blend_scale) + detail = \ gr.Slider(label=ui_labels.inpaint_detail_preservation, info=ui_info.inpaint_detail_preservation, minimum=1, maximum=32, step=0.5, value=default.inpaint_detail_preservation, - elem_id=el_ids.inpaint_detail_preservation)) + elem_id=el_ids.inpaint_detail_preservation) + + gr.Markdown( + """ + ### Pixel Composite Settings + """) + + mask_inf = \ + gr.Slider(label=ui_labels.composite_mask_influence, + info=ui_info.composite_mask_influence, + minimum=0, + maximum=1, + step=0.05, + value=default.composite_mask_influence, + elem_id=el_ids.composite_mask_influence) + + dif_thresh = \ + gr.Slider(label=ui_labels.composite_difference_threshold, + info=ui_info.composite_difference_threshold, + minimum=0, + maximum=8, + step=0.25, + value=default.composite_difference_threshold, + elem_id=el_ids.composite_difference_threshold) + + dif_contr = \ + gr.Slider(label=ui_labels.composite_difference_contrast, + info=ui_info.composite_difference_contrast, + minimum=0, + maximum=8, + step=0.25, + value=default.composite_difference_contrast, + elem_id=el_ids.composite_difference_contrast) with gr.Accordion("Help", open=False): gr.Markdown( @@ -507,41 +569,86 @@ def ui(self, is_img2img): - **High values**: Stronger contrast, may over-saturate colors. """) + gr.Markdown( + """ + ## Pixel Composite Settings + + Masks are generated based on how much a part of the image changed after denoising. + These masks are used to blend the original and final images together. + If the difference is low, the original pixels are used instead of the pixels returned by the inpainting process. + """) + + gr.Markdown( + f""" + ### {ui_labels.composite_mask_influence} + + This parameter controls how much the mask should bias this sensitivity to difference. + + - **0**: Ignore the mask, only consider differences in image content. + - **1**: Follow the mask closely despite image content changes. + """) + + gr.Markdown( + f""" + ### {ui_labels.composite_difference_threshold} + + This value represents the difference at which the opacity of the original pixels will have less than 50% opacity. + + - **Low values**: Two images patches must be almost the same in order to retain original pixels. + - **High values**: Two images patches can be very different and still retain original pixels. + """) + + gr.Markdown( + f""" + ### {ui_labels.composite_difference_contrast} + + This value represents the difference at which the opacity of the original pixels will have less than 50% opacity. + + - **Low values**: Two images patches must be almost the same in order to retain original pixels. + - **High values**: Two images patches can be very different and still retain original pixels. + """) + self.infotext_fields = [(soft_inpainting_enabled, enabled_gen_param_label), - (result.mask_blend_power, gen_param_labels.mask_blend_power), - (result.mask_blend_scale, gen_param_labels.mask_blend_scale), - (result.inpaint_detail_preservation, gen_param_labels.inpaint_detail_preservation)] + (power, gen_param_labels.mask_blend_power), + (scale, gen_param_labels.mask_blend_scale), + (detail, gen_param_labels.inpaint_detail_preservation), + (mask_inf, gen_param_labels.composite_mask_influence), + (dif_thresh, gen_param_labels.composite_difference_threshold), + (dif_contr, gen_param_labels.composite_difference_contrast)] self.paste_field_names = [] for _, field_name in self.infotext_fields: self.paste_field_names.append(field_name) return [soft_inpainting_enabled, - result.mask_blend_power, - result.mask_blend_scale, - result.inpaint_detail_preservation] - - def process(self, p, enabled, power, scale, detail_preservation): + power, + scale, + detail, + mask_inf, + dif_thresh, + dif_contr] + + def process(self, p, enabled, power, scale, detail_preservation, mask_inf, dif_thresh, dif_contr): if not enabled: return # Shut off the rounding it normally does. p.mask_round = False - settings = SoftInpaintingSettings(power, scale, detail_preservation) + settings = SoftInpaintingSettings(power, scale, detail_preservation, mask_inf, dif_thresh, dif_contr) # p.extra_generation_params["Mask rounding"] = False settings.add_generation_params(p.extra_generation_params) - def on_mask_blend(self, p, mba: scripts.MaskBlendArgs, enabled, power, scale, detail_preservation): + def on_mask_blend(self, p, mba: scripts.MaskBlendArgs, enabled, power, scale, detail_preservation, mask_inf, dif_thresh, dif_contr): if not enabled: return - if mba.sigma is None: + if mba.is_final_blend: mba.blended_latent = mba.current_latent return - settings = SoftInpaintingSettings(power, scale, detail_preservation) + settings = SoftInpaintingSettings(power, scale, detail_preservation, mask_inf, dif_thresh, dif_contr) # todo: Why is sigma 2D? Both values are the same. mba.blended_latent = latent_blend(settings, @@ -549,11 +656,11 @@ def on_mask_blend(self, p, mba: scripts.MaskBlendArgs, enabled, power, scale, de mba.current_latent, get_modified_nmask(settings, mba.nmask, mba.sigma[0])) - def post_sample(self, p, ps: scripts.PostSampleArgs, enabled, power, scale, detail_preservation): + def post_sample(self, p, ps: scripts.PostSampleArgs, enabled, power, scale, detail_preservation, mask_inf, dif_thresh, dif_contr): if not enabled: return - settings = SoftInpaintingSettings(power, scale, detail_preservation) + settings = SoftInpaintingSettings(power, scale, detail_preservation, mask_inf, dif_thresh, dif_contr) from modules import images from modules.shared import opts @@ -570,15 +677,20 @@ def post_sample(self, p, ps: scripts.PostSampleArgs, enabled, power, scale, deta self.overlay_images.append(image.convert('RGBA')) + if len(p.init_images) == 1: + self.overlay_images = self.overlay_images * p.batch_size + if getattr(ps.samples, 'already_decoded', False): - self.masks_for_overlay = apply_masks(soft_inpainting=settings, + self.masks_for_overlay = apply_masks(settings=settings, nmask=p.nmask, overlay_images=self.overlay_images, width=p.width, height=p.height, paste_to=p.paste_to) else: - self.masks_for_overlay = apply_adaptive_masks(latent_orig=p.init_latent, + self.masks_for_overlay = apply_adaptive_masks(settings=settings, + nmask=p.nmask, + latent_orig=p.init_latent, latent_processed=ps.samples, overlay_images=self.overlay_images, width=p.width, @@ -586,7 +698,7 @@ def post_sample(self, p, ps: scripts.PostSampleArgs, enabled, power, scale, deta paste_to=p.paste_to) - def postprocess_maskoverlay(self, p, ppmo: scripts.PostProcessMaskOverlayArgs, enabled, power, scale, detail_preservation): + def postprocess_maskoverlay(self, p, ppmo: scripts.PostProcessMaskOverlayArgs, enabled, power, scale, detail_preservation, mask_inf, dif_thresh, dif_contr): if not enabled: return From fc3e246c0f4f292c33b181a902cd934629ff0d7a Mon Sep 17 00:00:00 2001 From: CodeHatchling Date: Thu, 7 Dec 2023 20:28:38 -0700 Subject: [PATCH 0406/1174] Fixed complaint about whitespace, updated help section for a parameter. --- scripts/soft_inpainting.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/scripts/soft_inpainting.py b/scripts/soft_inpainting.py index 1b21aee9db9..6fb5cfbd074 100644 --- a/scripts/soft_inpainting.py +++ b/scripts/soft_inpainting.py @@ -572,7 +572,7 @@ def ui(self, is_img2img): gr.Markdown( """ ## Pixel Composite Settings - + Masks are generated based on how much a part of the image changed after denoising. These masks are used to blend the original and final images together. If the difference is low, the original pixels are used instead of the pixels returned by the inpainting process. @@ -602,10 +602,10 @@ def ui(self, is_img2img): f""" ### {ui_labels.composite_difference_contrast} - This value represents the difference at which the opacity of the original pixels will have less than 50% opacity. + This value represents the contrast between the opacity of the original and inpainted content. - - **Low values**: Two images patches must be almost the same in order to retain original pixels. - - **High values**: Two images patches can be very different and still retain original pixels. + - **Low values**: The blend will be more gradual and have longer transitions, but may cause ghosting. + - **High values**: Ghosting will be less common, but transitions may be very sudden. """) self.infotext_fields = [(soft_inpainting_enabled, enabled_gen_param_label), From 659f62e120b210e3043712ff928e8b7b6cd6cf61 Mon Sep 17 00:00:00 2001 From: CodeHatchling Date: Thu, 7 Dec 2023 21:39:54 -0700 Subject: [PATCH 0407/1174] Fixed grammar error. --- scripts/soft_inpainting.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/soft_inpainting.py b/scripts/soft_inpainting.py index 6fb5cfbd074..51f9ca2fed2 100644 --- a/scripts/soft_inpainting.py +++ b/scripts/soft_inpainting.py @@ -592,7 +592,7 @@ def ui(self, is_img2img): f""" ### {ui_labels.composite_difference_threshold} - This value represents the difference at which the opacity of the original pixels will have less than 50% opacity. + This value represents the difference at which the original pixels will have less than 50% opacity. - **Low values**: Two images patches must be almost the same in order to retain original pixels. - **High values**: Two images patches can be very different and still retain original pixels. From 16bdcce92d5b482d50cdc32a8f308040d320b6c9 Mon Sep 17 00:00:00 2001 From: Rene Kroon Date: Fri, 8 Dec 2023 21:19:29 +0100 Subject: [PATCH 0408/1174] #13354: solve lora loading issue --- extensions-builtin/Lora/networks.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/extensions-builtin/Lora/networks.py b/extensions-builtin/Lora/networks.py index 7f814706adb..629bf85376d 100644 --- a/extensions-builtin/Lora/networks.py +++ b/extensions-builtin/Lora/networks.py @@ -159,7 +159,8 @@ def load_network(name, network_on_disk): bundle_embeddings = {} for key_network, weight in sd.items(): - key_network_without_network_parts, network_part = key_network.split(".", 1) + key_network_without_network_parts, _, network_part = key_network.partition(".") + if key_network_without_network_parts == "bundle_emb": emb_name, vec_name = network_part.split(".", 1) emb_dict = bundle_embeddings.get(emb_name, {}) From b2414476ef164ba55cff2508c58b73d23bbc3000 Mon Sep 17 00:00:00 2001 From: CodeHatchling Date: Fri, 8 Dec 2023 17:32:41 -0700 Subject: [PATCH 0409/1174] soft_inpainting now appears in the "inpaint" section, and will not activate unless inpainting is activated. --- scripts/soft_inpainting.py | 43 ++++++++++++++++++++++++++++++++++---- 1 file changed, 39 insertions(+), 4 deletions(-) diff --git a/scripts/soft_inpainting.py b/scripts/soft_inpainting.py index 51f9ca2fed2..f10a1e5627c 100644 --- a/scripts/soft_inpainting.py +++ b/scripts/soft_inpainting.py @@ -32,6 +32,19 @@ def add_generation_params(self, dest): # ------------------- Methods ------------------- +def processing_uses_inpainting(p): + # TODO: Figure out a better way to determine if inpainting is being used by p + if getattr(p, "image_mask", None) is not None: + return True + + if getattr(p, "mask", None) is not None: + return True + + if getattr(p, "nmask", None) is not None: + return True + + return False + def latent_blend(settings, a, b, t): """ @@ -454,8 +467,8 @@ def gaussian_kernel_func(coordinate): class Script(scripts.Script): - def __init__(self): + self.section = "inpaint" self.masks_for_overlay = None self.overlay_images = None @@ -632,6 +645,9 @@ def process(self, p, enabled, power, scale, detail_preservation, mask_inf, dif_t if not enabled: return + if not processing_uses_inpainting(p): + return + # Shut off the rounding it normally does. p.mask_round = False @@ -644,6 +660,9 @@ def on_mask_blend(self, p, mba: scripts.MaskBlendArgs, enabled, power, scale, de if not enabled: return + if not processing_uses_inpainting(p): + return + if mba.is_final_blend: mba.blended_latent = mba.current_latent return @@ -660,11 +679,18 @@ def post_sample(self, p, ps: scripts.PostSampleArgs, enabled, power, scale, deta if not enabled: return - settings = SoftInpaintingSettings(power, scale, detail_preservation, mask_inf, dif_thresh, dif_contr) + if not processing_uses_inpainting(p): + return + + nmask = getattr(p, "nmask", None) + if nmask is None: + return from modules import images from modules.shared import opts + settings = SoftInpaintingSettings(power, scale, detail_preservation, mask_inf, dif_thresh, dif_contr) + # since the original code puts holes in the existing overlay images, # we have to rebuild them. self.overlay_images = [] @@ -682,14 +708,14 @@ def post_sample(self, p, ps: scripts.PostSampleArgs, enabled, power, scale, deta if getattr(ps.samples, 'already_decoded', False): self.masks_for_overlay = apply_masks(settings=settings, - nmask=p.nmask, + nmask=nmask, overlay_images=self.overlay_images, width=p.width, height=p.height, paste_to=p.paste_to) else: self.masks_for_overlay = apply_adaptive_masks(settings=settings, - nmask=p.nmask, + nmask=nmask, latent_orig=p.init_latent, latent_processed=ps.samples, overlay_images=self.overlay_images, @@ -702,5 +728,14 @@ def postprocess_maskoverlay(self, p, ppmo: scripts.PostProcessMaskOverlayArgs, e if not enabled: return + if not processing_uses_inpainting(p): + return + + if self.masks_for_overlay is None: + return + + if self.overlay_images is None: + return + ppmo.mask_for_overlay = self.masks_for_overlay[ppmo.index] ppmo.overlay_image = self.overlay_images[ppmo.index] From f1ff932cafa2bf34fa35f41072f21a8ea5474d84 Mon Sep 17 00:00:00 2001 From: CodeHatchling Date: Fri, 8 Dec 2023 17:33:11 -0700 Subject: [PATCH 0410/1174] Formatted soft_inpainting. --- scripts/soft_inpainting.py | 26 ++++++++++++++++---------- 1 file changed, 16 insertions(+), 10 deletions(-) diff --git a/scripts/soft_inpainting.py b/scripts/soft_inpainting.py index f10a1e5627c..d9024344251 100644 --- a/scripts/soft_inpainting.py +++ b/scripts/soft_inpainting.py @@ -122,7 +122,7 @@ def get_modified_nmask(settings, nmask, sigma): def apply_adaptive_masks( - settings:SoftInpaintingSettings, + settings: SoftInpaintingSettings, nmask, latent_orig, latent_processed, @@ -137,10 +137,10 @@ def apply_adaptive_masks( # TODO: Bias the blending according to the latent mask, add adjustable parameter for bias control. latent_mask = nmask[0].float() # convert the original mask into a form we use to scale distances for thresholding - mask_scalar = 1-(torch.clamp(latent_mask, min=0, max=1) ** (settings.mask_blend_scale / 2)) - mask_scalar = (0.5 * (1-settings.composite_mask_influence) + mask_scalar = 1 - (torch.clamp(latent_mask, min=0, max=1) ** (settings.mask_blend_scale / 2)) + mask_scalar = (0.5 * (1 - settings.composite_mask_influence) + mask_scalar * settings.composite_mask_influence) - mask_scalar = mask_scalar / (1.00001-mask_scalar) + mask_scalar = mask_scalar / (1.00001 - mask_scalar) mask_scalar = mask_scalar.cpu().numpy() latent_distance = torch.norm(latent_processed - latent_orig, p=2, dim=1) @@ -152,9 +152,9 @@ def apply_adaptive_masks( for i, (distance_map, overlay_image) in enumerate(zip(latent_distance, overlay_images)): converted_mask = distance_map.float().cpu().numpy() converted_mask = weighted_histogram_filter(converted_mask, kernel, kernel_center, - percentile_min=0.9, percentile_max=1, min_width=1) + percentile_min=0.9, percentile_max=1, min_width=1) converted_mask = weighted_histogram_filter(converted_mask, kernel, kernel_center, - percentile_min=0.25, percentile_max=0.75, min_width=1) + percentile_min=0.25, percentile_max=0.75, min_width=1) # The distance at which opacity of original decreases to 50% half_weighted_distance = settings.composite_difference_threshold * mask_scalar @@ -276,6 +276,7 @@ class WeightedElement: An element of the histogram, its weight and bounds. """ + def __init__(self, value, weight): self.value: float = value self.weight: float = weight @@ -355,6 +356,7 @@ def sort_key(x: WeightedElement): return img_out + def smoothstep(x): """ The smoothstep function, input should be clamped to 0-1 range. @@ -362,6 +364,7 @@ def smoothstep(x): """ return x * x * (3 - 2 * x) + def smootherstep(x): """ The smootherstep function, input should be clamped to 0-1 range. @@ -385,6 +388,7 @@ def get_gaussian_kernel(stddev_radius=1.0, max_radius=2): Returns: (nparray, nparray): A kernel array (shape: (N, N)), its center coordinate (shape: (2)) """ + # Evaluates a 0-1 normalized gaussian function for a given square distance from the mean. def gaussian(sqr_mag): return math.exp(-sqr_mag / (stddev_radius * stddev_radius)) @@ -656,7 +660,8 @@ def process(self, p, enabled, power, scale, detail_preservation, mask_inf, dif_t # p.extra_generation_params["Mask rounding"] = False settings.add_generation_params(p.extra_generation_params) - def on_mask_blend(self, p, mba: scripts.MaskBlendArgs, enabled, power, scale, detail_preservation, mask_inf, dif_thresh, dif_contr): + def on_mask_blend(self, p, mba: scripts.MaskBlendArgs, enabled, power, scale, detail_preservation, mask_inf, + dif_thresh, dif_contr): if not enabled: return @@ -675,7 +680,8 @@ def on_mask_blend(self, p, mba: scripts.MaskBlendArgs, enabled, power, scale, de mba.current_latent, get_modified_nmask(settings, mba.nmask, mba.sigma[0])) - def post_sample(self, p, ps: scripts.PostSampleArgs, enabled, power, scale, detail_preservation, mask_inf, dif_thresh, dif_contr): + def post_sample(self, p, ps: scripts.PostSampleArgs, enabled, power, scale, detail_preservation, mask_inf, + dif_thresh, dif_contr): if not enabled: return @@ -723,8 +729,8 @@ def post_sample(self, p, ps: scripts.PostSampleArgs, enabled, power, scale, deta height=p.height, paste_to=p.paste_to) - - def postprocess_maskoverlay(self, p, ppmo: scripts.PostProcessMaskOverlayArgs, enabled, power, scale, detail_preservation, mask_inf, dif_thresh, dif_contr): + def postprocess_maskoverlay(self, p, ppmo: scripts.PostProcessMaskOverlayArgs, enabled, power, scale, + detail_preservation, mask_inf, dif_thresh, dif_contr): if not enabled: return From 59429793440fb3cb1624ddcc702c6f9807373203 Mon Sep 17 00:00:00 2001 From: Nuullll Date: Sat, 9 Dec 2023 18:09:45 +0800 Subject: [PATCH 0411/1174] Fix ControlNet --- modules/xpu_specific.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/modules/xpu_specific.py b/modules/xpu_specific.py index ec1ad100aae..9bb0a56155a 100644 --- a/modules/xpu_specific.py +++ b/modules/xpu_specific.py @@ -51,3 +51,9 @@ def torch_xpu_gc(): CondFunc('torch.bmm', lambda orig_func, input, mat2, out=None: orig_func(input.to(mat2.dtype), mat2, out=out), lambda orig_func, input, mat2, out=None: input.dtype != mat2.dtype) + CondFunc('torch.cat', + lambda orig_func, tensors, dim=0, out=None: orig_func([t.to(tensors[0].dtype) for t in tensors], dim=dim, out=out), + lambda orig_func, tensors, dim=0, out=None: not all(t.dtype == tensors[0].dtype for t in tensors)) + CondFunc('torch.nn.functional.scaled_dot_product_attention', + lambda orig_func, query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False: orig_func(query, key.to(query.dtype), value.to(query.dtype), attn_mask, dropout_p, is_causal), + lambda orig_func, query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False: query.dtype != key.dtype or query.dtype != value.dtype) \ No newline at end of file From 049d5642e58d572ee8657ac754e72d019eea0e6c Mon Sep 17 00:00:00 2001 From: Nuullll Date: Sat, 9 Dec 2023 18:11:26 +0800 Subject: [PATCH 0412/1174] Fix format --- modules/xpu_specific.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/xpu_specific.py b/modules/xpu_specific.py index 9bb0a56155a..d8da94a0efd 100644 --- a/modules/xpu_specific.py +++ b/modules/xpu_specific.py @@ -56,4 +56,4 @@ def torch_xpu_gc(): lambda orig_func, tensors, dim=0, out=None: not all(t.dtype == tensors[0].dtype for t in tensors)) CondFunc('torch.nn.functional.scaled_dot_product_attention', lambda orig_func, query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False: orig_func(query, key.to(query.dtype), value.to(query.dtype), attn_mask, dropout_p, is_causal), - lambda orig_func, query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False: query.dtype != key.dtype or query.dtype != value.dtype) \ No newline at end of file + lambda orig_func, query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False: query.dtype != key.dtype or query.dtype != value.dtype) From 39ec4cfea9040bc94e639eb4aa8ab8ed37a68f01 Mon Sep 17 00:00:00 2001 From: kaalibro Date: Sat, 9 Dec 2023 19:12:59 +0600 Subject: [PATCH 0413/1174] Re-add setting lost as part of e294e46 --- modules/shared_options.py | 1 + 1 file changed, 1 insertion(+) diff --git a/modules/shared_options.py b/modules/shared_options.py index e5de0d018fb..acb6e2d4819 100644 --- a/modules/shared_options.py +++ b/modules/shared_options.py @@ -256,6 +256,7 @@ "keyedit_precision_extra": OptionInfo(0.05, "Precision for when editing the prompt with Ctrl+up/down", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}), "keyedit_delimiters": OptionInfo(r".,\/!?%^*;:{}=`~() ", "Word delimiters when editing the prompt with Ctrl+up/down"), "keyedit_delimiters_whitespace": OptionInfo(["Tab", "Carriage Return", "Line Feed"], "Ctrl+up/down whitespace delimiters", gr.CheckboxGroup, lambda: {"choices": ["Tab", "Carriage Return", "Line Feed"]}), + "keyedit_move": OptionInfo(True, "Alt+left/right moves prompt elements"), "disable_token_counters": OptionInfo(False, "Disable prompt token counters").needs_reload_ui(), })) From 9c201550ddae0b33367adfb99bcbb57ba9b207a9 Mon Sep 17 00:00:00 2001 From: kaalibro Date: Sat, 9 Dec 2023 21:04:45 +0600 Subject: [PATCH 0414/1174] Add keyboard shortcuts for generation (Removed Alt+Enter) Ctrl+Enter to start/restart generation (New) Alt/Option+Enter to skip generation (New) Ctrl+Alt/Option+Enter to interrupt generation --- modules/ui_toprow.py | 4 ++-- script.js | 23 +++++++++++++++++++---- 2 files changed, 21 insertions(+), 6 deletions(-) diff --git a/modules/ui_toprow.py b/modules/ui_toprow.py index 88838f97749..c3865e3d91d 100644 --- a/modules/ui_toprow.py +++ b/modules/ui_toprow.py @@ -79,11 +79,11 @@ def create_inline_toprow_image(self): def create_prompts(self): with gr.Column(elem_id=f"{self.id_part}_prompt_container", elem_classes=["prompt-container-compact"] if self.is_compact else [], scale=6): with gr.Row(elem_id=f"{self.id_part}_prompt_row", elem_classes=["prompt-row"]): - self.prompt = gr.Textbox(label="Prompt", elem_id=f"{self.id_part}_prompt", show_label=False, lines=3, placeholder="Prompt (press Ctrl+Enter or Alt+Enter to generate)", elem_classes=["prompt"]) + self.prompt = gr.Textbox(label="Prompt", elem_id=f"{self.id_part}_prompt", show_label=False, lines=3, placeholder="Prompt\n(Press Ctrl+Enter to generate, Alt+Enter to skip, Ctrl+Alt+Enter to interrupt)", elem_classes=["prompt"]) self.prompt_img = gr.File(label="", elem_id=f"{self.id_part}_prompt_image", file_count="single", type="binary", visible=False) with gr.Row(elem_id=f"{self.id_part}_neg_prompt_row", elem_classes=["prompt-row"]): - self.negative_prompt = gr.Textbox(label="Negative prompt", elem_id=f"{self.id_part}_neg_prompt", show_label=False, lines=3, placeholder="Negative prompt (press Ctrl+Enter or Alt+Enter to generate)", elem_classes=["prompt"]) + self.negative_prompt = gr.Textbox(label="Negative prompt", elem_id=f"{self.id_part}_neg_prompt", show_label=False, lines=3, placeholder="Negative prompt\n(Press Ctrl+Enter to generate, Alt+Enter to skip, Ctrl+Alt+Enter to interrupt)", elem_classes=["prompt"]) self.prompt_img.change( fn=modules.images.image_data, diff --git a/script.js b/script.js index c0e678ea702..69598f45454 100644 --- a/script.js +++ b/script.js @@ -121,16 +121,21 @@ document.addEventListener("DOMContentLoaded", function() { }); /** - * Add a ctrl+enter as a shortcut to start a generation + * Add keyboard shortcuts: + * Ctrl+Enter to start/restart a generation + * Alt/Option+Enter to skip a generation + * Alt/Option+Ctrl+Enter to interrupt a generation */ document.addEventListener('keydown', function(e) { const isEnter = e.key === 'Enter' || e.keyCode === 13; - const isModifierKey = e.metaKey || e.ctrlKey || e.altKey; + const isCtrlKey = e.metaKey || e.ctrlKey; + const isAltKey = e.altKey; - const interruptButton = get_uiCurrentTabContent().querySelector('button[id$=_interrupt]'); const generateButton = get_uiCurrentTabContent().querySelector('button[id$=_generate]'); + const interruptButton = get_uiCurrentTabContent().querySelector('button[id$=_interrupt]'); + const skipButton = get_uiCurrentTabContent().querySelector('button[id$=_skip]'); - if (isEnter && isModifierKey) { + if (isCtrlKey && isEnter && !isAltKey) { if (interruptButton.style.display === 'block') { interruptButton.click(); const callback = (mutationList) => { @@ -150,6 +155,16 @@ document.addEventListener('keydown', function(e) { } e.preventDefault(); } + + if (isAltKey && isEnter && !isCtrlKey) { + skipButton.click(); + e.preventDefault(); + } + + if (isAltKey && isCtrlKey && isEnter) { + interruptButton.click(); + e.preventDefault(); + } }); /** From 1a79a5049bdfef285235e83f37b201e39dd54f81 Mon Sep 17 00:00:00 2001 From: kaalibro Date: Sat, 9 Dec 2023 22:35:31 +0600 Subject: [PATCH 0415/1174] Assign id for "extra_options". Replace numeric field with slider in Settings. --- .../extra-options-section/scripts/extra_options_section.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/extensions-builtin/extra-options-section/scripts/extra_options_section.py b/extensions-builtin/extra-options-section/scripts/extra_options_section.py index a903df62548..b9867fe63b8 100644 --- a/extensions-builtin/extra-options-section/scripts/extra_options_section.py +++ b/extensions-builtin/extra-options-section/scripts/extra_options_section.py @@ -23,11 +23,12 @@ def ui(self, is_img2img): self.setting_names = [] self.infotext_fields = [] extra_options = shared.opts.extra_options_img2img if is_img2img else shared.opts.extra_options_txt2img + elem_id_tabname = "extra_options_" + ("img2img" if is_img2img else "txt2img") mapping = {k: v for v, k in generation_parameters_copypaste.infotext_to_setting_name_mapping} with gr.Blocks() as interface: - with gr.Accordion("Options", open=False) if shared.opts.extra_options_accordion and extra_options else gr.Group(): + with gr.Accordion("Options", open=False, elem_id=elem_id_tabname) if shared.opts.extra_options_accordion and extra_options else gr.Group(elem_id=elem_id_tabname): row_count = math.ceil(len(extra_options) / shared.opts.extra_options_cols) @@ -70,7 +71,7 @@ def before_process(self, p, *args): """), "extra_options_txt2img": shared.OptionInfo([], "Settings for txt2img", ui_components.DropdownMulti, lambda: {"choices": list(shared.opts.data_labels.keys())}).js("info", "settingsHintsShowQuicksettings").info("setting entries that also appear in txt2img interfaces").needs_reload_ui(), "extra_options_img2img": shared.OptionInfo([], "Settings for img2img", ui_components.DropdownMulti, lambda: {"choices": list(shared.opts.data_labels.keys())}).js("info", "settingsHintsShowQuicksettings").info("setting entries that also appear in img2img interfaces").needs_reload_ui(), - "extra_options_cols": shared.OptionInfo(1, "Number of columns for added settings", gr.Number, {"precision": 0}).needs_reload_ui(), + "extra_options_cols": shared.OptionInfo(1, "Number of columns for added settings", gr.Slider, {"step": 1, "minimum": 1, "maximum": 6}).needs_reload_ui(), "extra_options_accordion": shared.OptionInfo(False, "Place added settings into an accordion").needs_reload_ui() })) From 5381405eaa1e809e5cfb97522bd4c19d3c946079 Mon Sep 17 00:00:00 2001 From: drhead <1313496+drhead@users.noreply.github.com> Date: Sat, 9 Dec 2023 14:09:28 -0500 Subject: [PATCH 0416/1174] re-derive sqrt alpha bar and sqrt one minus alphabar This is the only place these values are ever referenced outside of training code so this change is very justifiable and more consistent. --- modules/sd_samplers_timesteps.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/sd_samplers_timesteps.py b/modules/sd_samplers_timesteps.py index b17a8f93c2b..c4bd5c12775 100644 --- a/modules/sd_samplers_timesteps.py +++ b/modules/sd_samplers_timesteps.py @@ -36,7 +36,7 @@ def __init__(self, model, *args, **kwargs): self.inner_model = model def predict_eps_from_z_and_v(self, x_t, t, v): - return self.inner_model.sqrt_alphas_cumprod[t.to(torch.int), None, None, None] * v + self.inner_model.sqrt_one_minus_alphas_cumprod[t.to(torch.int), None, None, None] * x_t + return torch.sqrt(self.inner_model.alphas_cumprod)[t.to(torch.int), None, None, None] * v + torch.sqrt(1 - self.inner_model.alphas_cumprod)[t.to(torch.int), None, None, None] * x_t def forward(self, input, timesteps, **kwargs): model_output = self.inner_model.apply_model(input, timesteps, **kwargs) From 23a0e60b9bf90a80f8af9732cc6495fbfce2ea21 Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Sun, 10 Dec 2023 14:03:41 +0900 Subject: [PATCH 0417/1174] fix save styles --- modules/styles.py | 25 +++++-------------------- 1 file changed, 5 insertions(+), 20 deletions(-) diff --git a/modules/styles.py b/modules/styles.py index 7fb6c2e1177..07588945eba 100644 --- a/modules/styles.py +++ b/modules/styles.py @@ -155,10 +155,8 @@ def load_from_csv(self, path: str): row["name"], prompt, negative_prompt, path ) - def get_style_paths(self) -> list(): - """ - Returns a list of all distinct paths, including the default path, of - files that styles are loaded from.""" + def get_style_paths(self) -> set: + """Returns a set of all distinct paths of files that styles are loaded from.""" # Update any styles without a path to the default path for style in list(self.styles.values()): if not style.path: @@ -172,9 +170,9 @@ def get_style_paths(self) -> list(): style_paths.add(style.path) # Remove any paths for styles that are just list dividers - style_paths.remove("do_not_save") + style_paths.discard("do_not_save") - return list(style_paths) + return style_paths def get_style_prompts(self, styles): return [self.styles.get(x, self.no_style).prompt for x in styles] @@ -196,20 +194,7 @@ def save_styles(self, path: str = None) -> None: # The path argument is deprecated, but kept for backwards compatibility _ = path - # Update any styles without a path to the default path - for style in list(self.styles.values()): - if not style.path: - self.styles[style.name] = style._replace(path=self.default_path) - - # Create a list of all distinct paths, including the default path - style_paths = set() - style_paths.add(self.default_path) - for _, style in self.styles.items(): - if style.path: - style_paths.add(style.path) - - # Remove any paths for styles that are just list dividers - style_paths.remove("do_not_save") + style_paths = self.get_style_paths() csv_names = [os.path.split(path)[1].lower() for path in style_paths] From 8b74389e76a7678e972583ef16100e90e1519e55 Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Sun, 10 Dec 2023 15:48:16 +0900 Subject: [PATCH 0418/1174] fix styles.csv filename --- modules/styles.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/modules/styles.py b/modules/styles.py index 07588945eba..81d9800d184 100644 --- a/modules/styles.py +++ b/modules/styles.py @@ -98,10 +98,8 @@ def __init__(self, path: str): self.path = path folder, file = os.path.split(self.path) - self.default_file = file.split("*")[0] + ".csv" - if self.default_file == ".csv": - self.default_file = "styles.csv" - self.default_path = os.path.join(folder, self.default_file) + filename, _, ext = file.partition('*') + self.default_path = os.path.join(folder, filename + ext) self.prompt_fields = [field for field in PromptStyle._fields if field != "path"] From 6b8143a84e112f029ee1868b6ab98b1d2c773ead Mon Sep 17 00:00:00 2001 From: kaalibro Date: Sun, 10 Dec 2023 15:35:06 +0600 Subject: [PATCH 0419/1174] Number of columns slider: max count set to 20, add description info --- .../extra-options-section/scripts/extra_options_section.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/extensions-builtin/extra-options-section/scripts/extra_options_section.py b/extensions-builtin/extra-options-section/scripts/extra_options_section.py index b9867fe63b8..ac2c3de4643 100644 --- a/extensions-builtin/extra-options-section/scripts/extra_options_section.py +++ b/extensions-builtin/extra-options-section/scripts/extra_options_section.py @@ -71,7 +71,7 @@ def before_process(self, p, *args): """), "extra_options_txt2img": shared.OptionInfo([], "Settings for txt2img", ui_components.DropdownMulti, lambda: {"choices": list(shared.opts.data_labels.keys())}).js("info", "settingsHintsShowQuicksettings").info("setting entries that also appear in txt2img interfaces").needs_reload_ui(), "extra_options_img2img": shared.OptionInfo([], "Settings for img2img", ui_components.DropdownMulti, lambda: {"choices": list(shared.opts.data_labels.keys())}).js("info", "settingsHintsShowQuicksettings").info("setting entries that also appear in img2img interfaces").needs_reload_ui(), - "extra_options_cols": shared.OptionInfo(1, "Number of columns for added settings", gr.Slider, {"step": 1, "minimum": 1, "maximum": 6}).needs_reload_ui(), + "extra_options_cols": shared.OptionInfo(1, "Number of columns for added settings", gr.Slider, {"step": 1, "minimum": 1, "maximum": 20}).info("displayed amount will depend on the actual browser window width").needs_reload_ui(), "extra_options_accordion": shared.OptionInfo(False, "Place added settings into an accordion").needs_reload_ui() })) From 1d42babd324b933bae317cb427fe0513138954f4 Mon Sep 17 00:00:00 2001 From: kaalibro Date: Sun, 10 Dec 2023 16:28:56 +0600 Subject: [PATCH 0420/1174] Replace Ctrl+Alt+Enter with Esc --- modules/ui_toprow.py | 4 ++-- script.js | 17 +++++++++++------ 2 files changed, 13 insertions(+), 8 deletions(-) diff --git a/modules/ui_toprow.py b/modules/ui_toprow.py index c3865e3d91d..9caf8faa2de 100644 --- a/modules/ui_toprow.py +++ b/modules/ui_toprow.py @@ -79,11 +79,11 @@ def create_inline_toprow_image(self): def create_prompts(self): with gr.Column(elem_id=f"{self.id_part}_prompt_container", elem_classes=["prompt-container-compact"] if self.is_compact else [], scale=6): with gr.Row(elem_id=f"{self.id_part}_prompt_row", elem_classes=["prompt-row"]): - self.prompt = gr.Textbox(label="Prompt", elem_id=f"{self.id_part}_prompt", show_label=False, lines=3, placeholder="Prompt\n(Press Ctrl+Enter to generate, Alt+Enter to skip, Ctrl+Alt+Enter to interrupt)", elem_classes=["prompt"]) + self.prompt = gr.Textbox(label="Prompt", elem_id=f"{self.id_part}_prompt", show_label=False, lines=3, placeholder="Prompt\n(Press Ctrl+Enter to generate, Alt+Enter to skip, Esc to interrupt)", elem_classes=["prompt"]) self.prompt_img = gr.File(label="", elem_id=f"{self.id_part}_prompt_image", file_count="single", type="binary", visible=False) with gr.Row(elem_id=f"{self.id_part}_neg_prompt_row", elem_classes=["prompt-row"]): - self.negative_prompt = gr.Textbox(label="Negative prompt", elem_id=f"{self.id_part}_neg_prompt", show_label=False, lines=3, placeholder="Negative prompt\n(Press Ctrl+Enter to generate, Alt+Enter to skip, Ctrl+Alt+Enter to interrupt)", elem_classes=["prompt"]) + self.negative_prompt = gr.Textbox(label="Negative prompt", elem_id=f"{self.id_part}_neg_prompt", show_label=False, lines=3, placeholder="Negative prompt\n(Press Ctrl+Enter to generate, Alt+Enter to skip, Esc to interrupt)", elem_classes=["prompt"]) self.prompt_img.change( fn=modules.images.image_data, diff --git a/script.js b/script.js index 69598f45454..44950090aba 100644 --- a/script.js +++ b/script.js @@ -124,18 +124,19 @@ document.addEventListener("DOMContentLoaded", function() { * Add keyboard shortcuts: * Ctrl+Enter to start/restart a generation * Alt/Option+Enter to skip a generation - * Alt/Option+Ctrl+Enter to interrupt a generation + * Esc to interrupt a generation */ document.addEventListener('keydown', function(e) { const isEnter = e.key === 'Enter' || e.keyCode === 13; const isCtrlKey = e.metaKey || e.ctrlKey; const isAltKey = e.altKey; + const isEsc = e.key === 'Escape'; const generateButton = get_uiCurrentTabContent().querySelector('button[id$=_generate]'); const interruptButton = get_uiCurrentTabContent().querySelector('button[id$=_interrupt]'); const skipButton = get_uiCurrentTabContent().querySelector('button[id$=_skip]'); - if (isCtrlKey && isEnter && !isAltKey) { + if (isCtrlKey && isEnter) { if (interruptButton.style.display === 'block') { interruptButton.click(); const callback = (mutationList) => { @@ -156,14 +157,18 @@ document.addEventListener('keydown', function(e) { e.preventDefault(); } - if (isAltKey && isEnter && !isCtrlKey) { + if (isAltKey && isEnter) { skipButton.click(); e.preventDefault(); } - if (isAltKey && isCtrlKey && isEnter) { - interruptButton.click(); - e.preventDefault(); + if (isEsc) { + if (!globalPopup || globalPopup.style.display === "none") { + interruptButton.click(); + e.preventDefault(); + } else { + closePopup(); + } } }); From cee1a4065162982e18f32761259d9107538c2d93 Mon Sep 17 00:00:00 2001 From: kaalibro Date: Sun, 10 Dec 2023 17:06:12 +0600 Subject: [PATCH 0421/1174] Fix linter issues --- script.js | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/script.js b/script.js index 44950090aba..354154b01d6 100644 --- a/script.js +++ b/script.js @@ -163,11 +163,13 @@ document.addEventListener('keydown', function(e) { } if (isEsc) { + const globalPopup = document.querySelector('.global-popup'); if (!globalPopup || globalPopup.style.display === "none") { interruptButton.click(); e.preventDefault(); } else { - closePopup(); + if (!globalPopup) return; + globalPopup.style.display = "none"; } } }); From 6513470f0db1aed1b0a5200634e8e02f7c05e932 Mon Sep 17 00:00:00 2001 From: kaalibro Date: Mon, 11 Dec 2023 18:06:08 +0600 Subject: [PATCH 0422/1174] Remove unnecessary 'else', add 'lightboxModal' check --- script.js | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/script.js b/script.js index 354154b01d6..be1bc317e2b 100644 --- a/script.js +++ b/script.js @@ -164,12 +164,11 @@ document.addEventListener('keydown', function(e) { if (isEsc) { const globalPopup = document.querySelector('.global-popup'); - if (!globalPopup || globalPopup.style.display === "none") { + const lightboxModal = document.querySelector('#lightboxModal'); + if (!globalPopup || globalPopup.style.display === 'none') { + if (document.activeElement === lightboxModal) return; interruptButton.click(); e.preventDefault(); - } else { - if (!globalPopup) return; - globalPopup.style.display = "none"; } } }); From cc41cc4349514bbfeb9f37445c931a050b076bd6 Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Wed, 13 Dec 2023 02:06:56 +0900 Subject: [PATCH 0423/1174] on mouse hover show / hide modal image viewer icons --- style.css | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/style.css b/style.css index ee39a57b73e..ec449bdea44 100644 --- a/style.css +++ b/style.css @@ -749,6 +749,22 @@ table.popup-table .link{ display: none; } +@media (pointer: fine) { + .modalPrev:hover, + .modalNext:hover, + .modalControls:hover ~ .modalPrev, + .modalControls:hover ~ .modalNext, + .modalControls:hover .cursor { + opacity: 1; + } + + .modalPrev, + .modalNext, + .modalControls .cursor { + opacity: 0; + } +} + /* context menu (ie for the generate button) */ #context-menu{ From bda86f0fd9653657c146f7c1128f92771d16ad4e Mon Sep 17 00:00:00 2001 From: Hina <102651522+HinaHyugaHime@users.noreply.github.com> Date: Tue, 12 Dec 2023 19:39:14 -0600 Subject: [PATCH 0424/1174] Update webui.sh --- webui.sh | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/webui.sh b/webui.sh index 3d0f87eed74..046ecf9d001 100755 --- a/webui.sh +++ b/webui.sh @@ -131,7 +131,7 @@ case "$gpu_info" in if [[ $(bc <<< "$pyv <= 3.10") -eq 1 ]] then # Navi users will still use torch 1.13 because 2.0 does not seem to work. - export TORCH_COMMAND="pip install torch==1.13.1+rocm5.2 torchvision==0.14.1+rocm5.2 --index-url https://download.pytorch.org/whl/rocm5.2" + export TORCH_COMMAND="pip install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/rocm5.6" else printf "\e[1m\e[31mERROR: RX 5000 series GPUs must be using at max python 3.10, aborting...\e[0m" exit 1 @@ -141,9 +141,8 @@ case "$gpu_info" in *"Navi 2"*) export HSA_OVERRIDE_GFX_VERSION=10.3.0 ;; *"Navi 3"*) [[ -z "${TORCH_COMMAND}" ]] && \ - export TORCH_COMMAND="pip install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/rocm5.6" - # Navi 3 needs at least 5.5 which is only on the nightly chain, previous versions are no longer online (torch==2.1.0.dev-20230614+rocm5.5 torchvision==0.16.0.dev-20230614+rocm5.5 torchaudio==2.1.0.dev-20230614+rocm5.5) - # so switch to nightly rocm5.6 without explicit versions this time + export TORCH_COMMAND="pip install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/rocm5.7" + ;; *"Renoir"*) export HSA_OVERRIDE_GFX_VERSION=9.0.0 printf "\n%s\n" "${delimiter}" From 89cfbc3bbe401fe1655afb07edbae34ec6af7aca Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Wed, 13 Dec 2023 12:22:13 +0200 Subject: [PATCH 0425/1174] Allow pasting in WIDTHxHEIGHT strings into the width/height fields --- javascript/ui.js | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/javascript/ui.js b/javascript/ui.js index 410fc44e3d9..18c9f891afc 100644 --- a/javascript/ui.js +++ b/javascript/ui.js @@ -215,9 +215,33 @@ function restoreProgressImg2img() { } +/** + * Configure the width and height elements on `tabname` to accept + * pasting of resolutions in the form of "width x height". + */ +function setupResolutionPasting(tabname) { + var width = gradioApp().querySelector(`#${tabname}_width input[type=number]`); + var height = gradioApp().querySelector(`#${tabname}_height input[type=number]`); + for (const el of [width, height]) { + el.addEventListener('paste', function(event) { + var pasteData = event.clipboardData.getData('text/plain'); + var parsed = pasteData.match(/^\s*(\d+)\D+(\d+)\s*$/); + if (parsed) { + width.value = parsed[1]; + height.value = parsed[2]; + updateInput(width); + updateInput(height); + event.preventDefault(); + } + }); + } +} + onUiLoaded(function() { showRestoreProgressButton('txt2img', localGet("txt2img_task_id")); showRestoreProgressButton('img2img', localGet("img2img_task_id")); + setupResolutionPasting('txt2img'); + setupResolutionPasting('img2img'); }); From 735c9e8059384d4f640e5582413c30871f83eac5 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Thu, 14 Dec 2023 01:38:32 +0800 Subject: [PATCH 0426/1174] Fix network_oft --- extensions-builtin/Lora/network_oft.py | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/extensions-builtin/Lora/network_oft.py b/extensions-builtin/Lora/network_oft.py index 05c3781187a..44465f7aa26 100644 --- a/extensions-builtin/Lora/network_oft.py +++ b/extensions-builtin/Lora/network_oft.py @@ -53,12 +53,17 @@ def __init__(self, net: network.Network, weights: network.NetworkWeights): self.constraint = None self.block_size, self.num_blocks = factorization(self.out_dim, self.dim) - def calc_updown_kb(self, orig_weight, multiplier): + def calc_updown(self, orig_weight): + I = torch.eye(self.block_size, device=self.oft_blocks.device) oft_blocks = self.oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype) - oft_blocks = oft_blocks - oft_blocks.transpose(1, 2) # ensure skew-symmetric orthogonal matrix + if self.is_kohya: + block_Q = oft_blocks - oft_blocks.transpose(1, 2) # ensure skew-symmetric orthogonal matrix + norm_Q = torch.norm(block_Q.flatten()) + new_norm_Q = torch.clamp(norm_Q, max=self.constraint) + block_Q = block_Q * ((new_norm_Q + 1e-8) / (norm_Q + 1e-8)) + oft_blocks = torch.matmul(I + block_Q, (I - block_Q).float().inverse()) R = oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype) - R = R * multiplier + torch.eye(self.block_size, device=orig_weight.device) # This errors out for MultiheadAttention, might need to be handled up-stream merged_weight = rearrange(orig_weight, '(k n) ... -> k n ...', k=self.num_blocks, n=self.block_size) @@ -70,15 +75,10 @@ def calc_updown_kb(self, orig_weight, multiplier): merged_weight = rearrange(merged_weight, 'k m ... -> (k m) ...') updown = merged_weight.to(orig_weight.device, dtype=orig_weight.dtype) - orig_weight + print(torch.norm(updown)) output_shape = orig_weight.shape return self.finalize_updown(updown, orig_weight, output_shape) - def calc_updown(self, orig_weight): - # if alpha is a very small number as in coft, calc_scale() will return a almost zero number so we ignore it - multiplier = self.multiplier() - return self.calc_updown_kb(orig_weight, multiplier) - - # override to remove the multiplier/scale factor; it's already multiplied in get_weight def finalize_updown(self, updown, orig_weight, output_shape, ex_bias=None): if self.bias is not None: updown = updown.reshape(self.bias.shape) @@ -94,4 +94,5 @@ def finalize_updown(self, updown, orig_weight, output_shape, ex_bias=None): if ex_bias is not None: ex_bias = ex_bias * self.multiplier() - return updown, ex_bias + # Ignore calc_scale, which is not used in OFT. + return updown * self.multiplier(), ex_bias From 265bc26c21264d63956e8f30f1ce31dec917fc76 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Thu, 14 Dec 2023 01:43:24 +0800 Subject: [PATCH 0427/1174] Use self.scale instead of custom finalize --- extensions-builtin/Lora/network_oft.py | 20 ++------------------ 1 file changed, 2 insertions(+), 18 deletions(-) diff --git a/extensions-builtin/Lora/network_oft.py b/extensions-builtin/Lora/network_oft.py index 44465f7aa26..e3ae61a22b4 100644 --- a/extensions-builtin/Lora/network_oft.py +++ b/extensions-builtin/Lora/network_oft.py @@ -21,6 +21,8 @@ def __init__(self, net: network.Network, weights: network.NetworkWeights): self.lin_module = None self.org_module: list[torch.Module] = [self.sd_module] + self.scale = 1.0 + # kohya-ss if "oft_blocks" in weights.w.keys(): self.is_kohya = True @@ -78,21 +80,3 @@ def calc_updown(self, orig_weight): print(torch.norm(updown)) output_shape = orig_weight.shape return self.finalize_updown(updown, orig_weight, output_shape) - - def finalize_updown(self, updown, orig_weight, output_shape, ex_bias=None): - if self.bias is not None: - updown = updown.reshape(self.bias.shape) - updown += self.bias.to(orig_weight.device, dtype=orig_weight.dtype) - updown = updown.reshape(output_shape) - - if len(output_shape) == 4: - updown = updown.reshape(output_shape) - - if orig_weight.size().numel() == updown.size().numel(): - updown = updown.reshape(orig_weight.shape) - - if ex_bias is not None: - ex_bias = ex_bias * self.multiplier() - - # Ignore calc_scale, which is not used in OFT. - return updown * self.multiplier(), ex_bias From 8fc67f3851babd4575d3312b931d5e7c2b0c78c6 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Thu, 14 Dec 2023 01:44:49 +0800 Subject: [PATCH 0428/1174] remove debug print --- extensions-builtin/Lora/network_oft.py | 1 - 1 file changed, 1 deletion(-) diff --git a/extensions-builtin/Lora/network_oft.py b/extensions-builtin/Lora/network_oft.py index e3ae61a22b4..ff4eb59b1c2 100644 --- a/extensions-builtin/Lora/network_oft.py +++ b/extensions-builtin/Lora/network_oft.py @@ -77,6 +77,5 @@ def calc_updown(self, orig_weight): merged_weight = rearrange(merged_weight, 'k m ... -> (k m) ...') updown = merged_weight.to(orig_weight.device, dtype=orig_weight.dtype) - orig_weight - print(torch.norm(updown)) output_shape = orig_weight.shape return self.finalize_updown(updown, orig_weight, output_shape) From 3772a82a70769fe1aac884a75bf5a3313fb83328 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Thu, 14 Dec 2023 01:47:13 +0800 Subject: [PATCH 0429/1174] better naming and correct order for device. --- extensions-builtin/Lora/network_oft.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/extensions-builtin/Lora/network_oft.py b/extensions-builtin/Lora/network_oft.py index ff4eb59b1c2..fa647020f0a 100644 --- a/extensions-builtin/Lora/network_oft.py +++ b/extensions-builtin/Lora/network_oft.py @@ -56,14 +56,15 @@ def __init__(self, net: network.Network, weights: network.NetworkWeights): self.block_size, self.num_blocks = factorization(self.out_dim, self.dim) def calc_updown(self, orig_weight): - I = torch.eye(self.block_size, device=self.oft_blocks.device) oft_blocks = self.oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype) + eye = torch.eye(self.block_size, device=self.oft_blocks.device) + if self.is_kohya: block_Q = oft_blocks - oft_blocks.transpose(1, 2) # ensure skew-symmetric orthogonal matrix norm_Q = torch.norm(block_Q.flatten()) new_norm_Q = torch.clamp(norm_Q, max=self.constraint) block_Q = block_Q * ((new_norm_Q + 1e-8) / (norm_Q + 1e-8)) - oft_blocks = torch.matmul(I + block_Q, (I - block_Q).float().inverse()) + oft_blocks = torch.matmul(eye + block_Q, (eye - block_Q).float().inverse()) R = oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype) From 6ef0ff39f2a35a02e5380e522e1dff3eafd7ccfc Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Thu, 14 Dec 2023 09:39:57 +0300 Subject: [PATCH 0430/1174] Merge pull request #14300 from AUTOMATIC1111/oft_fixes Fix wrong implementation in network_oft --- extensions-builtin/Lora/network_oft.py | 37 ++++++++------------------ 1 file changed, 11 insertions(+), 26 deletions(-) diff --git a/extensions-builtin/Lora/network_oft.py b/extensions-builtin/Lora/network_oft.py index 05c3781187a..fa647020f0a 100644 --- a/extensions-builtin/Lora/network_oft.py +++ b/extensions-builtin/Lora/network_oft.py @@ -21,6 +21,8 @@ def __init__(self, net: network.Network, weights: network.NetworkWeights): self.lin_module = None self.org_module: list[torch.Module] = [self.sd_module] + self.scale = 1.0 + # kohya-ss if "oft_blocks" in weights.w.keys(): self.is_kohya = True @@ -53,12 +55,18 @@ def __init__(self, net: network.Network, weights: network.NetworkWeights): self.constraint = None self.block_size, self.num_blocks = factorization(self.out_dim, self.dim) - def calc_updown_kb(self, orig_weight, multiplier): + def calc_updown(self, orig_weight): oft_blocks = self.oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype) - oft_blocks = oft_blocks - oft_blocks.transpose(1, 2) # ensure skew-symmetric orthogonal matrix + eye = torch.eye(self.block_size, device=self.oft_blocks.device) + + if self.is_kohya: + block_Q = oft_blocks - oft_blocks.transpose(1, 2) # ensure skew-symmetric orthogonal matrix + norm_Q = torch.norm(block_Q.flatten()) + new_norm_Q = torch.clamp(norm_Q, max=self.constraint) + block_Q = block_Q * ((new_norm_Q + 1e-8) / (norm_Q + 1e-8)) + oft_blocks = torch.matmul(eye + block_Q, (eye - block_Q).float().inverse()) R = oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype) - R = R * multiplier + torch.eye(self.block_size, device=orig_weight.device) # This errors out for MultiheadAttention, might need to be handled up-stream merged_weight = rearrange(orig_weight, '(k n) ... -> k n ...', k=self.num_blocks, n=self.block_size) @@ -72,26 +80,3 @@ def calc_updown_kb(self, orig_weight, multiplier): updown = merged_weight.to(orig_weight.device, dtype=orig_weight.dtype) - orig_weight output_shape = orig_weight.shape return self.finalize_updown(updown, orig_weight, output_shape) - - def calc_updown(self, orig_weight): - # if alpha is a very small number as in coft, calc_scale() will return a almost zero number so we ignore it - multiplier = self.multiplier() - return self.calc_updown_kb(orig_weight, multiplier) - - # override to remove the multiplier/scale factor; it's already multiplied in get_weight - def finalize_updown(self, updown, orig_weight, output_shape, ex_bias=None): - if self.bias is not None: - updown = updown.reshape(self.bias.shape) - updown += self.bias.to(orig_weight.device, dtype=orig_weight.dtype) - updown = updown.reshape(output_shape) - - if len(output_shape) == 4: - updown = updown.reshape(output_shape) - - if orig_weight.size().numel() == updown.size().numel(): - updown = updown.reshape(orig_weight.shape) - - if ex_bias is not None: - ex_bias = ex_bias * self.multiplier() - - return updown, ex_bias From c7cd9b441d9061f33b7b88be519fb4c6e5b8bc1e Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Thu, 14 Dec 2023 09:41:18 +0300 Subject: [PATCH 0431/1174] Merge pull request #14296 from akx/paste-resolution Allow pasting in WIDTHxHEIGHT strings into the width/height fields --- javascript/ui.js | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/javascript/ui.js b/javascript/ui.js index 410fc44e3d9..18c9f891afc 100644 --- a/javascript/ui.js +++ b/javascript/ui.js @@ -215,9 +215,33 @@ function restoreProgressImg2img() { } +/** + * Configure the width and height elements on `tabname` to accept + * pasting of resolutions in the form of "width x height". + */ +function setupResolutionPasting(tabname) { + var width = gradioApp().querySelector(`#${tabname}_width input[type=number]`); + var height = gradioApp().querySelector(`#${tabname}_height input[type=number]`); + for (const el of [width, height]) { + el.addEventListener('paste', function(event) { + var pasteData = event.clipboardData.getData('text/plain'); + var parsed = pasteData.match(/^\s*(\d+)\D+(\d+)\s*$/); + if (parsed) { + width.value = parsed[1]; + height.value = parsed[2]; + updateInput(width); + updateInput(height); + event.preventDefault(); + } + }); + } +} + onUiLoaded(function() { showRestoreProgressButton('txt2img', localGet("txt2img_task_id")); showRestoreProgressButton('img2img', localGet("img2img_task_id")); + setupResolutionPasting('txt2img'); + setupResolutionPasting('img2img'); }); From b55f09c4e13c082590bc64cd792a0b7bd46c1c0d Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Thu, 14 Dec 2023 09:46:05 +0300 Subject: [PATCH 0432/1174] Merge pull request #14270 from kaalibro/extra-options-elem-id Assign id for "extra_options". Replace numeric field with slider. --- .../extra-options-section/scripts/extra_options_section.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/extensions-builtin/extra-options-section/scripts/extra_options_section.py b/extensions-builtin/extra-options-section/scripts/extra_options_section.py index a903df62548..ac2c3de4643 100644 --- a/extensions-builtin/extra-options-section/scripts/extra_options_section.py +++ b/extensions-builtin/extra-options-section/scripts/extra_options_section.py @@ -23,11 +23,12 @@ def ui(self, is_img2img): self.setting_names = [] self.infotext_fields = [] extra_options = shared.opts.extra_options_img2img if is_img2img else shared.opts.extra_options_txt2img + elem_id_tabname = "extra_options_" + ("img2img" if is_img2img else "txt2img") mapping = {k: v for v, k in generation_parameters_copypaste.infotext_to_setting_name_mapping} with gr.Blocks() as interface: - with gr.Accordion("Options", open=False) if shared.opts.extra_options_accordion and extra_options else gr.Group(): + with gr.Accordion("Options", open=False, elem_id=elem_id_tabname) if shared.opts.extra_options_accordion and extra_options else gr.Group(elem_id=elem_id_tabname): row_count = math.ceil(len(extra_options) / shared.opts.extra_options_cols) @@ -70,7 +71,7 @@ def before_process(self, p, *args): """), "extra_options_txt2img": shared.OptionInfo([], "Settings for txt2img", ui_components.DropdownMulti, lambda: {"choices": list(shared.opts.data_labels.keys())}).js("info", "settingsHintsShowQuicksettings").info("setting entries that also appear in txt2img interfaces").needs_reload_ui(), "extra_options_img2img": shared.OptionInfo([], "Settings for img2img", ui_components.DropdownMulti, lambda: {"choices": list(shared.opts.data_labels.keys())}).js("info", "settingsHintsShowQuicksettings").info("setting entries that also appear in img2img interfaces").needs_reload_ui(), - "extra_options_cols": shared.OptionInfo(1, "Number of columns for added settings", gr.Number, {"precision": 0}).needs_reload_ui(), + "extra_options_cols": shared.OptionInfo(1, "Number of columns for added settings", gr.Slider, {"step": 1, "minimum": 1, "maximum": 20}).info("displayed amount will depend on the actual browser window width").needs_reload_ui(), "extra_options_accordion": shared.OptionInfo(False, "Place added settings into an accordion").needs_reload_ui() })) From 888b928f0da7da4a2dfa4519e95ac17a3e5562f7 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Thu, 14 Dec 2023 09:48:14 +0300 Subject: [PATCH 0433/1174] Merge pull request #14276 from AUTOMATIC1111/fix-styles Fix styles --- modules/styles.py | 31 +++++++------------------------ 1 file changed, 7 insertions(+), 24 deletions(-) diff --git a/modules/styles.py b/modules/styles.py index 7fb6c2e1177..81d9800d184 100644 --- a/modules/styles.py +++ b/modules/styles.py @@ -98,10 +98,8 @@ def __init__(self, path: str): self.path = path folder, file = os.path.split(self.path) - self.default_file = file.split("*")[0] + ".csv" - if self.default_file == ".csv": - self.default_file = "styles.csv" - self.default_path = os.path.join(folder, self.default_file) + filename, _, ext = file.partition('*') + self.default_path = os.path.join(folder, filename + ext) self.prompt_fields = [field for field in PromptStyle._fields if field != "path"] @@ -155,10 +153,8 @@ def load_from_csv(self, path: str): row["name"], prompt, negative_prompt, path ) - def get_style_paths(self) -> list(): - """ - Returns a list of all distinct paths, including the default path, of - files that styles are loaded from.""" + def get_style_paths(self) -> set: + """Returns a set of all distinct paths of files that styles are loaded from.""" # Update any styles without a path to the default path for style in list(self.styles.values()): if not style.path: @@ -172,9 +168,9 @@ def get_style_paths(self) -> list(): style_paths.add(style.path) # Remove any paths for styles that are just list dividers - style_paths.remove("do_not_save") + style_paths.discard("do_not_save") - return list(style_paths) + return style_paths def get_style_prompts(self, styles): return [self.styles.get(x, self.no_style).prompt for x in styles] @@ -196,20 +192,7 @@ def save_styles(self, path: str = None) -> None: # The path argument is deprecated, but kept for backwards compatibility _ = path - # Update any styles without a path to the default path - for style in list(self.styles.values()): - if not style.path: - self.styles[style.name] = style._replace(path=self.default_path) - - # Create a list of all distinct paths, including the default path - style_paths = set() - style_paths.add(self.default_path) - for _, style in self.styles.items(): - if style.path: - style_paths.add(style.path) - - # Remove any paths for styles that are just list dividers - style_paths.remove("do_not_save") + style_paths = self.get_style_paths() csv_names = [os.path.split(path)[1].lower() for path in style_paths] From 5cb1ce470df8332872af3dfa1067b761062d4608 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Thu, 14 Dec 2023 09:48:36 +0300 Subject: [PATCH 0434/1174] Merge pull request #14266 from kaalibro/dev Re-add setting lost as part of e294e46 --- modules/shared_options.py | 1 + 1 file changed, 1 insertion(+) diff --git a/modules/shared_options.py b/modules/shared_options.py index e5de0d018fb..acb6e2d4819 100644 --- a/modules/shared_options.py +++ b/modules/shared_options.py @@ -256,6 +256,7 @@ "keyedit_precision_extra": OptionInfo(0.05, "Precision for when editing the prompt with Ctrl+up/down", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}), "keyedit_delimiters": OptionInfo(r".,\/!?%^*;:{}=`~() ", "Word delimiters when editing the prompt with Ctrl+up/down"), "keyedit_delimiters_whitespace": OptionInfo(["Tab", "Carriage Return", "Line Feed"], "Ctrl+up/down whitespace delimiters", gr.CheckboxGroup, lambda: {"choices": ["Tab", "Carriage Return", "Line Feed"]}), + "keyedit_move": OptionInfo(True, "Alt+left/right moves prompt elements"), "disable_token_counters": OptionInfo(False, "Disable prompt token counters").needs_reload_ui(), })) From b7e0d4a7e171ee1cef73684b8423fe4a20ca7e34 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Thu, 14 Dec 2023 09:52:23 +0300 Subject: [PATCH 0435/1174] Merge pull request #14229 from Nuullll/ipex-embedding [IPEX] Fix embedding and ControlNet --- modules/xpu_specific.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/modules/xpu_specific.py b/modules/xpu_specific.py index d933c790378..d8da94a0efd 100644 --- a/modules/xpu_specific.py +++ b/modules/xpu_specific.py @@ -48,3 +48,12 @@ def torch_xpu_gc(): CondFunc('torch.nn.modules.conv.Conv2d.forward', lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)), lambda orig_func, self, input: input.dtype != self.weight.data.dtype) + CondFunc('torch.bmm', + lambda orig_func, input, mat2, out=None: orig_func(input.to(mat2.dtype), mat2, out=out), + lambda orig_func, input, mat2, out=None: input.dtype != mat2.dtype) + CondFunc('torch.cat', + lambda orig_func, tensors, dim=0, out=None: orig_func([t.to(tensors[0].dtype) for t in tensors], dim=dim, out=out), + lambda orig_func, tensors, dim=0, out=None: not all(t.dtype == tensors[0].dtype for t in tensors)) + CondFunc('torch.nn.functional.scaled_dot_product_attention', + lambda orig_func, query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False: orig_func(query, key.to(query.dtype), value.to(query.dtype), attn_mask, dropout_p, is_causal), + lambda orig_func, query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False: query.dtype != key.dtype or query.dtype != value.dtype) From f8871dedcfe3a67689ef333aea2fdf05a9aaffa2 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Thu, 14 Dec 2023 09:59:48 +0300 Subject: [PATCH 0436/1174] Merge pull request #14230 from AUTOMATIC1111/add-option-Live-preview-in-full-page-image-viewer add option: Live preview in full page image viewer --- javascript/imageviewer.js | 2 +- modules/shared_options.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/javascript/imageviewer.js b/javascript/imageviewer.js index e4dae91bc9d..625c5d148df 100644 --- a/javascript/imageviewer.js +++ b/javascript/imageviewer.js @@ -34,7 +34,7 @@ function updateOnBackgroundChange() { if (modalImage && modalImage.offsetParent) { let currentButton = selected_gallery_button(); let preview = gradioApp().querySelectorAll('.livePreview > img'); - if (preview.length > 0) { + if (opts.js_live_preview_in_modal_lightbox && preview.length > 0) { // show preview image if available modalImage.src = preview[preview.length - 1].src; } else if (currentButton?.children?.length > 0 && modalImage.src != currentButton.children[0].src) { diff --git a/modules/shared_options.py b/modules/shared_options.py index acb6e2d4819..41097d8e6c9 100644 --- a/modules/shared_options.py +++ b/modules/shared_options.py @@ -331,6 +331,7 @@ "live_preview_content": OptionInfo("Prompt", "Live preview subject", gr.Radio, {"choices": ["Combined", "Prompt", "Negative prompt"]}), "live_preview_refresh_period": OptionInfo(1000, "Progressbar and preview update period").info("in milliseconds"), "live_preview_fast_interrupt": OptionInfo(False, "Return image with chosen live preview method on interrupt").info("makes interrupts faster"), + "js_live_preview_in_modal_lightbox": OptionInfo(True, "Show Live preview in full page image viewer"), })) options_templates.update(options_section(('sampler-params', "Sampler parameters", "sd"), { From eb52c803b849cdd1fc137db4568eca5bb8373f58 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Thu, 14 Dec 2023 10:03:14 +0300 Subject: [PATCH 0437/1174] Merge pull request #14216 from wfjsw/state-dict-ref-comparison change state dict comparison to ref compare --- modules/sd_disable_initialization.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/sd_disable_initialization.py b/modules/sd_disable_initialization.py index 8863107ae6f..273a7edd8b4 100644 --- a/modules/sd_disable_initialization.py +++ b/modules/sd_disable_initialization.py @@ -215,7 +215,7 @@ def load_state_dict(original, module, state_dict, strict=True): would be on the meta device. """ - if state_dict == sd: + if state_dict is sd: state_dict = {k: v.to(device="meta", dtype=v.dtype) for k, v in state_dict.items()} original(module, state_dict, strict=strict) From 2be85f8fe03533bf3b1ad562ea58ee8227ba3b99 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Thu, 14 Dec 2023 10:08:03 +0300 Subject: [PATCH 0438/1174] Merge pull request #14237 from ReneKroon/dev #13354 : solve lora loading issue --- extensions-builtin/Lora/networks.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/extensions-builtin/Lora/networks.py b/extensions-builtin/Lora/networks.py index 7f814706adb..629bf85376d 100644 --- a/extensions-builtin/Lora/networks.py +++ b/extensions-builtin/Lora/networks.py @@ -159,7 +159,8 @@ def load_network(name, network_on_disk): bundle_embeddings = {} for key_network, weight in sd.items(): - key_network_without_network_parts, network_part = key_network.split(".", 1) + key_network_without_network_parts, _, network_part = key_network.partition(".") + if key_network_without_network_parts == "bundle_emb": emb_name, vec_name = network_part.split(".", 1) emb_dict = bundle_embeddings.get(emb_name, {}) From 3c0c27757944ae17a7fa4c2323ee9ae2d434dbce Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Thu, 14 Dec 2023 19:36:17 +0900 Subject: [PATCH 0439/1174] default False js_live_preview_in_modal_lightbox --- modules/shared_options.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/shared_options.py b/modules/shared_options.py index 41097d8e6c9..d2e86ff10b3 100644 --- a/modules/shared_options.py +++ b/modules/shared_options.py @@ -331,7 +331,7 @@ "live_preview_content": OptionInfo("Prompt", "Live preview subject", gr.Radio, {"choices": ["Combined", "Prompt", "Negative prompt"]}), "live_preview_refresh_period": OptionInfo(1000, "Progressbar and preview update period").info("in milliseconds"), "live_preview_fast_interrupt": OptionInfo(False, "Return image with chosen live preview method on interrupt").info("makes interrupts faster"), - "js_live_preview_in_modal_lightbox": OptionInfo(True, "Show Live preview in full page image viewer"), + "js_live_preview_in_modal_lightbox": OptionInfo(False, "Show Live preview in full page image viewer"), })) options_templates.update(options_section(('sampler-params', "Sampler parameters", "sd"), { From 0c5427960b3a4ffe6d673c28e8e135b26f015717 Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Fri, 15 Dec 2023 17:11:59 +0900 Subject: [PATCH 0440/1174] make modal toolbar and icon opacity adjustable --- modules/shared_gradio_themes.py | 4 ++++ modules/shared_options.py | 2 ++ style.css | 4 ++-- 3 files changed, 8 insertions(+), 2 deletions(-) diff --git a/modules/shared_gradio_themes.py b/modules/shared_gradio_themes.py index 822db0a951d..b6dc31450bc 100644 --- a/modules/shared_gradio_themes.py +++ b/modules/shared_gradio_themes.py @@ -65,3 +65,7 @@ def reload_gradio_theme(theme_name=None): except Exception as e: errors.display(e, "changing gradio theme") shared.gradio_theme = gr.themes.Default(**default_theme_args) + + # append additional values gradio_theme + shared.gradio_theme.sd_webui_modal_lightbox_toolbar_opacity = shared.opts.sd_webui_modal_lightbox_toolbar_opacity + shared.gradio_theme.sd_webui_modal_lightbox_icon_opacity = shared.opts.sd_webui_modal_lightbox_icon_opacity diff --git a/modules/shared_options.py b/modules/shared_options.py index e5de0d018fb..86e7636cc58 100644 --- a/modules/shared_options.py +++ b/modules/shared_options.py @@ -266,6 +266,8 @@ "js_modal_lightbox_initially_zoomed": OptionInfo(True, "Full page image viewer: show images zoomed in by default"), "js_modal_lightbox_gamepad": OptionInfo(False, "Full page image viewer: navigate with gamepad"), "js_modal_lightbox_gamepad_repeat": OptionInfo(250, "Full page image viewer: gamepad repeat period").info("in milliseconds"), + "sd_webui_modal_lightbox_icon_opacity": OptionInfo(1, "Full page image viewer: control icon unfocused opacity", gr.Slider, {"minimum": 0.0, "maximum": 1, "step": 0.01}, onchange=shared.reload_gradio_theme).info('for mouse only').needs_reload_ui(), + "sd_webui_modal_lightbox_toolbar_opacity": OptionInfo(0.9, "Full page image viewer: tool bar opacity", gr.Slider, {"minimum": 0.0, "maximum": 1, "step": 0.01}, onchange=shared.reload_gradio_theme).info('for mouse only').needs_reload_ui(), "gallery_height": OptionInfo("", "Gallery height", gr.Textbox).info("can be any valid CSS value, for example 768px or 20em").needs_reload_ui(), })) diff --git a/style.css b/style.css index ec449bdea44..6d4c8a0d505 100644 --- a/style.css +++ b/style.css @@ -679,7 +679,7 @@ table.popup-table .link{ transition: 0.2s ease background-color; } .modalControls:hover { - background-color:rgba(0,0,0,0.9); + background-color:rgba(0,0,0, var(--sd-webui-modal-lightbox-toolbar-opacity)); } .modalClose { margin-left: auto; @@ -761,7 +761,7 @@ table.popup-table .link{ .modalPrev, .modalNext, .modalControls .cursor { - opacity: 0; + opacity: var(--sd-webui-modal-lightbox-icon-opacity); } } From 1242ba08e19f3d317bdc5924db2b73d0c9569a7f Mon Sep 17 00:00:00 2001 From: gayshub Date: Fri, 15 Dec 2023 16:57:17 +0800 Subject: [PATCH 0441/1174] add allow specify the task id and get the location of task in the queue of pending task --- modules/api/api.py | 20 ++++++++++++++++++-- modules/api/models.py | 2 ++ modules/processing.py | 2 ++ modules/progress.py | 21 +++++++++++++++++++-- 4 files changed, 41 insertions(+), 4 deletions(-) diff --git a/modules/api/api.py b/modules/api/api.py index e6edffe7144..5d000ae8bf6 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -33,7 +33,7 @@ import piexif import piexif.helper from contextlib import closing - +from modules.progress import create_task_id, add_task_to_queue, start_task, finish_task, current_task def script_name_to_index(name, scripts): try: @@ -337,6 +337,10 @@ def init_script_args(self, request, default_script_args, selectable_scripts, sel return script_args def text2imgapi(self, txt2imgreq: models.StableDiffusionTxt2ImgProcessingAPI): + task_id = create_task_id("text2img") + if txt2imgreq.force_task_id != None: + task_id = txt2imgreq.force_task_id + script_runner = scripts.scripts_txt2img if not script_runner.scripts: script_runner.initialize_scripts(False) @@ -363,6 +367,8 @@ def text2imgapi(self, txt2imgreq: models.StableDiffusionTxt2ImgProcessingAPI): send_images = args.pop('send_images', True) args.pop('save_images', None) + add_task_to_queue(task_id) + with self.queue_lock: with closing(StableDiffusionProcessingTxt2Img(sd_model=shared.sd_model, **args)) as p: p.is_api = True @@ -372,12 +378,14 @@ def text2imgapi(self, txt2imgreq: models.StableDiffusionTxt2ImgProcessingAPI): try: shared.state.begin(job="scripts_txt2img") + start_task(task_id) if selectable_scripts is not None: p.script_args = script_args processed = scripts.scripts_txt2img.run(p, *p.script_args) # Need to pass args as list here else: p.script_args = tuple(script_args) # Need to pass args as tuple here processed = process_images(p) + finish_task(task_id) finally: shared.state.end() shared.total_tqdm.clear() @@ -387,6 +395,10 @@ def text2imgapi(self, txt2imgreq: models.StableDiffusionTxt2ImgProcessingAPI): return models.TextToImageResponse(images=b64images, parameters=vars(txt2imgreq), info=processed.js()) def img2imgapi(self, img2imgreq: models.StableDiffusionImg2ImgProcessingAPI): + task_id = create_task_id("img2img") + if img2imgreq.force_task_id != None: + task_id = img2imgreq.force_task_id + init_images = img2imgreq.init_images if init_images is None: raise HTTPException(status_code=404, detail="Init image not found") @@ -423,6 +435,8 @@ def img2imgapi(self, img2imgreq: models.StableDiffusionImg2ImgProcessingAPI): send_images = args.pop('send_images', True) args.pop('save_images', None) + add_task_to_queue(task_id) + with self.queue_lock: with closing(StableDiffusionProcessingImg2Img(sd_model=shared.sd_model, **args)) as p: p.init_images = [decode_base64_to_image(x) for x in init_images] @@ -433,12 +447,14 @@ def img2imgapi(self, img2imgreq: models.StableDiffusionImg2ImgProcessingAPI): try: shared.state.begin(job="scripts_img2img") + start_task(task_id) if selectable_scripts is not None: p.script_args = script_args processed = scripts.scripts_img2img.run(p, *p.script_args) # Need to pass args as list here else: p.script_args = tuple(script_args) # Need to pass args as tuple here processed = process_images(p) + finish_task(task_id) finally: shared.state.end() shared.total_tqdm.clear() @@ -514,7 +530,7 @@ def progressapi(self, req: models.ProgressRequest = Depends()): if shared.state.current_image and not req.skip_current_image: current_image = encode_pil_to_base64(shared.state.current_image) - return models.ProgressResponse(progress=progress, eta_relative=eta_relative, state=shared.state.dict(), current_image=current_image, textinfo=shared.state.textinfo) + return models.ProgressResponse(progress=progress, eta_relative=eta_relative, state=shared.state.dict(), current_image=current_image, textinfo=shared.state.textinfo, current_task=current_task) def interrogateapi(self, interrogatereq: models.InterrogateRequest): image_b64 = interrogatereq.image diff --git a/modules/api/models.py b/modules/api/models.py index 6a574771c33..7b7f1773866 100644 --- a/modules/api/models.py +++ b/modules/api/models.py @@ -109,6 +109,7 @@ def generate_model(self): {"key": "send_images", "type": bool, "default": True}, {"key": "save_images", "type": bool, "default": False}, {"key": "alwayson_scripts", "type": dict, "default": {}}, + {"key": "force_task_id", "type": str, "default": None}, ] ).generate_model() @@ -126,6 +127,7 @@ def generate_model(self): {"key": "send_images", "type": bool, "default": True}, {"key": "save_images", "type": bool, "default": False}, {"key": "alwayson_scripts", "type": dict, "default": {}}, + {"key": "force_task_id", "type": str, "default": None}, ] ).generate_model() diff --git a/modules/processing.py b/modules/processing.py index e124e7f0dd2..657cacfcace 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -1023,6 +1023,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): hr_sampler_name: str = None hr_prompt: str = '' hr_negative_prompt: str = '' + force_task_id: str = None cached_hr_uc = [None, None] cached_hr_c = [None, None] @@ -1358,6 +1359,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): inpainting_mask_invert: int = 0 initial_noise_multiplier: float = None latent_mask: Image = None + force_task_id: string = None image_mask: Any = field(default=None, init=False) diff --git a/modules/progress.py b/modules/progress.py index 69921de7281..553866db113 100644 --- a/modules/progress.py +++ b/modules/progress.py @@ -8,10 +8,13 @@ from modules.shared import opts import modules.shared as shared - +from collections import OrderedDict +import string +import random +from typing import List current_task = None -pending_tasks = {} +pending_tasks = OrderedDict() finished_tasks = [] recorded_results = [] recorded_results_limit = 2 @@ -34,6 +37,11 @@ def finish_task(id_task): if len(finished_tasks) > 16: finished_tasks.pop(0) +def create_task_id(task_type): + N = 7 + res = ''.join(random.choices(string.ascii_uppercase + + string.digits, k=N)) + return f"task({task_type}-{res})" def record_results(id_task, res): recorded_results.append((id_task, res)) @@ -44,6 +52,9 @@ def record_results(id_task, res): def add_task_to_queue(id_job): pending_tasks[id_job] = time.time() +class PendingTasksResponse(BaseModel): + size: int = Field(title="Pending task size") + tasks: List[str] = Field(title="Pending task ids") class ProgressRequest(BaseModel): id_task: str = Field(default=None, title="Task ID", description="id of the task to get progress for") @@ -63,8 +74,14 @@ class ProgressResponse(BaseModel): def setup_progress_api(app): + app.add_api_route("/internal/pendingTasks", get_pending_tasks, methods=["GET"]) return app.add_api_route("/internal/progress", progressapi, methods=["POST"], response_model=ProgressResponse) +def get_pending_tasks(): + pending_tasks_ids = [x for x in pending_tasks] + pending_len = len(pending_tasks_ids) + return PendingTasksResponse(size=pending_len, tasks=pending_tasks_ids) + def progressapi(req: ProgressRequest): active = req.id_task == current_task From d859de37d9ec10cb6c804226328a11c87c444852 Mon Sep 17 00:00:00 2001 From: gayshub Date: Fri, 15 Dec 2023 17:48:20 +0800 Subject: [PATCH 0442/1174] fix the problem of ruff of github --- modules/api/api.py | 4 ++-- modules/processing.py | 2 +- modules/progress.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/modules/api/api.py b/modules/api/api.py index 5d000ae8bf6..1f464806727 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -338,7 +338,7 @@ def init_script_args(self, request, default_script_args, selectable_scripts, sel def text2imgapi(self, txt2imgreq: models.StableDiffusionTxt2ImgProcessingAPI): task_id = create_task_id("text2img") - if txt2imgreq.force_task_id != None: + if txt2imgreq.force_task_id is None: task_id = txt2imgreq.force_task_id script_runner = scripts.scripts_txt2img @@ -396,7 +396,7 @@ def text2imgapi(self, txt2imgreq: models.StableDiffusionTxt2ImgProcessingAPI): def img2imgapi(self, img2imgreq: models.StableDiffusionImg2ImgProcessingAPI): task_id = create_task_id("img2img") - if img2imgreq.force_task_id != None: + if img2imgreq.force_task_id is None: task_id = img2imgreq.force_task_id init_images = img2imgreq.init_images diff --git a/modules/processing.py b/modules/processing.py index 657cacfcace..5added65ed0 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -1359,7 +1359,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): inpainting_mask_invert: int = 0 initial_noise_multiplier: float = None latent_mask: Image = None - force_task_id: string = None + force_task_id: str = None image_mask: Any = field(default=None, init=False) diff --git a/modules/progress.py b/modules/progress.py index 553866db113..6946fb1bc2a 100644 --- a/modules/progress.py +++ b/modules/progress.py @@ -78,7 +78,7 @@ def setup_progress_api(app): return app.add_api_route("/internal/progress", progressapi, methods=["POST"], response_model=ProgressResponse) def get_pending_tasks(): - pending_tasks_ids = [x for x in pending_tasks] + pending_tasks_ids = list(pending_tasks) pending_len = len(pending_tasks_ids) return PendingTasksResponse(size=pending_len, tasks=pending_tasks_ids) From da45e73b4ffde2e2a85b64a3e3258a0625bd307e Mon Sep 17 00:00:00 2001 From: gayshub Date: Fri, 15 Dec 2023 17:57:58 +0800 Subject: [PATCH 0443/1174] fix the problem of ruff of github --- modules/api/api.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/api/api.py b/modules/api/api.py index 1f464806727..9fac7e604f3 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -340,7 +340,7 @@ def text2imgapi(self, txt2imgreq: models.StableDiffusionTxt2ImgProcessingAPI): task_id = create_task_id("text2img") if txt2imgreq.force_task_id is None: task_id = txt2imgreq.force_task_id - + script_runner = scripts.scripts_txt2img if not script_runner.scripts: script_runner.initialize_scripts(False) From 6d7e57ba6a4d686d515518b5f90e91b32fa01caf Mon Sep 17 00:00:00 2001 From: gayshub Date: Fri, 15 Dec 2023 18:03:14 +0800 Subject: [PATCH 0444/1174] fix the problem of ruff of github --- modules/api/api.py | 1 - 1 file changed, 1 deletion(-) diff --git a/modules/api/api.py b/modules/api/api.py index 9fac7e604f3..8d8e70a46a2 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -340,7 +340,6 @@ def text2imgapi(self, txt2imgreq: models.StableDiffusionTxt2ImgProcessingAPI): task_id = create_task_id("text2img") if txt2imgreq.force_task_id is None: task_id = txt2imgreq.force_task_id - script_runner = scripts.scripts_txt2img if not script_runner.scripts: script_runner.initialize_scripts(False) From 0dfffe53ec11b2ee097d55efc479f8e707015db9 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sat, 16 Dec 2023 09:25:08 +0300 Subject: [PATCH 0445/1174] Merge pull request #14307 from AUTOMATIC1111/default-Falst-js_live_preview_in_modal_lightbox default False js_live_preview_in_modal_lightbox --- modules/shared_options.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/shared_options.py b/modules/shared_options.py index 41097d8e6c9..d2e86ff10b3 100644 --- a/modules/shared_options.py +++ b/modules/shared_options.py @@ -331,7 +331,7 @@ "live_preview_content": OptionInfo("Prompt", "Live preview subject", gr.Radio, {"choices": ["Combined", "Prompt", "Negative prompt"]}), "live_preview_refresh_period": OptionInfo(1000, "Progressbar and preview update period").info("in milliseconds"), "live_preview_fast_interrupt": OptionInfo(False, "Return image with chosen live preview method on interrupt").info("makes interrupts faster"), - "js_live_preview_in_modal_lightbox": OptionInfo(True, "Show Live preview in full page image viewer"), + "js_live_preview_in_modal_lightbox": OptionInfo(False, "Show Live preview in full page image viewer"), })) options_templates.update(options_section(('sampler-params', "Sampler parameters", "sd"), { From ea272152e0b50dbb2bd675ec020607f3d50c37d0 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Sat, 16 Dec 2023 15:08:08 +0800 Subject: [PATCH 0446/1174] Add FP8 settings into PNG info --- modules/generation_parameters_copypaste.py | 6 ++++++ modules/processing.py | 2 ++ 2 files changed, 8 insertions(+) diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py index 4efe53e0c48..dbffe494998 100644 --- a/modules/generation_parameters_copypaste.py +++ b/modules/generation_parameters_copypaste.py @@ -314,6 +314,12 @@ def parse_generation_parameters(x: str): if "VAE Decoder" not in res: res["VAE Decoder"] = "Full" + if "FP8 weight" not in res: + res["FP8 weight"] = "Disable" + + if "Cache FP16 weight for LoRA" not in res and res["FP8 weight"] != "Disable": + res["Cache FP16 weight for LoRA"] = False + skip = set(shared.opts.infotext_skip_pasting) res = {k: v for k, v in res.items() if k not in skip} diff --git a/modules/processing.py b/modules/processing.py index bea01ec6884..179f2c0fccf 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -688,6 +688,8 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter "Size": f"{p.width}x{p.height}", "Model hash": p.sd_model_hash if opts.add_model_hash_to_info else None, "Model": p.sd_model_name if opts.add_model_name_to_info else None, + "FP8 weight": opts.fp8_storage if devices.fp8 else None, + "Cache FP16 weight for LoRA": opts.cache_fp16_weight if devices.fp8 else None, "VAE hash": p.sd_vae_hash if opts.add_vae_hash_to_info else None, "VAE": p.sd_vae_name if opts.add_vae_name_to_info else None, "Variation seed": (None if p.subseed_strength == 0 else (p.all_subseeds[0] if use_main_prompt else all_subseeds[index])), From 7745db6fc02faf19117838c1e7bcc8a60b5f5e90 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sat, 16 Dec 2023 10:15:08 +0300 Subject: [PATCH 0447/1174] torch 2.1.2 --- modules/errors.py | 4 ++-- modules/launch_utils.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/modules/errors.py b/modules/errors.py index c534a5d66fe..48aa13a1728 100644 --- a/modules/errors.py +++ b/modules/errors.py @@ -107,8 +107,8 @@ def check_versions(): import torch import gradio - expected_torch_version = "2.1.0" - expected_xformers_version = "0.0.22.post7" + expected_torch_version = "2.1.2" + expected_xformers_version = "0.0.23.post1" expected_gradio_version = "3.41.2" if version.parse(torch.__version__) < version.parse(expected_torch_version): diff --git a/modules/launch_utils.py b/modules/launch_utils.py index 2c54e2a02f4..dabef0f5335 100644 --- a/modules/launch_utils.py +++ b/modules/launch_utils.py @@ -315,7 +315,7 @@ def requirements_met(requirements_file): def prepare_environment(): torch_index_url = os.environ.get('TORCH_INDEX_URL', "https://download.pytorch.org/whl/cu121") - torch_command = os.environ.get('TORCH_COMMAND', f"pip install torch==2.1.0 torchvision==0.16.0 --extra-index-url {torch_index_url}") + torch_command = os.environ.get('TORCH_COMMAND', f"pip install torch==2.1.2 torchvision==0.16.2 --extra-index-url {torch_index_url}") if args.use_ipex: if platform.system() == "Windows": # The "Nuullll/intel-extension-for-pytorch" wheels were built from IPEX source for Intel Arc GPU: https://github.com/intel/intel-extension-for-pytorch/tree/xpu-main @@ -338,7 +338,7 @@ def prepare_environment(): torch_command = os.environ.get('TORCH_COMMAND', f"pip install torch==2.0.0a0 intel-extension-for-pytorch==2.0.110+gitba7f6c1 --extra-index-url {torch_index_url}") requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt") - xformers_package = os.environ.get('XFORMERS_PACKAGE', 'xformers==0.0.22.post7') + xformers_package = os.environ.get('XFORMERS_PACKAGE', 'xformers==0.0.23.post1') clip_package = os.environ.get('CLIP_PACKAGE', "https://github.com/openai/CLIP/archive/d50d76daa670286dd6cacf3bcd80b5e4823fc8e1.zip") openclip_package = os.environ.get('OPENCLIP_PACKAGE', "https://github.com/mlfoundations/open_clip/archive/bb6e834e9c70d9c27d0dc3ecedeebeaeb1ffad6b.zip") From cd9ce2e31c4a264d7cde17c54d24f8ad94c9cf2c Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sat, 16 Dec 2023 10:40:20 +0300 Subject: [PATCH 0448/1174] Use radio for FP8 mode selection --- modules/shared_options.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/shared_options.py b/modules/shared_options.py index d470eb8ffd7..fa542ba8cd6 100644 --- a/modules/shared_options.py +++ b/modules/shared_options.py @@ -206,7 +206,7 @@ "pad_cond_uncond": OptionInfo(False, "Pad prompt/negative prompt to be same length", infotext='Pad conds').info("improves performance when prompt and negative prompt have different lengths; changes seeds"), "persistent_cond_cache": OptionInfo(True, "Persistent cond cache").info("do not recalculate conds from prompts if prompts have not changed since previous calculation"), "batch_cond_uncond": OptionInfo(True, "Batch cond/uncond").info("do both conditional and unconditional denoising in one batch; uses a bit more VRAM during sampling, but improves speed; previously this was controlled by --always-batch-cond-uncond comandline argument"), - "fp8_storage": OptionInfo("Disable", "FP8 weight", gr.Dropdown, {"choices": ["Disable", "Enable for SDXL", "Enable"]}).info("Use FP8 to store Linear/Conv layers' weight. Require pytorch>=2.1.0."), + "fp8_storage": OptionInfo("Disable", "FP8 weight", gr.Radio, {"choices": ["Disable", "Enable for SDXL", "Enable"]}).info("Use FP8 to store Linear/Conv layers' weight. Require pytorch>=2.1.0."), "cache_fp16_weight": OptionInfo(False, "Cache FP16 weight for LoRA").info("Cache fp16 weight when enabling FP8, will increase the quality of LoRA. Use more system ram."), })) From 93eae69895c34361a71dbed17348bcfd132fbc6a Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sat, 16 Dec 2023 11:00:42 +0300 Subject: [PATCH 0449/1174] move soft inpainting to a built-in extension --- .../soft-inpainting/scripts}/soft_inpainting.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename {scripts => extensions-builtin/soft-inpainting/scripts}/soft_inpainting.py (100%) diff --git a/scripts/soft_inpainting.py b/extensions-builtin/soft-inpainting/scripts/soft_inpainting.py similarity index 100% rename from scripts/soft_inpainting.py rename to extensions-builtin/soft-inpainting/scripts/soft_inpainting.py From 86b3aa94e2d36a4f9d5ef1bb7c6ec995ff8eb517 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sat, 16 Dec 2023 11:04:59 +0300 Subject: [PATCH 0450/1174] rename pending tasks api endpoint to be more in line with others --- modules/progress.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/modules/progress.py b/modules/progress.py index 6946fb1bc2a..85255e821f5 100644 --- a/modules/progress.py +++ b/modules/progress.py @@ -74,9 +74,10 @@ class ProgressResponse(BaseModel): def setup_progress_api(app): - app.add_api_route("/internal/pendingTasks", get_pending_tasks, methods=["GET"]) + app.add_api_route("/internal/pending-tasks", get_pending_tasks, methods=["GET"]) return app.add_api_route("/internal/progress", progressapi, methods=["POST"], response_model=ProgressResponse) + def get_pending_tasks(): pending_tasks_ids = list(pending_tasks) pending_len = len(pending_tasks_ids) From a97832033427096072d5ea914adac3662cda4fd1 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Sat, 16 Dec 2023 19:39:43 +0800 Subject: [PATCH 0451/1174] Let fp8-related settings to invalidate cond_cache --- modules/processing.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/modules/processing.py b/modules/processing.py index dd97b4eee80..9351e3fb019 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -431,6 +431,8 @@ def cached_params(self, required_prompts, steps, extra_network_data, hires_steps opts.sdxl_crop_top, self.width, self.height, + opts.fp8_storage, + opts.cache_fp16_weight, ) def get_conds_with_caching(self, function, required_prompts, steps, caches, extra_network_data, hires_steps=None): From 98c5fa92015837706adfd9975d5f345ab74f1c99 Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Sat, 16 Dec 2023 22:14:39 +0900 Subject: [PATCH 0452/1174] fix extras caption BLIP #14328 --- scripts/postprocessing_caption.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/postprocessing_caption.py b/scripts/postprocessing_caption.py index 243e3ad9c62..9482a03ca8b 100644 --- a/scripts/postprocessing_caption.py +++ b/scripts/postprocessing_caption.py @@ -25,6 +25,6 @@ def process(self, pp: scripts_postprocessing.PostprocessedImage, enable, option) captions.append(deepbooru.model.tag(pp.image)) if "BLIP" in option: - captions.append(shared.interrogator.generate_caption(pp.image)) + captions.append(shared.interrogator.interrogate(pp.image.convert("RGB"))) pp.caption = ", ".join([x for x in captions if x]) From de03882d6ca56bc81058f5120f028678a6a54aaa Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sun, 17 Dec 2023 08:55:35 +0300 Subject: [PATCH 0453/1174] make task ids for API work without force_task_id --- modules/api/api.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/modules/api/api.py b/modules/api/api.py index 9637cb81b77..7154c9d5b7b 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -336,9 +336,8 @@ def init_script_args(self, request, default_script_args, selectable_scripts, sel return script_args def text2imgapi(self, txt2imgreq: models.StableDiffusionTxt2ImgProcessingAPI): - task_id = create_task_id("text2img") - if txt2imgreq.force_task_id is None: - task_id = txt2imgreq.force_task_id + task_id = txt2imgreq.force_task_id or create_task_id("txt2img") + script_runner = scripts.scripts_txt2img if not script_runner.scripts: script_runner.initialize_scripts(False) @@ -393,9 +392,7 @@ def text2imgapi(self, txt2imgreq: models.StableDiffusionTxt2ImgProcessingAPI): return models.TextToImageResponse(images=b64images, parameters=vars(txt2imgreq), info=processed.js()) def img2imgapi(self, img2imgreq: models.StableDiffusionImg2ImgProcessingAPI): - task_id = create_task_id("img2img") - if img2imgreq.force_task_id is None: - task_id = img2imgreq.force_task_id + task_id = img2imgreq.force_task_id or create_task_id("img2img") init_images = img2imgreq.init_images if init_images is None: From 10945aa41a158ee03727c5ea77d4ffff6b5370f0 Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Mon, 18 Dec 2023 15:27:41 +0900 Subject: [PATCH 0454/1174] only rewrite ui-config when there is change and a typo --- modules/ui.py | 4 +++- modules/ui_loadsave.py | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/modules/ui.py b/modules/ui.py index d80486dd4ac..f02b55116f5 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1086,6 +1086,7 @@ def get_textual_inversion_template_names(): ) loadsave = ui_loadsave.UiLoadsave(cmd_opts.ui_config_file) + ui_settings_from_file = loadsave.ui_settings.copy() settings = ui_settings.UiSettings() settings.create_ui(loadsave, dummy_component) @@ -1146,7 +1147,8 @@ def get_textual_inversion_template_names(): modelmerger_ui.setup_ui(dummy_component=dummy_component, sd_model_checkpoint_component=settings.component_dict['sd_model_checkpoint']) - loadsave.dump_defaults() + if ui_settings_from_file != loadsave.ui_settings: + loadsave.dump_defaults() demo.ui_loadsave = loadsave return demo diff --git a/modules/ui_loadsave.py b/modules/ui_loadsave.py index 7826786ccde..693ff75c5d8 100644 --- a/modules/ui_loadsave.py +++ b/modules/ui_loadsave.py @@ -144,7 +144,7 @@ def write_to_file(self, current_ui_settings): json.dump(current_ui_settings, file, indent=4, ensure_ascii=False) def dump_defaults(self): - """saves default values to a file unless tjhe file is present and there was an error loading default values at start""" + """saves default values to a file unless the file is present and there was an error loading default values at start""" if self.error_loading and os.path.exists(self.filename): return From e4b4a9c4acf0ca375a8603f7f52fde8467b2d266 Mon Sep 17 00:00:00 2001 From: Nuullll Date: Mon, 18 Dec 2023 18:00:01 +0800 Subject: [PATCH 0455/1174] [IPEX] Slice SDPA into smaller chunks --- modules/xpu_specific.py | 66 +++++++++++++++++++++++++++++++++++++++-- 1 file changed, 64 insertions(+), 2 deletions(-) diff --git a/modules/xpu_specific.py b/modules/xpu_specific.py index d8da94a0efd..0ebdd5964f1 100644 --- a/modules/xpu_specific.py +++ b/modules/xpu_specific.py @@ -27,6 +27,68 @@ def torch_xpu_gc(): has_xpu = check_for_xpu() + +# Arc GPU cannot allocate a single block larger than 4GB: https://github.com/intel/compute-runtime/issues/627 +# Here we implement a slicing algorithm to split large batch size into smaller chunks, +# so that SDPA of each chunk wouldn't require any allocation larger than ARC_SINGLE_ALLOCATION_LIMIT. +# The heuristic limit (TOTAL_VRAM // 8) is tuned for Intel Arc A770 16G and Arc A750 8G, +# which is the best trade-off between VRAM usage and performance. +ARC_SINGLE_ALLOCATION_LIMIT = min(torch.xpu.get_device_properties(shared.cmd_opts.device_id).total_memory // 8, 4 * 1024 * 1024 * 1024) +orig_sdp_attn_func = torch.nn.functional.scaled_dot_product_attention +def torch_xpu_scaled_dot_product_attention( + query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, *args, **kwargs +): + # cast to same dtype first + key = key.to(query.dtype) + value = value.to(query.dtype) + + N = query.shape[:-2] # Batch size + L = query.size(-2) # Target sequence length + E = query.size(-1) # Embedding dimension of the query and key + S = key.size(-2) # Source sequence length + Ev = value.size(-1) # Embedding dimension of the value + + total_batch_size = torch.numel(torch.empty(N)) + batch_size_limit = max(1, ARC_SINGLE_ALLOCATION_LIMIT // (L * S * query.element_size())) + + if total_batch_size <= batch_size_limit: + return orig_sdp_attn_func( + query, + key, + value, + attn_mask, + dropout_p, + is_causal, + *args, **kwargs + ) + + query = torch.reshape(query, (-1, L, E)) + key = torch.reshape(key, (-1, S, E)) + value = torch.reshape(value, (-1, S, Ev)) + if attn_mask is not None: + attn_mask = attn_mask.view(-1, L, S) + chunk_count = (total_batch_size + batch_size_limit - 1) // batch_size_limit + outputs = [] + for i in range(chunk_count): + attn_mask_chunk = ( + None + if attn_mask is None + else attn_mask[i * batch_size_limit : (i + 1) * batch_size_limit, :, :] + ) + chunk_output = orig_sdp_attn_func( + query[i * batch_size_limit : (i + 1) * batch_size_limit, :, :], + key[i * batch_size_limit : (i + 1) * batch_size_limit, :, :], + value[i * batch_size_limit : (i + 1) * batch_size_limit, :, :], + attn_mask_chunk, + dropout_p, + is_causal, + *args, **kwargs + ) + outputs.append(chunk_output) + result = torch.cat(outputs, dim=0) + return torch.reshape(result, (*N, L, Ev)) + + if has_xpu: # W/A for https://github.com/intel/intel-extension-for-pytorch/issues/452: torch.Generator API doesn't support XPU device CondFunc('torch.Generator', @@ -55,5 +117,5 @@ def torch_xpu_gc(): lambda orig_func, tensors, dim=0, out=None: orig_func([t.to(tensors[0].dtype) for t in tensors], dim=dim, out=out), lambda orig_func, tensors, dim=0, out=None: not all(t.dtype == tensors[0].dtype for t in tensors)) CondFunc('torch.nn.functional.scaled_dot_product_attention', - lambda orig_func, query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False: orig_func(query, key.to(query.dtype), value.to(query.dtype), attn_mask, dropout_p, is_causal), - lambda orig_func, query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False: query.dtype != key.dtype or query.dtype != value.dtype) + lambda orig_func, *args, **kwargs: torch_xpu_scaled_dot_product_attention(*args, **kwargs), + lambda orig_func, query, *args, **kwargs: query.is_xpu) From f586f4973a0f715e30b42242bb0e6b3f88c37d90 Mon Sep 17 00:00:00 2001 From: Nuullll Date: Mon, 18 Dec 2023 19:44:52 +0800 Subject: [PATCH 0456/1174] Fix device id --- modules/xpu_specific.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/modules/xpu_specific.py b/modules/xpu_specific.py index 0ebdd5964f1..f7687a66cc2 100644 --- a/modules/xpu_specific.py +++ b/modules/xpu_specific.py @@ -33,7 +33,7 @@ def torch_xpu_gc(): # so that SDPA of each chunk wouldn't require any allocation larger than ARC_SINGLE_ALLOCATION_LIMIT. # The heuristic limit (TOTAL_VRAM // 8) is tuned for Intel Arc A770 16G and Arc A750 8G, # which is the best trade-off between VRAM usage and performance. -ARC_SINGLE_ALLOCATION_LIMIT = min(torch.xpu.get_device_properties(shared.cmd_opts.device_id).total_memory // 8, 4 * 1024 * 1024 * 1024) +ARC_SINGLE_ALLOCATION_LIMIT = {} orig_sdp_attn_func = torch.nn.functional.scaled_dot_product_attention def torch_xpu_scaled_dot_product_attention( query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, *args, **kwargs @@ -49,7 +49,10 @@ def torch_xpu_scaled_dot_product_attention( Ev = value.size(-1) # Embedding dimension of the value total_batch_size = torch.numel(torch.empty(N)) - batch_size_limit = max(1, ARC_SINGLE_ALLOCATION_LIMIT // (L * S * query.element_size())) + device_id = query.device.index + if device_id not in ARC_SINGLE_ALLOCATION_LIMIT: + ARC_SINGLE_ALLOCATION_LIMIT[device_id] = min(torch.xpu.get_device_properties(device_id).total_memory // 8, 4 * 1024 * 1024 * 1024) + batch_size_limit = max(1, ARC_SINGLE_ALLOCATION_LIMIT[device_id] // (L * S * query.element_size())) if total_batch_size <= batch_size_limit: return orig_sdp_attn_func( From fe4d084390d35de49baf83c3319a72d71f540aee Mon Sep 17 00:00:00 2001 From: Muhammad Rehan Aslam <19831661+ranareehanaslam@users.noreply.github.com> Date: Mon, 18 Dec 2023 17:50:00 +0500 Subject: [PATCH 0457/1174] Update webui.py Added (Fixed) IPV6 Functionality When there is No Webui Argument Passed --- webui.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/webui.py b/webui.py index 9ed20b30672..3b587dc4150 100644 --- a/webui.py +++ b/webui.py @@ -39,7 +39,7 @@ def api_only(): print(f"Startup time: {startup_timer.summary()}.") api.launch( - server_name="0.0.0.0" if cmd_opts.listen else "127.0.0.1", + server_name = cmd_opts.server_name if cmd_opts.server_name else ("0.0.0.0" if cmd_opts.listen else "127.0.0.1"), port=cmd_opts.port if cmd_opts.port else 7861, root_path=f"/{cmd_opts.subpath}" if cmd_opts.subpath else "" ) From 0d5941edbc4c602b760d4200bd76e044c65a0e40 Mon Sep 17 00:00:00 2001 From: Muhammad Rehan Aslam <19831661+ranareehanaslam@users.noreply.github.com> Date: Tue, 19 Dec 2023 09:50:38 +0500 Subject: [PATCH 0458/1174] Update webui.py Co-authored-by: Aarni Koskela --- webui.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/webui.py b/webui.py index 3b587dc4150..3f5b67a3060 100644 --- a/webui.py +++ b/webui.py @@ -39,7 +39,7 @@ def api_only(): print(f"Startup time: {startup_timer.summary()}.") api.launch( - server_name = cmd_opts.server_name if cmd_opts.server_name else ("0.0.0.0" if cmd_opts.listen else "127.0.0.1"), + server_name=cmd_opts.server_name or ("0.0.0.0" if cmd_opts.listen else "127.0.0.1"), port=cmd_opts.port if cmd_opts.port else 7861, root_path=f"/{cmd_opts.subpath}" if cmd_opts.subpath else "" ) From 3e068de0dcec811d515402fc184f70709a785e4f Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Tue, 19 Dec 2023 18:48:49 +0900 Subject: [PATCH 0459/1174] reorder training preprocessing modules in extras tab using the order from before the rework 11d23e8ca55c097ecfa255a05b63f194e25f08be --- scripts/postprocessing_caption.py | 2 +- scripts/postprocessing_create_flipped_copies.py | 2 +- scripts/postprocessing_focal_crop.py | 2 +- scripts/processing_autosized_crop.py | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/scripts/postprocessing_caption.py b/scripts/postprocessing_caption.py index 9482a03ca8b..5592a89870e 100644 --- a/scripts/postprocessing_caption.py +++ b/scripts/postprocessing_caption.py @@ -4,7 +4,7 @@ class ScriptPostprocessingCeption(scripts_postprocessing.ScriptPostprocessing): name = "Caption" - order = 4000 + order = 4040 def ui(self): with ui_components.InputAccordion(False, label="Caption") as enable: diff --git a/scripts/postprocessing_create_flipped_copies.py b/scripts/postprocessing_create_flipped_copies.py index 3425571dc3b..b673003b6ea 100644 --- a/scripts/postprocessing_create_flipped_copies.py +++ b/scripts/postprocessing_create_flipped_copies.py @@ -6,7 +6,7 @@ class ScriptPostprocessingCreateFlippedCopies(scripts_postprocessing.ScriptPostprocessing): name = "Create flipped copies" - order = 4000 + order = 4030 def ui(self): with ui_components.InputAccordion(False, label="Create flipped copies") as enable: diff --git a/scripts/postprocessing_focal_crop.py b/scripts/postprocessing_focal_crop.py index d3baf29878a..cff1dbc5470 100644 --- a/scripts/postprocessing_focal_crop.py +++ b/scripts/postprocessing_focal_crop.py @@ -7,7 +7,7 @@ class ScriptPostprocessingFocalCrop(scripts_postprocessing.ScriptPostprocessing): name = "Auto focal point crop" - order = 4000 + order = 4010 def ui(self): with ui_components.InputAccordion(False, label="Auto focal point crop") as enable: diff --git a/scripts/processing_autosized_crop.py b/scripts/processing_autosized_crop.py index c098022645d..7e674989814 100644 --- a/scripts/processing_autosized_crop.py +++ b/scripts/processing_autosized_crop.py @@ -28,7 +28,7 @@ def multicrop_pic(image: Image, mindim, maxdim, minarea, maxarea, objective, thr class ScriptPostprocessingAutosizedCrop(scripts_postprocessing.ScriptPostprocessing): name = "Auto-sized crop" - order = 4000 + order = 4020 def ui(self): with ui_components.InputAccordion(False, label="Auto-sized crop") as enable: From 9feb034e343d6d7ef63395821658fb3774b30a24 Mon Sep 17 00:00:00 2001 From: wangqyqq Date: Thu, 21 Dec 2023 20:15:51 +0800 Subject: [PATCH 0460/1174] support for sdxl-inpaint model --- configs/sd_xl_inpaint.yaml | 98 +++++++++++++++++++++++++++++++++++++ modules/processing.py | 19 +++++++ modules/sd_models_config.py | 6 ++- modules/sd_models_xl.py | 5 ++ 4 files changed, 127 insertions(+), 1 deletion(-) create mode 100644 configs/sd_xl_inpaint.yaml diff --git a/configs/sd_xl_inpaint.yaml b/configs/sd_xl_inpaint.yaml new file mode 100644 index 00000000000..3bad372186f --- /dev/null +++ b/configs/sd_xl_inpaint.yaml @@ -0,0 +1,98 @@ +model: + target: sgm.models.diffusion.DiffusionEngine + params: + scale_factor: 0.13025 + disable_first_stage_autocast: True + + denoiser_config: + target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser + params: + num_idx: 1000 + + weighting_config: + target: sgm.modules.diffusionmodules.denoiser_weighting.EpsWeighting + scaling_config: + target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling + discretization_config: + target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization + + network_config: + target: sgm.modules.diffusionmodules.openaimodel.UNetModel + params: + adm_in_channels: 2816 + num_classes: sequential + use_checkpoint: True + in_channels: 9 + out_channels: 4 + model_channels: 320 + attention_resolutions: [4, 2] + num_res_blocks: 2 + channel_mult: [1, 2, 4] + num_head_channels: 64 + use_spatial_transformer: True + use_linear_in_transformer: True + transformer_depth: [1, 2, 10] # note: the first is unused (due to attn_res starting at 2) 32, 16, 8 --> 64, 32, 16 + context_dim: 2048 + spatial_transformer_attn_type: softmax-xformers + legacy: False + + conditioner_config: + target: sgm.modules.GeneralConditioner + params: + emb_models: + # crossattn cond + - is_trainable: False + input_key: txt + target: sgm.modules.encoders.modules.FrozenCLIPEmbedder + params: + layer: hidden + layer_idx: 11 + # crossattn and vector cond + - is_trainable: False + input_key: txt + target: sgm.modules.encoders.modules.FrozenOpenCLIPEmbedder2 + params: + arch: ViT-bigG-14 + version: laion2b_s39b_b160k + freeze: True + layer: penultimate + always_return_pooled: True + legacy: False + # vector cond + - is_trainable: False + input_key: original_size_as_tuple + target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND + params: + outdim: 256 # multiplied by two + # vector cond + - is_trainable: False + input_key: crop_coords_top_left + target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND + params: + outdim: 256 # multiplied by two + # vector cond + - is_trainable: False + input_key: target_size_as_tuple + target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND + params: + outdim: 256 # multiplied by two + + first_stage_config: + target: sgm.models.autoencoder.AutoencoderKLInferenceWrapper + params: + embed_dim: 4 + monitor: val/rec_loss + ddconfig: + attn_type: vanilla-xformers + double_z: true + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: [1, 2, 4, 4] + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity diff --git a/modules/processing.py b/modules/processing.py index 6f01c95f5b6..159548dba7d 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -106,6 +106,20 @@ def txt2img_image_conditioning(sd_model, x, width, height): return x.new_zeros(x.shape[0], 2*sd_model.noise_augmentor.time_embed.dim, dtype=x.dtype, device=x.device) else: + sd = sd_model.model.state_dict() + diffusion_model_input = sd.get('diffusion_model.input_blocks.0.0.weight', None) + if diffusion_model_input.shape[1] == 9: + # The "masked-image" in this case will just be all 0.5 since the entire image is masked. + image_conditioning = torch.ones(x.shape[0], 3, height, width, device=x.device) * 0.5 + image_conditioning = images_tensor_to_samples(image_conditioning, + approximation_indexes.get(opts.sd_vae_encode_method)) + + # Add the fake full 1s mask to the first dimension. + image_conditioning = torch.nn.functional.pad(image_conditioning, (0, 0, 0, 0, 1, 0), value=1.0) + image_conditioning = image_conditioning.to(x.dtype) + + return image_conditioning + # Dummy zero conditioning if we're not using inpainting or unclip models. # Still takes up a bit of memory, but no encoder call. # Pretty sure we can just make this a 1x1 image since its not going to be used besides its batch size. @@ -362,6 +376,11 @@ def img2img_image_conditioning(self, source_image, latent_image, image_mask=None if self.sampler.conditioning_key == "crossattn-adm": return self.unclip_image_conditioning(source_image) + sd = self.sampler.model_wrap.inner_model.model.state_dict() + diffusion_model_input = sd.get('diffusion_model.input_blocks.0.0.weight', None) + if diffusion_model_input.shape[1] == 9: + return self.inpainting_image_conditioning(source_image, latent_image, image_mask=image_mask) + # Dummy zero conditioning if we're not using inpainting or depth model. return latent_image.new_zeros(latent_image.shape[0], 5, 1, 1) diff --git a/modules/sd_models_config.py b/modules/sd_models_config.py index deab2f6e237..b38137eb5a9 100644 --- a/modules/sd_models_config.py +++ b/modules/sd_models_config.py @@ -15,6 +15,7 @@ config_sd2_inpainting = os.path.join(sd_repo_configs_path, "v2-inpainting-inference.yaml") config_sdxl = os.path.join(sd_xl_repo_configs_path, "sd_xl_base.yaml") config_sdxl_refiner = os.path.join(sd_xl_repo_configs_path, "sd_xl_refiner.yaml") +config_sdxl_inpainting = os.path.join(sd_configs_path, "sd_xl_inpaint.yaml") config_depth_model = os.path.join(sd_repo_configs_path, "v2-midas-inference.yaml") config_unclip = os.path.join(sd_repo_configs_path, "v2-1-stable-unclip-l-inference.yaml") config_unopenclip = os.path.join(sd_repo_configs_path, "v2-1-stable-unclip-h-inference.yaml") @@ -71,7 +72,10 @@ def guess_model_config_from_state_dict(sd, filename): sd2_variations_weight = sd.get('embedder.model.ln_final.weight', None) if sd.get('conditioner.embedders.1.model.ln_final.weight', None) is not None: - return config_sdxl + if diffusion_model_input.shape[1] == 9: + return config_sdxl_inpainting + else: + return config_sdxl if sd.get('conditioner.embedders.0.model.ln_final.weight', None) is not None: return config_sdxl_refiner elif sd.get('depth_model.model.pretrained.act_postprocess3.0.project.0.bias', None) is not None: diff --git a/modules/sd_models_xl.py b/modules/sd_models_xl.py index 0112332161f..d8a9a73bcb3 100644 --- a/modules/sd_models_xl.py +++ b/modules/sd_models_xl.py @@ -34,6 +34,11 @@ def get_learned_conditioning(self: sgm.models.diffusion.DiffusionEngine, batch: def apply_model(self: sgm.models.diffusion.DiffusionEngine, x, t, cond): + sd = self.model.state_dict() + diffusion_model_input = sd.get('diffusion_model.input_blocks.0.0.weight', None) + if diffusion_model_input.shape[1] == 9: + x = torch.cat([x] + cond['c_concat'], dim=1) + return self.model(x, t, cond) From de1809bd14450cfc41623b6021c7087fb385ab6f Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Fri, 22 Dec 2023 00:08:35 +0900 Subject: [PATCH 0461/1174] handle axis_type is None --- scripts/xyz_grid.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/scripts/xyz_grid.py b/scripts/xyz_grid.py index b2250c04dc0..e5083874af9 100644 --- a/scripts/xyz_grid.py +++ b/scripts/xyz_grid.py @@ -476,6 +476,8 @@ def fill(axis_type, csv_mode): fill_z_button.click(fn=fill, inputs=[z_type, csv_mode], outputs=[z_values, z_values_dropdown]) def select_axis(axis_type, axis_values, axis_values_dropdown, csv_mode): + axis_type = axis_type or 0 # if axle type is None set to 0 + choices = self.current_axis_options[axis_type].choices has_choices = choices is not None @@ -526,6 +528,8 @@ def get_dropdown_update_from_params(axis, params): return [x_type, x_values, x_values_dropdown, y_type, y_values, y_values_dropdown, z_type, z_values, z_values_dropdown, draw_legend, include_lone_images, include_sub_grids, no_fixed_seeds, margin_size, csv_mode] def run(self, p, x_type, x_values, x_values_dropdown, y_type, y_values, y_values_dropdown, z_type, z_values, z_values_dropdown, draw_legend, include_lone_images, include_sub_grids, no_fixed_seeds, margin_size, csv_mode): + x_type, y_type, z_type = x_type or 0, y_type or 0, z_type or 0 # if axle type is None set to 0 + if not no_fixed_seeds: modules.processing.fix_seed(p) From edfae95d90a49ea95394b772817a59dde4175222 Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Sat, 23 Dec 2023 01:21:00 +0900 Subject: [PATCH 0462/1174] prevent crash due to Script __init__ exception --- modules/scripts.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/modules/scripts.py b/modules/scripts.py index b6fcf96e047..3a76691186b 100644 --- a/modules/scripts.py +++ b/modules/scripts.py @@ -566,7 +566,12 @@ def initialize_scripts(self, is_img2img): auto_processing_scripts = scripts_auto_postprocessing.create_auto_preprocessing_script_data() for script_data in auto_processing_scripts + scripts_data: - script = script_data.script_class() + try: + script = script_data.script_class() + except Exception: + errors.report(f"Error # failed to initialize Script {script_data.module}: ", exc_info=True) + continue + script.filename = script_data.path script.is_txt2img = not is_img2img script.is_img2img = is_img2img From 00d4a4d4ac75903d8224e9beb1136584dd66fcd8 Mon Sep 17 00:00:00 2001 From: lanyeeee <1210347077@qq.com> Date: Tue, 26 Dec 2023 14:46:29 +0800 Subject: [PATCH 0463/1174] move thread-unsafe code to __init__ --- modules/api/api.py | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/modules/api/api.py b/modules/api/api.py index 7154c9d5b7b..f0a68c672d6 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -251,6 +251,15 @@ def __init__(self, app: FastAPI, queue_lock: Lock): self.default_script_arg_txt2img = [] self.default_script_arg_img2img = [] + script_runner = scripts.scripts_img2img + if not script_runner.scripts: + script_runner.initialize_scripts(True) + ui.create_ui() + if not self.default_script_arg_txt2img: + self.default_script_arg_txt2img = self.init_default_script_args(script_runner) + if not self.default_script_arg_img2img: + self.default_script_arg_img2img = self.init_default_script_args(script_runner) + def add_api_route(self, path: str, endpoint, **kwargs): if shared.cmd_opts.api_auth: return self.app.add_api_route(path, endpoint, dependencies=[Depends(self.auth)], **kwargs) @@ -339,11 +348,6 @@ def text2imgapi(self, txt2imgreq: models.StableDiffusionTxt2ImgProcessingAPI): task_id = txt2imgreq.force_task_id or create_task_id("txt2img") script_runner = scripts.scripts_txt2img - if not script_runner.scripts: - script_runner.initialize_scripts(False) - ui.create_ui() - if not self.default_script_arg_txt2img: - self.default_script_arg_txt2img = self.init_default_script_args(script_runner) selectable_scripts, selectable_script_idx = self.get_selectable_script(txt2imgreq.script_name, script_runner) populate = txt2imgreq.copy(update={ # Override __init__ params @@ -403,11 +407,6 @@ def img2imgapi(self, img2imgreq: models.StableDiffusionImg2ImgProcessingAPI): mask = decode_base64_to_image(mask) script_runner = scripts.scripts_img2img - if not script_runner.scripts: - script_runner.initialize_scripts(True) - ui.create_ui() - if not self.default_script_arg_img2img: - self.default_script_arg_img2img = self.init_default_script_args(script_runner) selectable_scripts, selectable_script_idx = self.get_selectable_script(img2imgreq.script_name, script_runner) populate = img2imgreq.copy(update={ # Override __init__ params From bfe418a58d39c69ca2672e7d8a1fd7ad2b34869b Mon Sep 17 00:00:00 2001 From: wangqyqq Date: Wed, 27 Dec 2023 10:20:56 +0800 Subject: [PATCH 0464/1174] add some codes for robust --- modules/processing.py | 24 +++++++++++++----------- modules/sd_models_xl.py | 5 +++-- 2 files changed, 16 insertions(+), 13 deletions(-) diff --git a/modules/processing.py b/modules/processing.py index 159548dba7d..c05e608ab99 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -108,17 +108,18 @@ def txt2img_image_conditioning(sd_model, x, width, height): else: sd = sd_model.model.state_dict() diffusion_model_input = sd.get('diffusion_model.input_blocks.0.0.weight', None) - if diffusion_model_input.shape[1] == 9: - # The "masked-image" in this case will just be all 0.5 since the entire image is masked. - image_conditioning = torch.ones(x.shape[0], 3, height, width, device=x.device) * 0.5 - image_conditioning = images_tensor_to_samples(image_conditioning, - approximation_indexes.get(opts.sd_vae_encode_method)) + if diffusion_model_input is not None: + if diffusion_model_input.shape[1] == 9: + # The "masked-image" in this case will just be all 0.5 since the entire image is masked. + image_conditioning = torch.ones(x.shape[0], 3, height, width, device=x.device) * 0.5 + image_conditioning = images_tensor_to_samples(image_conditioning, + approximation_indexes.get(opts.sd_vae_encode_method)) - # Add the fake full 1s mask to the first dimension. - image_conditioning = torch.nn.functional.pad(image_conditioning, (0, 0, 0, 0, 1, 0), value=1.0) - image_conditioning = image_conditioning.to(x.dtype) + # Add the fake full 1s mask to the first dimension. + image_conditioning = torch.nn.functional.pad(image_conditioning, (0, 0, 0, 0, 1, 0), value=1.0) + image_conditioning = image_conditioning.to(x.dtype) - return image_conditioning + return image_conditioning # Dummy zero conditioning if we're not using inpainting or unclip models. # Still takes up a bit of memory, but no encoder call. @@ -378,8 +379,9 @@ def img2img_image_conditioning(self, source_image, latent_image, image_mask=None sd = self.sampler.model_wrap.inner_model.model.state_dict() diffusion_model_input = sd.get('diffusion_model.input_blocks.0.0.weight', None) - if diffusion_model_input.shape[1] == 9: - return self.inpainting_image_conditioning(source_image, latent_image, image_mask=image_mask) + if diffusion_model_input is not None: + if diffusion_model_input.shape[1] == 9: + return self.inpainting_image_conditioning(source_image, latent_image, image_mask=image_mask) # Dummy zero conditioning if we're not using inpainting or depth model. return latent_image.new_zeros(latent_image.shape[0], 5, 1, 1) diff --git a/modules/sd_models_xl.py b/modules/sd_models_xl.py index d8a9a73bcb3..162d0fee892 100644 --- a/modules/sd_models_xl.py +++ b/modules/sd_models_xl.py @@ -36,8 +36,9 @@ def get_learned_conditioning(self: sgm.models.diffusion.DiffusionEngine, batch: def apply_model(self: sgm.models.diffusion.DiffusionEngine, x, t, cond): sd = self.model.state_dict() diffusion_model_input = sd.get('diffusion_model.input_blocks.0.0.weight', None) - if diffusion_model_input.shape[1] == 9: - x = torch.cat([x] + cond['c_concat'], dim=1) + if diffusion_model_input is not None: + if diffusion_model_input.shape[1] == 9: + x = torch.cat([x] + cond['c_concat'], dim=1) return self.model(x, t, cond) From de04573438bc111f137359b8f4998780bf315275 Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Thu, 28 Dec 2023 06:22:51 +0900 Subject: [PATCH 0465/1174] create utility truncate_path utli.truncate_path(target_path, base_path) return the target_path relative to base_path if target_path is a sub path of base_path else return the absolute path --- modules/util.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/modules/util.py b/modules/util.py index 60afc0670c7..4861bcb0861 100644 --- a/modules/util.py +++ b/modules/util.py @@ -2,7 +2,7 @@ import re from modules import shared -from modules.paths_internal import script_path +from modules.paths_internal import script_path, cwd def natural_sort_key(s, regex=re.compile('([0-9]+)')): @@ -56,3 +56,13 @@ def ldm_print(*args, **kwargs): return print(*args, **kwargs) + + +def truncate_path(target_path, base_path=cwd): + abs_target, abs_base = os.path.abspath(target_path), os.path.abspath(base_path) + try: + if os.path.commonpath([abs_target, abs_base]) == abs_base: + return os.path.relpath(abs_target, abs_base) + except ValueError: + pass + return abs_target From af2951ed53da6d357aea9232538f9ea7e1cdc648 Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Thu, 28 Dec 2023 06:52:33 +0900 Subject: [PATCH 0466/1174] base default image output on data_path Co-Authored-By: Alberto Cano <34340962+canoalberto@users.noreply.github.com> --- modules/paths_internal.py | 1 + modules/shared_options.py | 19 ++++++++++--------- 2 files changed, 11 insertions(+), 9 deletions(-) diff --git a/modules/paths_internal.py b/modules/paths_internal.py index 89131a54fa1..b86ecd7f192 100644 --- a/modules/paths_internal.py +++ b/modules/paths_internal.py @@ -28,5 +28,6 @@ extensions_dir = os.path.join(data_path, "extensions") extensions_builtin_dir = os.path.join(script_path, "extensions-builtin") config_states_dir = os.path.join(script_path, "config_states") +default_output_dir = os.path.join(data_path, "output") roboto_ttf_file = os.path.join(modules_path, 'Roboto-Regular.ttf') diff --git a/modules/shared_options.py b/modules/shared_options.py index fa542ba8cd6..752a4f1259a 100644 --- a/modules/shared_options.py +++ b/modules/shared_options.py @@ -1,7 +1,8 @@ +import os import gradio as gr -from modules import localization, ui_components, shared_items, shared, interrogate, shared_gradio_themes -from modules.paths_internal import models_path, script_path, data_path, sd_configs_path, sd_default_config, sd_model_file, default_sd_model_file, extensions_dir, extensions_builtin_dir # noqa: F401 +from modules import localization, ui_components, shared_items, shared, interrogate, shared_gradio_themes, util +from modules.paths_internal import models_path, script_path, data_path, sd_configs_path, sd_default_config, sd_model_file, default_sd_model_file, extensions_dir, extensions_builtin_dir, default_output_dir # noqa: F401 from modules.shared_cmd_options import cmd_opts from modules.options import options_section, OptionInfo, OptionHTML, categories @@ -74,14 +75,14 @@ options_templates.update(options_section(('saving-paths', "Paths for saving", "saving"), { "outdir_samples": OptionInfo("", "Output directory for images; if empty, defaults to three directories below", component_args=hide_dirs), - "outdir_txt2img_samples": OptionInfo("outputs/txt2img-images", 'Output directory for txt2img images', component_args=hide_dirs), - "outdir_img2img_samples": OptionInfo("outputs/img2img-images", 'Output directory for img2img images', component_args=hide_dirs), - "outdir_extras_samples": OptionInfo("outputs/extras-images", 'Output directory for images from extras tab', component_args=hide_dirs), + "outdir_txt2img_samples": OptionInfo(util.truncate_path(os.path.join(default_output_dir, 'txt2img-images')), 'Output directory for txt2img images', component_args=hide_dirs), + "outdir_img2img_samples": OptionInfo(util.truncate_path(os.path.join(default_output_dir, 'img2img-images')), 'Output directory for img2img images', component_args=hide_dirs), + "outdir_extras_samples": OptionInfo(util.truncate_path(os.path.join(default_output_dir, 'extras-images')), 'Output directory for images from extras tab', component_args=hide_dirs), "outdir_grids": OptionInfo("", "Output directory for grids; if empty, defaults to two directories below", component_args=hide_dirs), - "outdir_txt2img_grids": OptionInfo("outputs/txt2img-grids", 'Output directory for txt2img grids', component_args=hide_dirs), - "outdir_img2img_grids": OptionInfo("outputs/img2img-grids", 'Output directory for img2img grids', component_args=hide_dirs), - "outdir_save": OptionInfo("log/images", "Directory for saving images using the Save button", component_args=hide_dirs), - "outdir_init_images": OptionInfo("outputs/init-images", "Directory for saving init images when using img2img", component_args=hide_dirs), + "outdir_txt2img_grids": OptionInfo(util.truncate_path(os.path.join(default_output_dir, 'txt2img-grids')), 'Output directory for txt2img grids', component_args=hide_dirs), + "outdir_img2img_grids": OptionInfo(util.truncate_path(os.path.join(default_output_dir, 'img2img-grids')), 'Output directory for img2img grids', component_args=hide_dirs), + "outdir_save": OptionInfo(util.truncate_path(os.path.join(data_path, 'log', 'images')), "Directory for saving images using the Save button", component_args=hide_dirs), + "outdir_init_images": OptionInfo(util.truncate_path(os.path.join(default_output_dir, 'init-images')), "Directory for saving init images when using img2img", component_args=hide_dirs), })) options_templates.update(options_section(('saving-to-dirs', "Saving to a directory", "saving"), { From 892e703b59b2f867d8a202a52fab1db89882ef86 Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Thu, 28 Dec 2023 06:52:41 +0900 Subject: [PATCH 0467/1174] webpath use truncate_path --- modules/ui_gradio_extensions.py | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/modules/ui_gradio_extensions.py b/modules/ui_gradio_extensions.py index 0d368f8b2c4..a86c368ef70 100644 --- a/modules/ui_gradio_extensions.py +++ b/modules/ui_gradio_extensions.py @@ -1,17 +1,12 @@ import os import gradio as gr -from modules import localization, shared, scripts -from modules.paths import script_path, data_path, cwd +from modules import localization, shared, scripts, util +from modules.paths import script_path, data_path def webpath(fn): - if fn.startswith(cwd): - web_path = os.path.relpath(fn, cwd) - else: - web_path = os.path.abspath(fn) - - return f'file={web_path}?{os.path.getmtime(fn)}' + return f'file={util.truncate_path(fn)}?{os.path.getmtime(fn)}' def javascript_html(): From dc57ec0296e768ee91290e16ab262404837c566d Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Fri, 29 Dec 2023 01:56:48 +0900 Subject: [PATCH 0468/1174] save info of init image --- modules/processing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/processing.py b/modules/processing.py index 9351e3fb019..141f2f11167 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -1482,7 +1482,7 @@ def init(self, all_prompts, all_seeds, all_subseeds): # Save init image if opts.save_init_img: self.init_img_hash = hashlib.md5(img.tobytes()).hexdigest() - images.save_image(img, path=opts.outdir_init_images, basename=None, forced_filename=self.init_img_hash, save_to_dirs=False) + images.save_image(img, path=opts.outdir_init_images, basename=None, forced_filename=self.init_img_hash, save_to_dirs=False, existing_info=img.info) image = images.flatten(img, opts.img2img_background_color) From bb07cb6a0df60a96827125ffc09ea182a1ed272c Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sun, 17 Dec 2023 10:22:03 +0300 Subject: [PATCH 0469/1174] a --- modules/api/api.py | 27 +++++++++++++ modules/api/models.py | 2 + modules/generation_parameters_copypaste.py | 19 +++++++++ modules/processing.py | 2 +- modules/processing_scripts/refiner.py | 7 ++-- modules/processing_scripts/seed.py | 13 +++--- modules/ui.py | 46 +++++++++++----------- 7 files changed, 83 insertions(+), 33 deletions(-) diff --git a/modules/api/api.py b/modules/api/api.py index 7154c9d5b7b..b3d70940401 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -335,6 +335,29 @@ def init_script_args(self, request, default_script_args, selectable_scripts, sel script_args[alwayson_script.args_from + idx] = request.alwayson_scripts[alwayson_script_name]["args"][idx] return script_args + def apply_infotext(self, request, tabname): + if not request.infotext: + return {} + + params = generation_parameters_copypaste.parse_generation_parameters(request.infotext) + + for field in generation_parameters_copypaste.paste_fields[tabname]["fields"]: + if not field.api: + continue + + value = field.function(params) if field.function else params.get(field.label) + target_type = request.__fields__[field.api].type_ + + if value is None: + continue + + if not isinstance(value, target_type): + value = target_type(value) + + setattr(request, field.api, value) + + return params + def text2imgapi(self, txt2imgreq: models.StableDiffusionTxt2ImgProcessingAPI): task_id = txt2imgreq.force_task_id or create_task_id("txt2img") @@ -342,6 +365,9 @@ def text2imgapi(self, txt2imgreq: models.StableDiffusionTxt2ImgProcessingAPI): if not script_runner.scripts: script_runner.initialize_scripts(False) ui.create_ui() + + infotext_params = self.apply_infotext(txt2imgreq, "txt2img") + if not self.default_script_arg_txt2img: self.default_script_arg_txt2img = self.init_default_script_args(script_runner) selectable_scripts, selectable_script_idx = self.get_selectable_script(txt2imgreq.script_name, script_runner) @@ -358,6 +384,7 @@ def text2imgapi(self, txt2imgreq: models.StableDiffusionTxt2ImgProcessingAPI): args.pop('script_name', None) args.pop('script_args', None) # will refeed them to the pipeline directly after initializing them args.pop('alwayson_scripts', None) + args.pop('infotext', None) script_args = self.init_script_args(txt2imgreq, self.default_script_arg_txt2img, selectable_scripts, selectable_script_idx, script_runner) diff --git a/modules/api/models.py b/modules/api/models.py index 58083a34fb0..16edf11cf83 100644 --- a/modules/api/models.py +++ b/modules/api/models.py @@ -108,6 +108,7 @@ def generate_model(self): {"key": "save_images", "type": bool, "default": False}, {"key": "alwayson_scripts", "type": dict, "default": {}}, {"key": "force_task_id", "type": str, "default": None}, + {"key": "infotext", "type": str, "default": None}, ] ).generate_model() @@ -126,6 +127,7 @@ def generate_model(self): {"key": "save_images", "type": bool, "default": False}, {"key": "alwayson_scripts", "type": dict, "default": {}}, {"key": "force_task_id", "type": str, "default": None}, + {"key": "infotext", "type": str, "default": None}, ] ).generate_model() diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py index dbffe494998..4b4727c4458 100644 --- a/modules/generation_parameters_copypaste.py +++ b/modules/generation_parameters_copypaste.py @@ -28,6 +28,19 @@ def __init__(self, paste_button, tabname, source_text_component=None, source_ima self.paste_field_names = paste_field_names or [] +class PasteField(tuple): + def __new__(cls, component, target, *, api=None): + return super().__new__(cls, (component, target)) + + def __init__(self, component, target, *, api=None): + super().__init__() + + self.api = api + self.component = component + self.label = target if isinstance(target, str) else None + self.function = target if callable(target) else None + + paste_fields: dict[str, dict] = {} registered_param_bindings: list[ParamBinding] = [] @@ -84,6 +97,12 @@ def image_from_url_text(filedata): def add_paste_fields(tabname, init_img, fields, override_settings_component=None): + + if fields: + for i in range(len(fields)): + if not isinstance(fields[i], PasteField): + fields[i] = PasteField(*fields[i]) + paste_fields[tabname] = {"init_img": init_img, "fields": fields, "override_settings_component": override_settings_component} # backwards compatibility for existing extensions diff --git a/modules/processing.py b/modules/processing.py index 9351e3fb019..ee2ccf460d2 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -1135,7 +1135,7 @@ def calculate_target_resolution(self): def init(self, all_prompts, all_seeds, all_subseeds): if self.enable_hr: - if self.hr_checkpoint_name: + if self.hr_checkpoint_name and self.hr_checkpoint_name != 'Use same checkpoint': self.hr_checkpoint_info = sd_models.get_closet_checkpoint_match(self.hr_checkpoint_name) if self.hr_checkpoint_info is None: diff --git a/modules/processing_scripts/refiner.py b/modules/processing_scripts/refiner.py index 29ccb78f903..cefad32b761 100644 --- a/modules/processing_scripts/refiner.py +++ b/modules/processing_scripts/refiner.py @@ -1,6 +1,7 @@ import gradio as gr from modules import scripts, sd_models +from modules.generation_parameters_copypaste import PasteField from modules.ui_common import create_refresh_button from modules.ui_components import InputAccordion @@ -31,9 +32,9 @@ def lookup_checkpoint(title): return None if info is None else info.title self.infotext_fields = [ - (enable_refiner, lambda d: 'Refiner' in d), - (refiner_checkpoint, lambda d: lookup_checkpoint(d.get('Refiner'))), - (refiner_switch_at, 'Refiner switch at'), + PasteField(enable_refiner, lambda d: 'Refiner' in d), + PasteField(refiner_checkpoint, lambda d: lookup_checkpoint(d.get('Refiner')), api="refiner_checkpoint"), + PasteField(refiner_switch_at, 'Refiner switch at', api="refiner_switch_at"), ] return enable_refiner, refiner_checkpoint, refiner_switch_at diff --git a/modules/processing_scripts/seed.py b/modules/processing_scripts/seed.py index dc9c2da5000..a3e16a12e9b 100644 --- a/modules/processing_scripts/seed.py +++ b/modules/processing_scripts/seed.py @@ -3,6 +3,7 @@ import gradio as gr from modules import scripts, ui, errors +from modules.generation_parameters_copypaste import PasteField from modules.shared import cmd_opts from modules.ui_components import ToolButton @@ -51,12 +52,12 @@ def ui(self, is_img2img): seed_checkbox.change(lambda x: gr.update(visible=x), show_progress=False, inputs=[seed_checkbox], outputs=[seed_extras]) self.infotext_fields = [ - (self.seed, "Seed"), - (seed_checkbox, lambda d: "Variation seed" in d or "Seed resize from-1" in d), - (subseed, "Variation seed"), - (subseed_strength, "Variation seed strength"), - (seed_resize_from_w, "Seed resize from-1"), - (seed_resize_from_h, "Seed resize from-2"), + PasteField(self.seed, "Seed", api="seed"), + PasteField(seed_checkbox, lambda d: "Variation seed" in d or "Seed resize from-1" in d), + PasteField(subseed, "Variation seed", api="subseed"), + PasteField(subseed_strength, "Variation seed strength", api="subseed_strength"), + PasteField(seed_resize_from_w, "Seed resize from-1", api="seed_resize_from_h"), + PasteField(seed_resize_from_h, "Seed resize from-2", api="seed_resize_from_w"), ] self.on_after_component(lambda x: connect_reuse_seed(self.seed, reuse_seed, x.component, False), elem_id=f'generation_info_{self.tabname}') diff --git a/modules/ui.py b/modules/ui.py index d80486dd4ac..9db2407ed30 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -28,7 +28,7 @@ import modules.shared as shared from modules import prompt_parser from modules.sd_hijack import model_hijack -from modules.generation_parameters_copypaste import image_from_url_text +from modules.generation_parameters_copypaste import image_from_url_text, PasteField create_setting_component = ui_settings.create_setting_component @@ -436,28 +436,28 @@ def create_ui(): ) txt2img_paste_fields = [ - (toprow.prompt, "Prompt"), - (toprow.negative_prompt, "Negative prompt"), - (steps, "Steps"), - (sampler_name, "Sampler"), - (cfg_scale, "CFG scale"), - (width, "Size-1"), - (height, "Size-2"), - (batch_size, "Batch size"), - (toprow.ui_styles.dropdown, lambda d: d["Styles array"] if isinstance(d.get("Styles array"), list) else gr.update()), - (denoising_strength, "Denoising strength"), - (enable_hr, lambda d: "Denoising strength" in d and ("Hires upscale" in d or "Hires upscaler" in d or "Hires resize-1" in d)), - (hr_scale, "Hires upscale"), - (hr_upscaler, "Hires upscaler"), - (hr_second_pass_steps, "Hires steps"), - (hr_resize_x, "Hires resize-1"), - (hr_resize_y, "Hires resize-2"), - (hr_checkpoint_name, "Hires checkpoint"), - (hr_sampler_name, "Hires sampler"), - (hr_sampler_container, lambda d: gr.update(visible=True) if d.get("Hires sampler", "Use same sampler") != "Use same sampler" or d.get("Hires checkpoint", "Use same checkpoint") != "Use same checkpoint" else gr.update()), - (hr_prompt, "Hires prompt"), - (hr_negative_prompt, "Hires negative prompt"), - (hr_prompts_container, lambda d: gr.update(visible=True) if d.get("Hires prompt", "") != "" or d.get("Hires negative prompt", "") != "" else gr.update()), + PasteField(toprow.prompt, "Prompt", api="prompt"), + PasteField(toprow.negative_prompt, "Negative prompt", api="negative_prompt"), + PasteField(steps, "Steps", api="steps"), + PasteField(sampler_name, "Sampler", api="sampler_name"), + PasteField(cfg_scale, "CFG scale", api="cfg_scale"), + PasteField(width, "Size-1", api="width"), + PasteField(height, "Size-2", api="height"), + PasteField(batch_size, "Batch size", api="batch_size"), + PasteField(toprow.ui_styles.dropdown, lambda d: d["Styles array"] if isinstance(d.get("Styles array"), list) else gr.update(), api="styles"), + PasteField(denoising_strength, "Denoising strength", api="denoising_strength"), + PasteField(enable_hr, lambda d: "Denoising strength" in d and ("Hires upscale" in d or "Hires upscaler" in d or "Hires resize-1" in d), api="enable_hr"), + PasteField(hr_scale, "Hires upscale", api="hr_scale"), + PasteField(hr_upscaler, "Hires upscaler", api="hr_upscaler"), + PasteField(hr_second_pass_steps, "Hires steps", api="hr_second_pass_steps"), + PasteField(hr_resize_x, "Hires resize-1", api="hr_resize_x"), + PasteField(hr_resize_y, "Hires resize-2", api="hr_resize_y"), + PasteField(hr_checkpoint_name, "Hires checkpoint", api="hr_checkpoint_name"), + PasteField(hr_sampler_name, "Hires sampler", api="hr_sampler_name"), + PasteField(hr_sampler_container, lambda d: gr.update(visible=True) if d.get("Hires sampler", "Use same sampler") != "Use same sampler" or d.get("Hires checkpoint", "Use same checkpoint") != "Use same checkpoint" else gr.update()), + PasteField(hr_prompt, "Hires prompt", api="hr_prompt"), + PasteField(hr_negative_prompt, "Hires negative prompt", api="hr_negative_prompt"), + PasteField(hr_prompts_container, lambda d: gr.update(visible=True) if d.get("Hires prompt", "") != "" or d.get("Hires negative prompt", "") != "" else gr.update()), *scripts.scripts_txt2img.infotext_fields ] parameters_copypaste.add_paste_fields("txt2img", None, txt2img_paste_fields, override_settings) From 59d060fd5ea93fcc3fdbfbd13b6e20fda06ecf94 Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Sat, 30 Dec 2023 17:11:03 +0900 Subject: [PATCH 0470/1174] More lora not found warning --- extensions-builtin/Lora/networks.py | 8 +++++++- extensions-builtin/Lora/scripts/lora_script.py | 2 ++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/extensions-builtin/Lora/networks.py b/extensions-builtin/Lora/networks.py index 985b2753bce..72ebd624169 100644 --- a/extensions-builtin/Lora/networks.py +++ b/extensions-builtin/Lora/networks.py @@ -1,3 +1,4 @@ +import gradio as gr import logging import os import re @@ -314,7 +315,12 @@ def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=No emb_db.skipped_embeddings[name] = embedding if failed_to_load_networks: - sd_hijack.model_hijack.comments.append("Networks not found: " + ", ".join(failed_to_load_networks)) + lora_not_found_message = f'Lora not found: {", ".join(failed_to_load_networks)}' + sd_hijack.model_hijack.comments.append(lora_not_found_message) + if shared.opts.lora_not_found_warning_console: + print(f'\n{lora_not_found_message}\n') + if shared.opts.lora_not_found_gradio_warning: + gr.Warning(lora_not_found_message) purge_networks_from_memory() diff --git a/extensions-builtin/Lora/scripts/lora_script.py b/extensions-builtin/Lora/scripts/lora_script.py index ef23968c563..1518f7e5c89 100644 --- a/extensions-builtin/Lora/scripts/lora_script.py +++ b/extensions-builtin/Lora/scripts/lora_script.py @@ -39,6 +39,8 @@ def before_ui(): "lora_show_all": shared.OptionInfo(False, "Always show all networks on the Lora page").info("otherwise, those detected as for incompatible version of Stable Diffusion will be hidden"), "lora_hide_unknown_for_versions": shared.OptionInfo([], "Hide networks of unknown versions for model versions", gr.CheckboxGroup, {"choices": ["SD1", "SD2", "SDXL"]}), "lora_in_memory_limit": shared.OptionInfo(0, "Number of Lora networks to keep cached in memory", gr.Number, {"precision": 0}), + "lora_not_found_warning_console": shared.OptionInfo(False, "Lora not found warning in console"), + "lora_not_found_gradio_warning": shared.OptionInfo(False, "Lora not found warning popup in webui"), })) From ba92135a2ba9e210ce5370715e2defcb43df70d1 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sat, 30 Dec 2023 12:11:09 +0300 Subject: [PATCH 0471/1174] add override_settings support for infotext API --- modules/api/api.py | 10 ++++ modules/generation_parameters_copypaste.py | 66 ++++++++++++++-------- 2 files changed, 54 insertions(+), 22 deletions(-) diff --git a/modules/api/api.py b/modules/api/api.py index b3d70940401..fb108486fab 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -341,6 +341,7 @@ def apply_infotext(self, request, tabname): params = generation_parameters_copypaste.parse_generation_parameters(request.infotext) + handled_fields = {} for field in generation_parameters_copypaste.paste_fields[tabname]["fields"]: if not field.api: continue @@ -355,6 +356,15 @@ def apply_infotext(self, request, tabname): value = target_type(value) setattr(request, field.api, value) + handled_fields[field.label] = 1 + + if request.override_settings is None: + request.override_settings = {} + + overriden_settings = generation_parameters_copypaste.get_override_settings(params, skip_fields=handled_fields) + for infotext_text, setting_name, value in overriden_settings: + if setting_name not in request.override_settings: + request.override_settings[setting_name] = value return params diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py index 4b4727c4458..86a36c3273c 100644 --- a/modules/generation_parameters_copypaste.py +++ b/modules/generation_parameters_copypaste.py @@ -390,6 +390,48 @@ def create_override_settings_dict(text_pairs): return res +def get_override_settings(params, *, skip_fields=None): + """Returns a list of settings overrides from the infotext parameters dictionary. + + This function checks the `params` dictionary for any keys that correspond to settings in `shared.opts` and returns + a list of tuples containing the parameter name, setting name, and new value cast to correct type. + + It checks for conditions before adding an override: + - ignores settings that match the current value + - ignores parameter keys present in skip_fields argument. + + Example input: + {"Clip skip": "2"} + + Example output: + [("Clip skip", "CLIP_stop_at_last_layers", 2)] + """ + + res = [] + + mapping = [(info.infotext, k) for k, info in shared.opts.data_labels.items() if info.infotext] + for param_name, setting_name in mapping + infotext_to_setting_name_mapping: + if param_name in (skip_fields or {}): + continue + + v = params.get(param_name, None) + if v is None: + continue + + if setting_name == "sd_model_checkpoint" and shared.opts.disable_weights_auto_swap: + continue + + v = shared.opts.cast_value(setting_name, v) + current_value = getattr(shared.opts, setting_name, None) + + if v == current_value: + continue + + res.append((param_name, setting_name, v)) + + return res + + def connect_paste(button, paste_fields, input_comp, override_settings_component, tabname): def paste_func(prompt): if not prompt and not shared.cmd_opts.hide_ui_dir_config: @@ -431,29 +473,9 @@ def paste_func(prompt): already_handled_fields = {key: 1 for _, key in paste_fields} def paste_settings(params): - vals = {} - - mapping = [(info.infotext, k) for k, info in shared.opts.data_labels.items() if info.infotext] - for param_name, setting_name in mapping + infotext_to_setting_name_mapping: - if param_name in already_handled_fields: - continue - - v = params.get(param_name, None) - if v is None: - continue - - if setting_name == "sd_model_checkpoint" and shared.opts.disable_weights_auto_swap: - continue - - v = shared.opts.cast_value(setting_name, v) - current_value = getattr(shared.opts, setting_name, None) - - if v == current_value: - continue - - vals[param_name] = v + vals = get_override_settings(params, skip_fields=already_handled_fields) - vals_pairs = [f"{k}: {v}" for k, v in vals.items()] + vals_pairs = [f"{infotext_text}: {value}" for infotext_text, setting_name, value in vals] return gr.Dropdown.update(value=vals_pairs, choices=vals_pairs, visible=bool(vals_pairs)) From 8b08b78c03f09898455d54cf099225ed5f8de1ee Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sat, 30 Dec 2023 12:27:23 +0300 Subject: [PATCH 0472/1174] make it so that if an option from infotext conflicts with an argument from API, the latter overrides the former --- modules/api/api.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/modules/api/api.py b/modules/api/api.py index fb108486fab..cabccb4c025 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -339,6 +339,7 @@ def apply_infotext(self, request, tabname): if not request.infotext: return {} + set_fields = request.model_dump(exclude_unset=True) if hasattr(request, "request") else request.dict(exclude_unset=True) # pydantic v1/v2 have differenrt names for this params = generation_parameters_copypaste.parse_generation_parameters(request.infotext) handled_fields = {} @@ -346,6 +347,9 @@ def apply_infotext(self, request, tabname): if not field.api: continue + if field.api in set_fields: + continue + value = field.function(params) if field.function else params.get(field.label) target_type = request.__fields__[field.api].type_ @@ -376,7 +380,7 @@ def text2imgapi(self, txt2imgreq: models.StableDiffusionTxt2ImgProcessingAPI): script_runner.initialize_scripts(False) ui.create_ui() - infotext_params = self.apply_infotext(txt2imgreq, "txt2img") + self.apply_infotext(txt2imgreq, "txt2img") if not self.default_script_arg_txt2img: self.default_script_arg_txt2img = self.init_default_script_args(script_runner) From 0aacd4c72b4008d7153e747301fe8c5ffca57f85 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sat, 30 Dec 2023 13:33:18 +0300 Subject: [PATCH 0473/1174] add support for alwayson scripts for infotext API --- modules/api/api.py | 61 +++++++++++++++++++++++++++++++--------------- 1 file changed, 42 insertions(+), 19 deletions(-) diff --git a/modules/api/api.py b/modules/api/api.py index cabccb4c025..946cfe4a92f 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -312,8 +312,13 @@ def init_default_script_args(self, script_runner): script_args[script.args_from:script.args_to] = ui_default_values return script_args - def init_script_args(self, request, default_script_args, selectable_scripts, selectable_idx, script_runner): + def init_script_args(self, request, default_script_args, selectable_scripts, selectable_idx, script_runner, *, input_script_args=None): script_args = default_script_args.copy() + + if input_script_args is not None: + for index, value in input_script_args.items(): + script_args[index] = value + # position 0 in script_arg is the idx+1 of the selectable script that is going to be run when using scripts.scripts_*2img.run() if selectable_scripts: script_args[selectable_scripts.args_from:selectable_scripts.args_to] = request.script_args @@ -335,41 +340,58 @@ def init_script_args(self, request, default_script_args, selectable_scripts, sel script_args[alwayson_script.args_from + idx] = request.alwayson_scripts[alwayson_script_name]["args"][idx] return script_args - def apply_infotext(self, request, tabname): + def apply_infotext(self, request, tabname, *, script_runner=None, mentioned_script_args=None): if not request.infotext: return {} + possible_fields = generation_parameters_copypaste.paste_fields[tabname]["fields"] set_fields = request.model_dump(exclude_unset=True) if hasattr(request, "request") else request.dict(exclude_unset=True) # pydantic v1/v2 have differenrt names for this params = generation_parameters_copypaste.parse_generation_parameters(request.infotext) - handled_fields = {} - for field in generation_parameters_copypaste.paste_fields[tabname]["fields"]: - if not field.api: - continue - - if field.api in set_fields: - continue - + def get_field_value(field, params): value = field.function(params) if field.function else params.get(field.label) - target_type = request.__fields__[field.api].type_ - if value is None: - continue + return None + + if field.api in request.__fields__: + target_type = request.__fields__[field.api].type_ + else: + target_type = type(field.component.value) + + if target_type == type(None): + return None if not isinstance(value, target_type): value = target_type(value) - setattr(request, field.api, value) - handled_fields[field.label] = 1 + return value + + for field in possible_fields: + if not field.api: + continue + + if field.api in set_fields: + continue + + value = get_field_value(field, params) + if value is not None: + setattr(request, field.api, value) if request.override_settings is None: request.override_settings = {} - overriden_settings = generation_parameters_copypaste.get_override_settings(params, skip_fields=handled_fields) - for infotext_text, setting_name, value in overriden_settings: + overriden_settings = generation_parameters_copypaste.get_override_settings(params) + for _, setting_name, value in overriden_settings: if setting_name not in request.override_settings: request.override_settings[setting_name] = value + if script_runner is not None and mentioned_script_args is not None: + indexes = {v: i for i, v in enumerate(script_runner.inputs)} + script_fields = ((field, indexes[field.component]) for field in possible_fields if field.component in indexes) + + for field, index in script_fields: + mentioned_script_args[index] = get_field_value(field, params) + return params def text2imgapi(self, txt2imgreq: models.StableDiffusionTxt2ImgProcessingAPI): @@ -380,7 +402,8 @@ def text2imgapi(self, txt2imgreq: models.StableDiffusionTxt2ImgProcessingAPI): script_runner.initialize_scripts(False) ui.create_ui() - self.apply_infotext(txt2imgreq, "txt2img") + infotext_script_args = {} + self.apply_infotext(txt2imgreq, "txt2img", script_runner=script_runner, mentioned_script_args=infotext_script_args) if not self.default_script_arg_txt2img: self.default_script_arg_txt2img = self.init_default_script_args(script_runner) @@ -400,7 +423,7 @@ def text2imgapi(self, txt2imgreq: models.StableDiffusionTxt2ImgProcessingAPI): args.pop('alwayson_scripts', None) args.pop('infotext', None) - script_args = self.init_script_args(txt2imgreq, self.default_script_arg_txt2img, selectable_scripts, selectable_script_idx, script_runner) + script_args = self.init_script_args(txt2imgreq, self.default_script_arg_txt2img, selectable_scripts, selectable_script_idx, script_runner, input_script_args=infotext_script_args) send_images = args.pop('send_images', True) args.pop('save_images', None) From 11a435b4697c2d735a117f31944c4ebe59c2504c Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sat, 30 Dec 2023 13:34:46 +0300 Subject: [PATCH 0474/1174] img2img support for infotext API --- modules/api/api.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/modules/api/api.py b/modules/api/api.py index 946cfe4a92f..2c8dc2a0f6e 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -470,6 +470,10 @@ def img2imgapi(self, img2imgreq: models.StableDiffusionImg2ImgProcessingAPI): if not script_runner.scripts: script_runner.initialize_scripts(True) ui.create_ui() + + infotext_script_args = {} + self.apply_infotext(img2imgreq, "img2img", script_runner=script_runner, mentioned_script_args=infotext_script_args) + if not self.default_script_arg_img2img: self.default_script_arg_img2img = self.init_default_script_args(script_runner) selectable_scripts, selectable_script_idx = self.get_selectable_script(img2imgreq.script_name, script_runner) @@ -489,7 +493,7 @@ def img2imgapi(self, img2imgreq: models.StableDiffusionImg2ImgProcessingAPI): args.pop('script_args', None) # will refeed them to the pipeline directly after initializing them args.pop('alwayson_scripts', None) - script_args = self.init_script_args(img2imgreq, self.default_script_arg_img2img, selectable_scripts, selectable_script_idx, script_runner) + script_args = self.init_script_args(img2imgreq, self.default_script_arg_img2img, selectable_scripts, selectable_script_idx, script_runner, input_script_args=infotext_script_args) send_images = args.pop('send_images', True) args.pop('save_images', None) From 8f1826375943718463cec3af97a37886249bdb44 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sat, 30 Dec 2023 13:48:25 +0300 Subject: [PATCH 0475/1174] fix bad values read from infotext for API, add comment --- modules/api/api.py | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/modules/api/api.py b/modules/api/api.py index 2c8dc2a0f6e..2918f785750 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -341,6 +341,13 @@ def init_script_args(self, request, default_script_args, selectable_scripts, sel return script_args def apply_infotext(self, request, tabname, *, script_runner=None, mentioned_script_args=None): + """Processes `infotext` field from the `request`, and sets other fields of the `request` accoring to what's in infotext. + + If request already has a field set, and that field is encountered in infotext too, the value from infotext is ignored. + + Additionally, fills `mentioned_script_args` dict with index: value pairs for script arguments read from infotext. + """ + if not request.infotext: return {} @@ -361,7 +368,10 @@ def get_field_value(field, params): if target_type == type(None): return None - if not isinstance(value, target_type): + if isinstance(value, dict) and value.get('__type__') == 'generic_update': # this is a gradio.update rather than a value + value = value.get('value') + + if value is not None and not isinstance(value, target_type): value = target_type(value) return value @@ -390,7 +400,12 @@ def get_field_value(field, params): script_fields = ((field, indexes[field.component]) for field in possible_fields if field.component in indexes) for field, index in script_fields: - mentioned_script_args[index] = get_field_value(field, params) + value = get_field_value(field, params) + + if value is None: + continue + + mentioned_script_args[index] = value return params From f0e2e8b930115012976f7c5bae00e243a7ebbf79 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sat, 30 Dec 2023 15:12:48 +0300 Subject: [PATCH 0476/1174] update #14354 --- webui.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/webui.py b/webui.py index 3f5b67a3060..2c417168aa6 100644 --- a/webui.py +++ b/webui.py @@ -39,7 +39,7 @@ def api_only(): print(f"Startup time: {startup_timer.summary()}.") api.launch( - server_name=cmd_opts.server_name or ("0.0.0.0" if cmd_opts.listen else "127.0.0.1"), + server_name=initialize_util.gradio_server_name(), port=cmd_opts.port if cmd_opts.port else 7861, root_path=f"/{cmd_opts.subpath}" if cmd_opts.subpath else "" ) From c069c2c5628728c9506dd034ef98e6335fd5bb34 Mon Sep 17 00:00:00 2001 From: lanyeeee <1210347077@qq.com> Date: Sat, 30 Dec 2023 21:32:22 +0800 Subject: [PATCH 0477/1174] add locks to ensure init args are thread-safe --- modules/api/api.py | 24 ++++++++++++++++-------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/modules/api/api.py b/modules/api/api.py index f0a68c672d6..45c5c50773a 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -251,14 +251,10 @@ def __init__(self, app: FastAPI, queue_lock: Lock): self.default_script_arg_txt2img = [] self.default_script_arg_img2img = [] - script_runner = scripts.scripts_img2img - if not script_runner.scripts: - script_runner.initialize_scripts(True) - ui.create_ui() - if not self.default_script_arg_txt2img: - self.default_script_arg_txt2img = self.init_default_script_args(script_runner) - if not self.default_script_arg_img2img: - self.default_script_arg_img2img = self.init_default_script_args(script_runner) + self.txt2img_script_arg_init_lock = Lock() + self.img2img_script_arg_init_lock = Lock() + + def add_api_route(self, path: str, endpoint, **kwargs): if shared.cmd_opts.api_auth: @@ -348,6 +344,12 @@ def text2imgapi(self, txt2imgreq: models.StableDiffusionTxt2ImgProcessingAPI): task_id = txt2imgreq.force_task_id or create_task_id("txt2img") script_runner = scripts.scripts_txt2img + with self.txt2img_script_arg_init_lock: + if not script_runner.scripts: + script_runner.initialize_scripts(False) + ui.create_ui() + if not self.default_script_arg_txt2img: + self.default_script_arg_txt2img = self.init_default_script_args(script_runner) selectable_scripts, selectable_script_idx = self.get_selectable_script(txt2imgreq.script_name, script_runner) populate = txt2imgreq.copy(update={ # Override __init__ params @@ -407,6 +409,12 @@ def img2imgapi(self, img2imgreq: models.StableDiffusionImg2ImgProcessingAPI): mask = decode_base64_to_image(mask) script_runner = scripts.scripts_img2img + with self.img2img_script_arg_init_lock: + if not script_runner.scripts: + script_runner.initialize_scripts(True) + ui.create_ui() + if not self.default_script_arg_img2img: + self.default_script_arg_img2img = self.init_default_script_args(script_runner) selectable_scripts, selectable_script_idx = self.get_selectable_script(img2imgreq.script_name, script_runner) populate = img2imgreq.copy(update={ # Override __init__ params From 31992eff9b9714c158b12cec16dfe66c76270dfa Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sat, 30 Dec 2023 16:51:02 +0300 Subject: [PATCH 0478/1174] make it possible again to extract styles that have whitespace at the end. --- modules/styles.py | 47 +++++++++++++++++------------------------------ 1 file changed, 17 insertions(+), 30 deletions(-) diff --git a/modules/styles.py b/modules/styles.py index 81d9800d184..026c43001a4 100644 --- a/modules/styles.py +++ b/modules/styles.py @@ -30,38 +30,29 @@ def apply_styles_to_prompt(prompt, styles): return prompt -def unwrap_style_text_from_prompt(style_text, prompt): - """ - Checks the prompt to see if the style text is wrapped around it. If so, - returns True plus the prompt text without the style text. Otherwise, returns - False with the original prompt. +def extract_style_text_from_prompt(style_text, prompt): + """This function extracts the text from a given prompt based on a provided style text. It checks if the style text contains the placeholder {prompt} or if it appears at the end of the prompt. If a match is found, it returns True along with the extracted text. Otherwise, it returns False and the original prompt. - Note that the "cleaned" version of the style text is only used for matching - purposes here. It isn't returned; the original style text is not modified. + extract_style_text_from_prompt("masterpiece", "1girl, art by greg, masterpiece") outputs (True, "1girl, art by greg") + extract_style_text_from_prompt("masterpiece, {prompt}", "masterpiece, 1girl, art by greg") outputs (True, "1girl, art by greg") + extract_style_text_from_prompt("masterpiece, {prompt}", "exquisite, 1girl, art by greg") outputs (False, "exquisite, 1girl, art by greg") """ - stripped_prompt = prompt - stripped_style_text = style_text + + stripped_prompt = prompt.strip() + stripped_style_text = style_text.strip() + if "{prompt}" in stripped_style_text: - # Work out whether the prompt is wrapped in the style text. If so, we - # return True and the "inner" prompt text that isn't part of the style. - try: - left, right = stripped_style_text.split("{prompt}", 2) - except ValueError as e: - # If the style text has multple "{prompt}"s, we can't split it into - # two parts. This is an error, but we can't do anything about it. - print(f"Unable to compare style text to prompt:\n{style_text}") - print(f"Error: {e}") - return False, prompt + left, right = stripped_style_text.split("{prompt}", 2) if stripped_prompt.startswith(left) and stripped_prompt.endswith(right): - prompt = stripped_prompt[len(left) : len(stripped_prompt) - len(right)] + prompt = stripped_prompt[len(left):len(stripped_prompt)-len(right)] return True, prompt else: - # Work out whether the given prompt ends with the style text. If so, we - # return True and the prompt text up to where the style text starts. if stripped_prompt.endswith(stripped_style_text): - prompt = stripped_prompt[: len(stripped_prompt) - len(stripped_style_text)] - if prompt.endswith(", "): + prompt = stripped_prompt[:len(stripped_prompt)-len(stripped_style_text)] + + if prompt.endswith(', '): prompt = prompt[:-2] + return True, prompt return False, prompt @@ -76,15 +67,11 @@ def extract_original_prompts(style: PromptStyle, prompt, negative_prompt): if not style.prompt and not style.negative_prompt: return False, prompt, negative_prompt - match_positive, extracted_positive = unwrap_style_text_from_prompt( - style.prompt, prompt - ) + match_positive, extracted_positive = extract_style_text_from_prompt(style.prompt, prompt) if not match_positive: return False, prompt, negative_prompt - match_negative, extracted_negative = unwrap_style_text_from_prompt( - style.negative_prompt, negative_prompt - ) + match_negative, extracted_negative = extract_style_text_from_prompt(style.negative_prompt, negative_prompt) if not match_negative: return False, prompt, negative_prompt From 7aa27b000a3087dcb5cc7254600064bf70cacd3e Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Mon, 25 Dec 2023 14:44:15 +0200 Subject: [PATCH 0479/1174] Add types to split_grid --- modules/images.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/modules/images.py b/modules/images.py index 16f9ae7cce1..d30e8865d7c 100644 --- a/modules/images.py +++ b/modules/images.py @@ -64,9 +64,8 @@ def image_grid(imgs, batch_size=1, rows=None): Grid = namedtuple("Grid", ["tiles", "tile_w", "tile_h", "image_w", "image_h", "overlap"]) -def split_grid(image, tile_w=512, tile_h=512, overlap=64): - w = image.width - h = image.height +def split_grid(image: Image.Image, tile_w: int = 512, tile_h: int = 512, overlap: int = 64) -> Grid: + w, h = image.size non_overlap_width = tile_w - overlap non_overlap_height = tile_h - overlap From 12c6f37f8e4b1d1d643c9d8d5dfc763c3203c728 Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Wed, 27 Dec 2023 11:01:45 +0200 Subject: [PATCH 0480/1174] Add tile_count property to Grid --- modules/images.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/modules/images.py b/modules/images.py index d30e8865d7c..87a7bf22139 100644 --- a/modules/images.py +++ b/modules/images.py @@ -61,7 +61,13 @@ def image_grid(imgs, batch_size=1, rows=None): return grid -Grid = namedtuple("Grid", ["tiles", "tile_w", "tile_h", "image_w", "image_h", "overlap"]) +class Grid(namedtuple("_Grid", ["tiles", "tile_w", "tile_h", "image_w", "image_h", "overlap"])): + @property + def tile_count(self) -> int: + """ + The total number of tiles in the grid. + """ + return sum(len(row[2]) for row in self.tiles) def split_grid(image: Image.Image, tile_w: int = 512, tile_h: int = 512, overlap: int = 64) -> Grid: From e472383acbb9e07dca311abe5fb16ee2675e410a Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Wed, 27 Dec 2023 11:04:33 +0200 Subject: [PATCH 0481/1174] Refactor esrgan_upscale to more generic upscale_with_model --- modules/esrgan_model.py | 47 +++++----------------------- modules/upscaler_utils.py | 66 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 74 insertions(+), 39 deletions(-) create mode 100644 modules/upscaler_utils.py diff --git a/modules/esrgan_model.py b/modules/esrgan_model.py index 02a1727d280..c0d22a99273 100644 --- a/modules/esrgan_model.py +++ b/modules/esrgan_model.py @@ -1,13 +1,12 @@ import sys -import numpy as np import torch -from PIL import Image import modules.esrgan_model_arch as arch -from modules import modelloader, images, devices +from modules import modelloader, devices from modules.shared import opts from modules.upscaler import Upscaler, UpscalerData +from modules.upscaler_utils import upscale_with_model def mod2normal(state_dict): @@ -190,40 +189,10 @@ def load_model(self, path: str): return model -def upscale_without_tiling(model, img): - img = np.array(img) - img = img[:, :, ::-1] - img = np.ascontiguousarray(np.transpose(img, (2, 0, 1))) / 255 - img = torch.from_numpy(img).float() - img = img.unsqueeze(0).to(devices.device_esrgan) - with torch.no_grad(): - output = model(img) - output = output.squeeze().float().cpu().clamp_(0, 1).numpy() - output = 255. * np.moveaxis(output, 0, 2) - output = output.astype(np.uint8) - output = output[:, :, ::-1] - return Image.fromarray(output, 'RGB') - - def esrgan_upscale(model, img): - if opts.ESRGAN_tile == 0: - return upscale_without_tiling(model, img) - - grid = images.split_grid(img, opts.ESRGAN_tile, opts.ESRGAN_tile, opts.ESRGAN_tile_overlap) - newtiles = [] - scale_factor = 1 - - for y, h, row in grid.tiles: - newrow = [] - for tiledata in row: - x, w, tile = tiledata - - output = upscale_without_tiling(model, tile) - scale_factor = output.width // tile.width - - newrow.append([x * scale_factor, w * scale_factor, output]) - newtiles.append([y * scale_factor, h * scale_factor, newrow]) - - newgrid = images.Grid(newtiles, grid.tile_w * scale_factor, grid.tile_h * scale_factor, grid.image_w * scale_factor, grid.image_h * scale_factor, grid.overlap * scale_factor) - output = images.combine_grid(newgrid) - return output + return upscale_with_model( + model, + img, + tile_size=opts.ESRGAN_tile, + tile_overlap=opts.ESRGAN_tile_overlap, + ) diff --git a/modules/upscaler_utils.py b/modules/upscaler_utils.py new file mode 100644 index 00000000000..8bdda51c4c2 --- /dev/null +++ b/modules/upscaler_utils.py @@ -0,0 +1,66 @@ +import logging +from typing import Callable + +import numpy as np +import torch +import tqdm +from PIL import Image + +from modules import devices, images + +logger = logging.getLogger(__name__) + + +def upscale_without_tiling(model, img: Image.Image): + img = np.array(img) + img = img[:, :, ::-1] + img = np.ascontiguousarray(np.transpose(img, (2, 0, 1))) / 255 + img = torch.from_numpy(img).float() + img = img.unsqueeze(0).to(devices.device_esrgan) + with torch.no_grad(): + output = model(img) + output = output.squeeze().float().cpu().clamp_(0, 1).numpy() + output = 255. * np.moveaxis(output, 0, 2) + output = output.astype(np.uint8) + output = output[:, :, ::-1] + return Image.fromarray(output, 'RGB') + + +def upscale_with_model( + model: Callable[[torch.Tensor], torch.Tensor], + img: Image.Image, + *, + tile_size: int, + tile_overlap: int = 0, + desc="tiled upscale", +) -> Image.Image: + if tile_size <= 0: + logger.debug("Upscaling %s without tiling", img) + output = upscale_without_tiling(model, img) + logger.debug("=> %s", output) + return output + + grid = images.split_grid(img, tile_size, tile_size, tile_overlap) + newtiles = [] + + with tqdm.tqdm(total=grid.tile_count, desc=desc) as p: + for y, h, row in grid.tiles: + newrow = [] + for x, w, tile in row: + logger.debug("Tile (%d, %d) %s...", x, y, tile) + output = upscale_without_tiling(model, tile) + scale_factor = output.width // tile.width + logger.debug("=> %s (scale factor %s)", output, scale_factor) + newrow.append([x * scale_factor, w * scale_factor, output]) + p.update(1) + newtiles.append([y * scale_factor, h * scale_factor, newrow]) + + newgrid = images.Grid( + newtiles, + tile_w=grid.tile_w * scale_factor, + tile_h=grid.tile_h * scale_factor, + image_w=grid.image_w * scale_factor, + image_h=grid.image_h * scale_factor, + overlap=grid.overlap * scale_factor, + ) + return images.combine_grid(newgrid) From b0f59342346b1c8b405f97c0e0bb01c6ae05c601 Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Mon, 25 Dec 2023 14:43:51 +0200 Subject: [PATCH 0482/1174] Use Spandrel for upscaling and face restoration architectures (aside from GFPGAN and LDSR) --- .../ScuNET/scripts/scunet_model.py | 13 +- .../ScuNET/scunet_model_arch.py | 268 ----- .../SwinIR/scripts/swinir_model.py | 126 +- .../SwinIR/swinir_model_arch.py | 867 -------------- .../SwinIR/swinir_model_arch_v2.py | 1017 ----------------- modules/codeformer/codeformer_arch.py | 276 ----- modules/codeformer/vqgan_arch.py | 435 ------- modules/codeformer_model.py | 195 ++-- modules/esrgan_model.py | 153 +-- modules/esrgan_model_arch.py | 465 -------- modules/gfpgan_model.py | 13 +- modules/launch_utils.py | 7 - modules/modelloader.py | 16 + modules/paths.py | 1 - modules/realesrgan_model.py | 153 +-- modules/sysinfo.py | 2 - modules/upscaler.py | 3 + requirements.txt | 3 +- requirements_versions.txt | 4 +- 19 files changed, 263 insertions(+), 3754 deletions(-) delete mode 100644 extensions-builtin/ScuNET/scunet_model_arch.py delete mode 100644 extensions-builtin/SwinIR/swinir_model_arch.py delete mode 100644 extensions-builtin/SwinIR/swinir_model_arch_v2.py delete mode 100644 modules/codeformer/codeformer_arch.py delete mode 100644 modules/codeformer/vqgan_arch.py delete mode 100644 modules/esrgan_model_arch.py diff --git a/extensions-builtin/ScuNET/scripts/scunet_model.py b/extensions-builtin/ScuNET/scripts/scunet_model.py index 167d2f64b8e..18cf8e1a0b6 100644 --- a/extensions-builtin/ScuNET/scripts/scunet_model.py +++ b/extensions-builtin/ScuNET/scripts/scunet_model.py @@ -7,9 +7,7 @@ import modules.upscaler from modules import devices, modelloader, script_callbacks, errors -from scunet_model_arch import SCUNet -from modules.modelloader import load_file_from_url from modules.shared import opts @@ -120,17 +118,10 @@ def load_model(self, path: str): device = devices.get_device_for('scunet') if path.startswith("http"): # TODO: this doesn't use `path` at all? - filename = load_file_from_url(self.model_url, model_dir=self.model_download_path, file_name=f"{self.name}.pth") + filename = modelloader.load_file_from_url(self.model_url, model_dir=self.model_download_path, file_name=f"{self.name}.pth") else: filename = path - model = SCUNet(in_nc=3, config=[4, 4, 4, 4, 4, 4, 4], dim=64) - model.load_state_dict(torch.load(filename), strict=True) - model.eval() - for _, v in model.named_parameters(): - v.requires_grad = False - model = model.to(device) - - return model + return modelloader.load_spandrel_model(filename, device=device) def on_ui_settings(): diff --git a/extensions-builtin/ScuNET/scunet_model_arch.py b/extensions-builtin/ScuNET/scunet_model_arch.py deleted file mode 100644 index b51a880629b..00000000000 --- a/extensions-builtin/ScuNET/scunet_model_arch.py +++ /dev/null @@ -1,268 +0,0 @@ -# -*- coding: utf-8 -*- -import numpy as np -import torch -import torch.nn as nn -from einops import rearrange -from einops.layers.torch import Rearrange -from timm.models.layers import trunc_normal_, DropPath - - -class WMSA(nn.Module): - """ Self-attention module in Swin Transformer - """ - - def __init__(self, input_dim, output_dim, head_dim, window_size, type): - super(WMSA, self).__init__() - self.input_dim = input_dim - self.output_dim = output_dim - self.head_dim = head_dim - self.scale = self.head_dim ** -0.5 - self.n_heads = input_dim // head_dim - self.window_size = window_size - self.type = type - self.embedding_layer = nn.Linear(self.input_dim, 3 * self.input_dim, bias=True) - - self.relative_position_params = nn.Parameter( - torch.zeros((2 * window_size - 1) * (2 * window_size - 1), self.n_heads)) - - self.linear = nn.Linear(self.input_dim, self.output_dim) - - trunc_normal_(self.relative_position_params, std=.02) - self.relative_position_params = torch.nn.Parameter( - self.relative_position_params.view(2 * window_size - 1, 2 * window_size - 1, self.n_heads).transpose(1, - 2).transpose( - 0, 1)) - - def generate_mask(self, h, w, p, shift): - """ generating the mask of SW-MSA - Args: - shift: shift parameters in CyclicShift. - Returns: - attn_mask: should be (1 1 w p p), - """ - # supporting square. - attn_mask = torch.zeros(h, w, p, p, p, p, dtype=torch.bool, device=self.relative_position_params.device) - if self.type == 'W': - return attn_mask - - s = p - shift - attn_mask[-1, :, :s, :, s:, :] = True - attn_mask[-1, :, s:, :, :s, :] = True - attn_mask[:, -1, :, :s, :, s:] = True - attn_mask[:, -1, :, s:, :, :s] = True - attn_mask = rearrange(attn_mask, 'w1 w2 p1 p2 p3 p4 -> 1 1 (w1 w2) (p1 p2) (p3 p4)') - return attn_mask - - def forward(self, x): - """ Forward pass of Window Multi-head Self-attention module. - Args: - x: input tensor with shape of [b h w c]; - attn_mask: attention mask, fill -inf where the value is True; - Returns: - output: tensor shape [b h w c] - """ - if self.type != 'W': - x = torch.roll(x, shifts=(-(self.window_size // 2), -(self.window_size // 2)), dims=(1, 2)) - - x = rearrange(x, 'b (w1 p1) (w2 p2) c -> b w1 w2 p1 p2 c', p1=self.window_size, p2=self.window_size) - h_windows = x.size(1) - w_windows = x.size(2) - # square validation - # assert h_windows == w_windows - - x = rearrange(x, 'b w1 w2 p1 p2 c -> b (w1 w2) (p1 p2) c', p1=self.window_size, p2=self.window_size) - qkv = self.embedding_layer(x) - q, k, v = rearrange(qkv, 'b nw np (threeh c) -> threeh b nw np c', c=self.head_dim).chunk(3, dim=0) - sim = torch.einsum('hbwpc,hbwqc->hbwpq', q, k) * self.scale - # Adding learnable relative embedding - sim = sim + rearrange(self.relative_embedding(), 'h p q -> h 1 1 p q') - # Using Attn Mask to distinguish different subwindows. - if self.type != 'W': - attn_mask = self.generate_mask(h_windows, w_windows, self.window_size, shift=self.window_size // 2) - sim = sim.masked_fill_(attn_mask, float("-inf")) - - probs = nn.functional.softmax(sim, dim=-1) - output = torch.einsum('hbwij,hbwjc->hbwic', probs, v) - output = rearrange(output, 'h b w p c -> b w p (h c)') - output = self.linear(output) - output = rearrange(output, 'b (w1 w2) (p1 p2) c -> b (w1 p1) (w2 p2) c', w1=h_windows, p1=self.window_size) - - if self.type != 'W': - output = torch.roll(output, shifts=(self.window_size // 2, self.window_size // 2), dims=(1, 2)) - - return output - - def relative_embedding(self): - cord = torch.tensor(np.array([[i, j] for i in range(self.window_size) for j in range(self.window_size)])) - relation = cord[:, None, :] - cord[None, :, :] + self.window_size - 1 - # negative is allowed - return self.relative_position_params[:, relation[:, :, 0].long(), relation[:, :, 1].long()] - - -class Block(nn.Module): - def __init__(self, input_dim, output_dim, head_dim, window_size, drop_path, type='W', input_resolution=None): - """ SwinTransformer Block - """ - super(Block, self).__init__() - self.input_dim = input_dim - self.output_dim = output_dim - assert type in ['W', 'SW'] - self.type = type - if input_resolution <= window_size: - self.type = 'W' - - self.ln1 = nn.LayerNorm(input_dim) - self.msa = WMSA(input_dim, input_dim, head_dim, window_size, self.type) - self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() - self.ln2 = nn.LayerNorm(input_dim) - self.mlp = nn.Sequential( - nn.Linear(input_dim, 4 * input_dim), - nn.GELU(), - nn.Linear(4 * input_dim, output_dim), - ) - - def forward(self, x): - x = x + self.drop_path(self.msa(self.ln1(x))) - x = x + self.drop_path(self.mlp(self.ln2(x))) - return x - - -class ConvTransBlock(nn.Module): - def __init__(self, conv_dim, trans_dim, head_dim, window_size, drop_path, type='W', input_resolution=None): - """ SwinTransformer and Conv Block - """ - super(ConvTransBlock, self).__init__() - self.conv_dim = conv_dim - self.trans_dim = trans_dim - self.head_dim = head_dim - self.window_size = window_size - self.drop_path = drop_path - self.type = type - self.input_resolution = input_resolution - - assert self.type in ['W', 'SW'] - if self.input_resolution <= self.window_size: - self.type = 'W' - - self.trans_block = Block(self.trans_dim, self.trans_dim, self.head_dim, self.window_size, self.drop_path, - self.type, self.input_resolution) - self.conv1_1 = nn.Conv2d(self.conv_dim + self.trans_dim, self.conv_dim + self.trans_dim, 1, 1, 0, bias=True) - self.conv1_2 = nn.Conv2d(self.conv_dim + self.trans_dim, self.conv_dim + self.trans_dim, 1, 1, 0, bias=True) - - self.conv_block = nn.Sequential( - nn.Conv2d(self.conv_dim, self.conv_dim, 3, 1, 1, bias=False), - nn.ReLU(True), - nn.Conv2d(self.conv_dim, self.conv_dim, 3, 1, 1, bias=False) - ) - - def forward(self, x): - conv_x, trans_x = torch.split(self.conv1_1(x), (self.conv_dim, self.trans_dim), dim=1) - conv_x = self.conv_block(conv_x) + conv_x - trans_x = Rearrange('b c h w -> b h w c')(trans_x) - trans_x = self.trans_block(trans_x) - trans_x = Rearrange('b h w c -> b c h w')(trans_x) - res = self.conv1_2(torch.cat((conv_x, trans_x), dim=1)) - x = x + res - - return x - - -class SCUNet(nn.Module): - # def __init__(self, in_nc=3, config=[2, 2, 2, 2, 2, 2, 2], dim=64, drop_path_rate=0.0, input_resolution=256): - def __init__(self, in_nc=3, config=None, dim=64, drop_path_rate=0.0, input_resolution=256): - super(SCUNet, self).__init__() - if config is None: - config = [2, 2, 2, 2, 2, 2, 2] - self.config = config - self.dim = dim - self.head_dim = 32 - self.window_size = 8 - - # drop path rate for each layer - dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(config))] - - self.m_head = [nn.Conv2d(in_nc, dim, 3, 1, 1, bias=False)] - - begin = 0 - self.m_down1 = [ConvTransBlock(dim // 2, dim // 2, self.head_dim, self.window_size, dpr[i + begin], - 'W' if not i % 2 else 'SW', input_resolution) - for i in range(config[0])] + \ - [nn.Conv2d(dim, 2 * dim, 2, 2, 0, bias=False)] - - begin += config[0] - self.m_down2 = [ConvTransBlock(dim, dim, self.head_dim, self.window_size, dpr[i + begin], - 'W' if not i % 2 else 'SW', input_resolution // 2) - for i in range(config[1])] + \ - [nn.Conv2d(2 * dim, 4 * dim, 2, 2, 0, bias=False)] - - begin += config[1] - self.m_down3 = [ConvTransBlock(2 * dim, 2 * dim, self.head_dim, self.window_size, dpr[i + begin], - 'W' if not i % 2 else 'SW', input_resolution // 4) - for i in range(config[2])] + \ - [nn.Conv2d(4 * dim, 8 * dim, 2, 2, 0, bias=False)] - - begin += config[2] - self.m_body = [ConvTransBlock(4 * dim, 4 * dim, self.head_dim, self.window_size, dpr[i + begin], - 'W' if not i % 2 else 'SW', input_resolution // 8) - for i in range(config[3])] - - begin += config[3] - self.m_up3 = [nn.ConvTranspose2d(8 * dim, 4 * dim, 2, 2, 0, bias=False), ] + \ - [ConvTransBlock(2 * dim, 2 * dim, self.head_dim, self.window_size, dpr[i + begin], - 'W' if not i % 2 else 'SW', input_resolution // 4) - for i in range(config[4])] - - begin += config[4] - self.m_up2 = [nn.ConvTranspose2d(4 * dim, 2 * dim, 2, 2, 0, bias=False), ] + \ - [ConvTransBlock(dim, dim, self.head_dim, self.window_size, dpr[i + begin], - 'W' if not i % 2 else 'SW', input_resolution // 2) - for i in range(config[5])] - - begin += config[5] - self.m_up1 = [nn.ConvTranspose2d(2 * dim, dim, 2, 2, 0, bias=False), ] + \ - [ConvTransBlock(dim // 2, dim // 2, self.head_dim, self.window_size, dpr[i + begin], - 'W' if not i % 2 else 'SW', input_resolution) - for i in range(config[6])] - - self.m_tail = [nn.Conv2d(dim, in_nc, 3, 1, 1, bias=False)] - - self.m_head = nn.Sequential(*self.m_head) - self.m_down1 = nn.Sequential(*self.m_down1) - self.m_down2 = nn.Sequential(*self.m_down2) - self.m_down3 = nn.Sequential(*self.m_down3) - self.m_body = nn.Sequential(*self.m_body) - self.m_up3 = nn.Sequential(*self.m_up3) - self.m_up2 = nn.Sequential(*self.m_up2) - self.m_up1 = nn.Sequential(*self.m_up1) - self.m_tail = nn.Sequential(*self.m_tail) - # self.apply(self._init_weights) - - def forward(self, x0): - - h, w = x0.size()[-2:] - paddingBottom = int(np.ceil(h / 64) * 64 - h) - paddingRight = int(np.ceil(w / 64) * 64 - w) - x0 = nn.ReplicationPad2d((0, paddingRight, 0, paddingBottom))(x0) - - x1 = self.m_head(x0) - x2 = self.m_down1(x1) - x3 = self.m_down2(x2) - x4 = self.m_down3(x3) - x = self.m_body(x4) - x = self.m_up3(x + x4) - x = self.m_up2(x + x3) - x = self.m_up1(x + x2) - x = self.m_tail(x + x1) - - x = x[..., :h, :w] - - return x - - def _init_weights(self, m): - if isinstance(m, nn.Linear): - trunc_normal_(m.weight, std=.02) - if m.bias is not None: - nn.init.constant_(m.bias, 0) - elif isinstance(m, nn.LayerNorm): - nn.init.constant_(m.bias, 0) - nn.init.constant_(m.weight, 1.0) diff --git a/extensions-builtin/SwinIR/scripts/swinir_model.py b/extensions-builtin/SwinIR/scripts/swinir_model.py index ae0d0e6a8ea..85c18b9e939 100644 --- a/extensions-builtin/SwinIR/scripts/swinir_model.py +++ b/extensions-builtin/SwinIR/scripts/swinir_model.py @@ -1,5 +1,5 @@ +import logging import sys -import platform import numpy as np import torch @@ -8,13 +8,11 @@ from modules import modelloader, devices, script_callbacks, shared from modules.shared import opts, state -from swinir_model_arch import SwinIR -from swinir_model_arch_v2 import Swin2SR from modules.upscaler import Upscaler, UpscalerData SWINIR_MODEL_URL = "https://github.com/JingyunLiang/SwinIR/releases/download/v0.0/003_realSR_BSRGAN_DFOWMFC_s64w8_SwinIR-L_x4_GAN.pth" -device_swinir = devices.get_device_for('swinir') +logger = logging.getLogger(__name__) class UpscalerSwinIR(Upscaler): @@ -37,26 +35,29 @@ def __init__(self, dirname): scalers.append(model_data) self.scalers = scalers - def do_upscale(self, img, model_file): - use_compile = hasattr(opts, 'SWIN_torch_compile') and opts.SWIN_torch_compile \ - and int(torch.__version__.split('.')[0]) >= 2 and platform.system() != "Windows" + def do_upscale(self, img: Image.Image, model_file: str) -> Image.Image: current_config = (model_file, opts.SWIN_tile) - if use_compile and self._cached_model_config == current_config: + device = self._get_device() + + if self._cached_model_config == current_config: model = self._cached_model else: - self._cached_model = None try: model = self.load_model(model_file) except Exception as e: print(f"Failed loading SwinIR model {model_file}: {e}", file=sys.stderr) return img - model = model.to(device_swinir, dtype=devices.dtype) - if use_compile: - model = torch.compile(model) - self._cached_model = model - self._cached_model_config = current_config - img = upscale(img, model) + self._cached_model = model + self._cached_model_config = current_config + + img = upscale( + img, + model, + tile=opts.SWIN_tile, + tile_overlap=opts.SWIN_tile_overlap, + device=device, + ) devices.torch_gc() return img @@ -69,69 +70,54 @@ def load_model(self, path, scale=4): ) else: filename = path - if filename.endswith(".v2.pth"): - model = Swin2SR( - upscale=scale, - in_chans=3, - img_size=64, - window_size=8, - img_range=1.0, - depths=[6, 6, 6, 6, 6, 6], - embed_dim=180, - num_heads=[6, 6, 6, 6, 6, 6], - mlp_ratio=2, - upsampler="nearest+conv", - resi_connection="1conv", - ) - params = None - else: - model = SwinIR( - upscale=scale, - in_chans=3, - img_size=64, - window_size=8, - img_range=1.0, - depths=[6, 6, 6, 6, 6, 6, 6, 6, 6], - embed_dim=240, - num_heads=[8, 8, 8, 8, 8, 8, 8, 8, 8], - mlp_ratio=2, - upsampler="nearest+conv", - resi_connection="3conv", - ) - params = "params_ema" - pretrained_model = torch.load(filename) - if params is not None: - model.load_state_dict(pretrained_model[params], strict=True) - else: - model.load_state_dict(pretrained_model, strict=True) + model = modelloader.load_spandrel_model( + filename, + device=self._get_device(), + dtype=devices.dtype, + ) + if getattr(opts, 'SWIN_torch_compile', False): + try: + model = torch.compile(model) + except Exception: + logger.warning("Failed to compile SwinIR model, fallback to JIT", exc_info=True) return model + def _get_device(self): + return devices.get_device_for('swinir') + def upscale( - img, - model, - tile=None, - tile_overlap=None, - window_size=8, - scale=4, + img, + model, + *, + tile: int, + tile_overlap: int, + window_size=8, + scale=4, + device, ): - tile = tile or opts.SWIN_tile - tile_overlap = tile_overlap or opts.SWIN_tile_overlap - img = np.array(img) img = img[:, :, ::-1] img = np.moveaxis(img, 2, 0) / 255 img = torch.from_numpy(img).float() - img = img.unsqueeze(0).to(device_swinir, dtype=devices.dtype) + img = img.unsqueeze(0).to(device, dtype=devices.dtype) with torch.no_grad(), devices.autocast(): _, _, h_old, w_old = img.size() h_pad = (h_old // window_size + 1) * window_size - h_old w_pad = (w_old // window_size + 1) * window_size - w_old img = torch.cat([img, torch.flip(img, [2])], 2)[:, :, : h_old + h_pad, :] img = torch.cat([img, torch.flip(img, [3])], 3)[:, :, :, : w_old + w_pad] - output = inference(img, model, tile, tile_overlap, window_size, scale) + output = inference( + img, + model, + tile=tile, + tile_overlap=tile_overlap, + window_size=window_size, + scale=scale, + device=device, + ) output = output[..., : h_old * scale, : w_old * scale] output = output.data.squeeze().float().cpu().clamp_(0, 1).numpy() if output.ndim == 3: @@ -142,7 +128,16 @@ def upscale( return Image.fromarray(output, "RGB") -def inference(img, model, tile, tile_overlap, window_size, scale): +def inference( + img, + model, + *, + tile: int, + tile_overlap: int, + window_size: int, + scale: int, + device, +): # test the image tile by tile b, c, h, w = img.size() tile = min(tile, h, w) @@ -152,8 +147,8 @@ def inference(img, model, tile, tile_overlap, window_size, scale): stride = tile - tile_overlap h_idx_list = list(range(0, h - tile, stride)) + [h - tile] w_idx_list = list(range(0, w - tile, stride)) + [w - tile] - E = torch.zeros(b, c, h * sf, w * sf, dtype=devices.dtype, device=device_swinir).type_as(img) - W = torch.zeros_like(E, dtype=devices.dtype, device=device_swinir) + E = torch.zeros(b, c, h * sf, w * sf, dtype=devices.dtype, device=device).type_as(img) + W = torch.zeros_like(E, dtype=devices.dtype, device=device) with tqdm(total=len(h_idx_list) * len(w_idx_list), desc="SwinIR tiles") as pbar: for h_idx in h_idx_list: @@ -185,8 +180,7 @@ def on_ui_settings(): shared.opts.add_option("SWIN_tile", shared.OptionInfo(192, "Tile size for all SwinIR.", gr.Slider, {"minimum": 16, "maximum": 512, "step": 16}, section=('upscaling', "Upscaling"))) shared.opts.add_option("SWIN_tile_overlap", shared.OptionInfo(8, "Tile overlap, in pixels for SwinIR. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}, section=('upscaling', "Upscaling"))) - if int(torch.__version__.split('.')[0]) >= 2 and platform.system() != "Windows": # torch.compile() require pytorch 2.0 or above, and not on Windows - shared.opts.add_option("SWIN_torch_compile", shared.OptionInfo(False, "Use torch.compile to accelerate SwinIR.", gr.Checkbox, {"interactive": True}, section=('upscaling', "Upscaling")).info("Takes longer on first run")) + shared.opts.add_option("SWIN_torch_compile", shared.OptionInfo(False, "Use torch.compile to accelerate SwinIR.", gr.Checkbox, {"interactive": True}, section=('upscaling', "Upscaling")).info("Takes longer on first run")) script_callbacks.on_ui_settings(on_ui_settings) diff --git a/extensions-builtin/SwinIR/swinir_model_arch.py b/extensions-builtin/SwinIR/swinir_model_arch.py deleted file mode 100644 index 93b9327473a..00000000000 --- a/extensions-builtin/SwinIR/swinir_model_arch.py +++ /dev/null @@ -1,867 +0,0 @@ -# ----------------------------------------------------------------------------------- -# SwinIR: Image Restoration Using Swin Transformer, https://arxiv.org/abs/2108.10257 -# Originally Written by Ze Liu, Modified by Jingyun Liang. -# ----------------------------------------------------------------------------------- - -import math -import torch -import torch.nn as nn -import torch.nn.functional as F -import torch.utils.checkpoint as checkpoint -from timm.models.layers import DropPath, to_2tuple, trunc_normal_ - - -class Mlp(nn.Module): - def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): - super().__init__() - out_features = out_features or in_features - hidden_features = hidden_features or in_features - self.fc1 = nn.Linear(in_features, hidden_features) - self.act = act_layer() - self.fc2 = nn.Linear(hidden_features, out_features) - self.drop = nn.Dropout(drop) - - def forward(self, x): - x = self.fc1(x) - x = self.act(x) - x = self.drop(x) - x = self.fc2(x) - x = self.drop(x) - return x - - -def window_partition(x, window_size): - """ - Args: - x: (B, H, W, C) - window_size (int): window size - - Returns: - windows: (num_windows*B, window_size, window_size, C) - """ - B, H, W, C = x.shape - x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) - windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) - return windows - - -def window_reverse(windows, window_size, H, W): - """ - Args: - windows: (num_windows*B, window_size, window_size, C) - window_size (int): Window size - H (int): Height of image - W (int): Width of image - - Returns: - x: (B, H, W, C) - """ - B = int(windows.shape[0] / (H * W / window_size / window_size)) - x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) - x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) - return x - - -class WindowAttention(nn.Module): - r""" Window based multi-head self attention (W-MSA) module with relative position bias. - It supports both of shifted and non-shifted window. - - Args: - dim (int): Number of input channels. - window_size (tuple[int]): The height and width of the window. - num_heads (int): Number of attention heads. - qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True - qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set - attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 - proj_drop (float, optional): Dropout ratio of output. Default: 0.0 - """ - - def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): - - super().__init__() - self.dim = dim - self.window_size = window_size # Wh, Ww - self.num_heads = num_heads - head_dim = dim // num_heads - self.scale = qk_scale or head_dim ** -0.5 - - # define a parameter table of relative position bias - self.relative_position_bias_table = nn.Parameter( - torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH - - # get pair-wise relative position index for each token inside the window - coords_h = torch.arange(self.window_size[0]) - coords_w = torch.arange(self.window_size[1]) - coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww - coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww - relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww - relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 - relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 - relative_coords[:, :, 1] += self.window_size[1] - 1 - relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 - relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww - self.register_buffer("relative_position_index", relative_position_index) - - self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) - self.attn_drop = nn.Dropout(attn_drop) - self.proj = nn.Linear(dim, dim) - - self.proj_drop = nn.Dropout(proj_drop) - - trunc_normal_(self.relative_position_bias_table, std=.02) - self.softmax = nn.Softmax(dim=-1) - - def forward(self, x, mask=None): - """ - Args: - x: input features with shape of (num_windows*B, N, C) - mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None - """ - B_, N, C = x.shape - qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) - q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) - - q = q * self.scale - attn = (q @ k.transpose(-2, -1)) - - relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( - self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH - relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww - attn = attn + relative_position_bias.unsqueeze(0) - - if mask is not None: - nW = mask.shape[0] - attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) - attn = attn.view(-1, self.num_heads, N, N) - attn = self.softmax(attn) - else: - attn = self.softmax(attn) - - attn = self.attn_drop(attn) - - x = (attn @ v).transpose(1, 2).reshape(B_, N, C) - x = self.proj(x) - x = self.proj_drop(x) - return x - - def extra_repr(self) -> str: - return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}' - - def flops(self, N): - # calculate flops for 1 window with token length of N - flops = 0 - # qkv = self.qkv(x) - flops += N * self.dim * 3 * self.dim - # attn = (q @ k.transpose(-2, -1)) - flops += self.num_heads * N * (self.dim // self.num_heads) * N - # x = (attn @ v) - flops += self.num_heads * N * N * (self.dim // self.num_heads) - # x = self.proj(x) - flops += N * self.dim * self.dim - return flops - - -class SwinTransformerBlock(nn.Module): - r""" Swin Transformer Block. - - Args: - dim (int): Number of input channels. - input_resolution (tuple[int]): Input resolution. - num_heads (int): Number of attention heads. - window_size (int): Window size. - shift_size (int): Shift size for SW-MSA. - mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. - qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True - qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. - drop (float, optional): Dropout rate. Default: 0.0 - attn_drop (float, optional): Attention dropout rate. Default: 0.0 - drop_path (float, optional): Stochastic depth rate. Default: 0.0 - act_layer (nn.Module, optional): Activation layer. Default: nn.GELU - norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm - """ - - def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, - mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., - act_layer=nn.GELU, norm_layer=nn.LayerNorm): - super().__init__() - self.dim = dim - self.input_resolution = input_resolution - self.num_heads = num_heads - self.window_size = window_size - self.shift_size = shift_size - self.mlp_ratio = mlp_ratio - if min(self.input_resolution) <= self.window_size: - # if window size is larger than input resolution, we don't partition windows - self.shift_size = 0 - self.window_size = min(self.input_resolution) - assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" - - self.norm1 = norm_layer(dim) - self.attn = WindowAttention( - dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, - qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) - - self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() - self.norm2 = norm_layer(dim) - mlp_hidden_dim = int(dim * mlp_ratio) - self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) - - if self.shift_size > 0: - attn_mask = self.calculate_mask(self.input_resolution) - else: - attn_mask = None - - self.register_buffer("attn_mask", attn_mask) - - def calculate_mask(self, x_size): - # calculate attention mask for SW-MSA - H, W = x_size - img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 - h_slices = (slice(0, -self.window_size), - slice(-self.window_size, -self.shift_size), - slice(-self.shift_size, None)) - w_slices = (slice(0, -self.window_size), - slice(-self.window_size, -self.shift_size), - slice(-self.shift_size, None)) - cnt = 0 - for h in h_slices: - for w in w_slices: - img_mask[:, h, w, :] = cnt - cnt += 1 - - mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 - mask_windows = mask_windows.view(-1, self.window_size * self.window_size) - attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) - attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) - - return attn_mask - - def forward(self, x, x_size): - H, W = x_size - B, L, C = x.shape - # assert L == H * W, "input feature has wrong size" - - shortcut = x - x = self.norm1(x) - x = x.view(B, H, W, C) - - # cyclic shift - if self.shift_size > 0: - shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) - else: - shifted_x = x - - # partition windows - x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C - x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C - - # W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size - if self.input_resolution == x_size: - attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C - else: - attn_windows = self.attn(x_windows, mask=self.calculate_mask(x_size).to(x.device)) - - # merge windows - attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) - shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C - - # reverse cyclic shift - if self.shift_size > 0: - x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) - else: - x = shifted_x - x = x.view(B, H * W, C) - - # FFN - x = shortcut + self.drop_path(x) - x = x + self.drop_path(self.mlp(self.norm2(x))) - - return x - - def extra_repr(self) -> str: - return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ - f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" - - def flops(self): - flops = 0 - H, W = self.input_resolution - # norm1 - flops += self.dim * H * W - # W-MSA/SW-MSA - nW = H * W / self.window_size / self.window_size - flops += nW * self.attn.flops(self.window_size * self.window_size) - # mlp - flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio - # norm2 - flops += self.dim * H * W - return flops - - -class PatchMerging(nn.Module): - r""" Patch Merging Layer. - - Args: - input_resolution (tuple[int]): Resolution of input feature. - dim (int): Number of input channels. - norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm - """ - - def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): - super().__init__() - self.input_resolution = input_resolution - self.dim = dim - self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) - self.norm = norm_layer(4 * dim) - - def forward(self, x): - """ - x: B, H*W, C - """ - H, W = self.input_resolution - B, L, C = x.shape - assert L == H * W, "input feature has wrong size" - assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." - - x = x.view(B, H, W, C) - - x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C - x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C - x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C - x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C - x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C - x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C - - x = self.norm(x) - x = self.reduction(x) - - return x - - def extra_repr(self) -> str: - return f"input_resolution={self.input_resolution}, dim={self.dim}" - - def flops(self): - H, W = self.input_resolution - flops = H * W * self.dim - flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim - return flops - - -class BasicLayer(nn.Module): - """ A basic Swin Transformer layer for one stage. - - Args: - dim (int): Number of input channels. - input_resolution (tuple[int]): Input resolution. - depth (int): Number of blocks. - num_heads (int): Number of attention heads. - window_size (int): Local window size. - mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. - qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True - qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. - drop (float, optional): Dropout rate. Default: 0.0 - attn_drop (float, optional): Attention dropout rate. Default: 0.0 - drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 - norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm - downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None - use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. - """ - - def __init__(self, dim, input_resolution, depth, num_heads, window_size, - mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., - drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False): - - super().__init__() - self.dim = dim - self.input_resolution = input_resolution - self.depth = depth - self.use_checkpoint = use_checkpoint - - # build blocks - self.blocks = nn.ModuleList([ - SwinTransformerBlock(dim=dim, input_resolution=input_resolution, - num_heads=num_heads, window_size=window_size, - shift_size=0 if (i % 2 == 0) else window_size // 2, - mlp_ratio=mlp_ratio, - qkv_bias=qkv_bias, qk_scale=qk_scale, - drop=drop, attn_drop=attn_drop, - drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, - norm_layer=norm_layer) - for i in range(depth)]) - - # patch merging layer - if downsample is not None: - self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) - else: - self.downsample = None - - def forward(self, x, x_size): - for blk in self.blocks: - if self.use_checkpoint: - x = checkpoint.checkpoint(blk, x, x_size) - else: - x = blk(x, x_size) - if self.downsample is not None: - x = self.downsample(x) - return x - - def extra_repr(self) -> str: - return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" - - def flops(self): - flops = 0 - for blk in self.blocks: - flops += blk.flops() - if self.downsample is not None: - flops += self.downsample.flops() - return flops - - -class RSTB(nn.Module): - """Residual Swin Transformer Block (RSTB). - - Args: - dim (int): Number of input channels. - input_resolution (tuple[int]): Input resolution. - depth (int): Number of blocks. - num_heads (int): Number of attention heads. - window_size (int): Local window size. - mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. - qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True - qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. - drop (float, optional): Dropout rate. Default: 0.0 - attn_drop (float, optional): Attention dropout rate. Default: 0.0 - drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 - norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm - downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None - use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. - img_size: Input image size. - patch_size: Patch size. - resi_connection: The convolutional block before residual connection. - """ - - def __init__(self, dim, input_resolution, depth, num_heads, window_size, - mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., - drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False, - img_size=224, patch_size=4, resi_connection='1conv'): - super(RSTB, self).__init__() - - self.dim = dim - self.input_resolution = input_resolution - - self.residual_group = BasicLayer(dim=dim, - input_resolution=input_resolution, - depth=depth, - num_heads=num_heads, - window_size=window_size, - mlp_ratio=mlp_ratio, - qkv_bias=qkv_bias, qk_scale=qk_scale, - drop=drop, attn_drop=attn_drop, - drop_path=drop_path, - norm_layer=norm_layer, - downsample=downsample, - use_checkpoint=use_checkpoint) - - if resi_connection == '1conv': - self.conv = nn.Conv2d(dim, dim, 3, 1, 1) - elif resi_connection == '3conv': - # to save parameters and memory - self.conv = nn.Sequential(nn.Conv2d(dim, dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True), - nn.Conv2d(dim // 4, dim // 4, 1, 1, 0), - nn.LeakyReLU(negative_slope=0.2, inplace=True), - nn.Conv2d(dim // 4, dim, 3, 1, 1)) - - self.patch_embed = PatchEmbed( - img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, - norm_layer=None) - - self.patch_unembed = PatchUnEmbed( - img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, - norm_layer=None) - - def forward(self, x, x_size): - return self.patch_embed(self.conv(self.patch_unembed(self.residual_group(x, x_size), x_size))) + x - - def flops(self): - flops = 0 - flops += self.residual_group.flops() - H, W = self.input_resolution - flops += H * W * self.dim * self.dim * 9 - flops += self.patch_embed.flops() - flops += self.patch_unembed.flops() - - return flops - - -class PatchEmbed(nn.Module): - r""" Image to Patch Embedding - - Args: - img_size (int): Image size. Default: 224. - patch_size (int): Patch token size. Default: 4. - in_chans (int): Number of input image channels. Default: 3. - embed_dim (int): Number of linear projection output channels. Default: 96. - norm_layer (nn.Module, optional): Normalization layer. Default: None - """ - - def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): - super().__init__() - img_size = to_2tuple(img_size) - patch_size = to_2tuple(patch_size) - patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] - self.img_size = img_size - self.patch_size = patch_size - self.patches_resolution = patches_resolution - self.num_patches = patches_resolution[0] * patches_resolution[1] - - self.in_chans = in_chans - self.embed_dim = embed_dim - - if norm_layer is not None: - self.norm = norm_layer(embed_dim) - else: - self.norm = None - - def forward(self, x): - x = x.flatten(2).transpose(1, 2) # B Ph*Pw C - if self.norm is not None: - x = self.norm(x) - return x - - def flops(self): - flops = 0 - H, W = self.img_size - if self.norm is not None: - flops += H * W * self.embed_dim - return flops - - -class PatchUnEmbed(nn.Module): - r""" Image to Patch Unembedding - - Args: - img_size (int): Image size. Default: 224. - patch_size (int): Patch token size. Default: 4. - in_chans (int): Number of input image channels. Default: 3. - embed_dim (int): Number of linear projection output channels. Default: 96. - norm_layer (nn.Module, optional): Normalization layer. Default: None - """ - - def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): - super().__init__() - img_size = to_2tuple(img_size) - patch_size = to_2tuple(patch_size) - patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] - self.img_size = img_size - self.patch_size = patch_size - self.patches_resolution = patches_resolution - self.num_patches = patches_resolution[0] * patches_resolution[1] - - self.in_chans = in_chans - self.embed_dim = embed_dim - - def forward(self, x, x_size): - B, HW, C = x.shape - x = x.transpose(1, 2).view(B, self.embed_dim, x_size[0], x_size[1]) # B Ph*Pw C - return x - - def flops(self): - flops = 0 - return flops - - -class Upsample(nn.Sequential): - """Upsample module. - - Args: - scale (int): Scale factor. Supported scales: 2^n and 3. - num_feat (int): Channel number of intermediate features. - """ - - def __init__(self, scale, num_feat): - m = [] - if (scale & (scale - 1)) == 0: # scale = 2^n - for _ in range(int(math.log(scale, 2))): - m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1)) - m.append(nn.PixelShuffle(2)) - elif scale == 3: - m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1)) - m.append(nn.PixelShuffle(3)) - else: - raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.') - super(Upsample, self).__init__(*m) - - -class UpsampleOneStep(nn.Sequential): - """UpsampleOneStep module (the difference with Upsample is that it always only has 1conv + 1pixelshuffle) - Used in lightweight SR to save parameters. - - Args: - scale (int): Scale factor. Supported scales: 2^n and 3. - num_feat (int): Channel number of intermediate features. - - """ - - def __init__(self, scale, num_feat, num_out_ch, input_resolution=None): - self.num_feat = num_feat - self.input_resolution = input_resolution - m = [] - m.append(nn.Conv2d(num_feat, (scale ** 2) * num_out_ch, 3, 1, 1)) - m.append(nn.PixelShuffle(scale)) - super(UpsampleOneStep, self).__init__(*m) - - def flops(self): - H, W = self.input_resolution - flops = H * W * self.num_feat * 3 * 9 - return flops - - -class SwinIR(nn.Module): - r""" SwinIR - A PyTorch impl of : `SwinIR: Image Restoration Using Swin Transformer`, based on Swin Transformer. - - Args: - img_size (int | tuple(int)): Input image size. Default 64 - patch_size (int | tuple(int)): Patch size. Default: 1 - in_chans (int): Number of input image channels. Default: 3 - embed_dim (int): Patch embedding dimension. Default: 96 - depths (tuple(int)): Depth of each Swin Transformer layer. - num_heads (tuple(int)): Number of attention heads in different layers. - window_size (int): Window size. Default: 7 - mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 - qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True - qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None - drop_rate (float): Dropout rate. Default: 0 - attn_drop_rate (float): Attention dropout rate. Default: 0 - drop_path_rate (float): Stochastic depth rate. Default: 0.1 - norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. - ape (bool): If True, add absolute position embedding to the patch embedding. Default: False - patch_norm (bool): If True, add normalization after patch embedding. Default: True - use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False - upscale: Upscale factor. 2/3/4/8 for image SR, 1 for denoising and compress artifact reduction - img_range: Image range. 1. or 255. - upsampler: The reconstruction reconstruction module. 'pixelshuffle'/'pixelshuffledirect'/'nearest+conv'/None - resi_connection: The convolutional block before residual connection. '1conv'/'3conv' - """ - - def __init__(self, img_size=64, patch_size=1, in_chans=3, - embed_dim=96, depths=(6, 6, 6, 6), num_heads=(6, 6, 6, 6), - window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None, - drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, - norm_layer=nn.LayerNorm, ape=False, patch_norm=True, - use_checkpoint=False, upscale=2, img_range=1., upsampler='', resi_connection='1conv', - **kwargs): - super(SwinIR, self).__init__() - num_in_ch = in_chans - num_out_ch = in_chans - num_feat = 64 - self.img_range = img_range - if in_chans == 3: - rgb_mean = (0.4488, 0.4371, 0.4040) - self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1) - else: - self.mean = torch.zeros(1, 1, 1, 1) - self.upscale = upscale - self.upsampler = upsampler - self.window_size = window_size - - ##################################################################################################### - ################################### 1, shallow feature extraction ################################### - self.conv_first = nn.Conv2d(num_in_ch, embed_dim, 3, 1, 1) - - ##################################################################################################### - ################################### 2, deep feature extraction ###################################### - self.num_layers = len(depths) - self.embed_dim = embed_dim - self.ape = ape - self.patch_norm = patch_norm - self.num_features = embed_dim - self.mlp_ratio = mlp_ratio - - # split image into non-overlapping patches - self.patch_embed = PatchEmbed( - img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim, - norm_layer=norm_layer if self.patch_norm else None) - num_patches = self.patch_embed.num_patches - patches_resolution = self.patch_embed.patches_resolution - self.patches_resolution = patches_resolution - - # merge non-overlapping patches into image - self.patch_unembed = PatchUnEmbed( - img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim, - norm_layer=norm_layer if self.patch_norm else None) - - # absolute position embedding - if self.ape: - self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) - trunc_normal_(self.absolute_pos_embed, std=.02) - - self.pos_drop = nn.Dropout(p=drop_rate) - - # stochastic depth - dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule - - # build Residual Swin Transformer blocks (RSTB) - self.layers = nn.ModuleList() - for i_layer in range(self.num_layers): - layer = RSTB(dim=embed_dim, - input_resolution=(patches_resolution[0], - patches_resolution[1]), - depth=depths[i_layer], - num_heads=num_heads[i_layer], - window_size=window_size, - mlp_ratio=self.mlp_ratio, - qkv_bias=qkv_bias, qk_scale=qk_scale, - drop=drop_rate, attn_drop=attn_drop_rate, - drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results - norm_layer=norm_layer, - downsample=None, - use_checkpoint=use_checkpoint, - img_size=img_size, - patch_size=patch_size, - resi_connection=resi_connection - - ) - self.layers.append(layer) - self.norm = norm_layer(self.num_features) - - # build the last conv layer in deep feature extraction - if resi_connection == '1conv': - self.conv_after_body = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1) - elif resi_connection == '3conv': - # to save parameters and memory - self.conv_after_body = nn.Sequential(nn.Conv2d(embed_dim, embed_dim // 4, 3, 1, 1), - nn.LeakyReLU(negative_slope=0.2, inplace=True), - nn.Conv2d(embed_dim // 4, embed_dim // 4, 1, 1, 0), - nn.LeakyReLU(negative_slope=0.2, inplace=True), - nn.Conv2d(embed_dim // 4, embed_dim, 3, 1, 1)) - - ##################################################################################################### - ################################ 3, high quality image reconstruction ################################ - if self.upsampler == 'pixelshuffle': - # for classical SR - self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1), - nn.LeakyReLU(inplace=True)) - self.upsample = Upsample(upscale, num_feat) - self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) - elif self.upsampler == 'pixelshuffledirect': - # for lightweight SR (to save parameters) - self.upsample = UpsampleOneStep(upscale, embed_dim, num_out_ch, - (patches_resolution[0], patches_resolution[1])) - elif self.upsampler == 'nearest+conv': - # for real-world SR (less artifacts) - self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1), - nn.LeakyReLU(inplace=True)) - self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) - if self.upscale == 4: - self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) - self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1) - self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) - self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) - else: - # for image denoising and JPEG compression artifact reduction - self.conv_last = nn.Conv2d(embed_dim, num_out_ch, 3, 1, 1) - - self.apply(self._init_weights) - - def _init_weights(self, m): - if isinstance(m, nn.Linear): - trunc_normal_(m.weight, std=.02) - if isinstance(m, nn.Linear) and m.bias is not None: - nn.init.constant_(m.bias, 0) - elif isinstance(m, nn.LayerNorm): - nn.init.constant_(m.bias, 0) - nn.init.constant_(m.weight, 1.0) - - @torch.jit.ignore - def no_weight_decay(self): - return {'absolute_pos_embed'} - - @torch.jit.ignore - def no_weight_decay_keywords(self): - return {'relative_position_bias_table'} - - def check_image_size(self, x): - _, _, h, w = x.size() - mod_pad_h = (self.window_size - h % self.window_size) % self.window_size - mod_pad_w = (self.window_size - w % self.window_size) % self.window_size - x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), 'reflect') - return x - - def forward_features(self, x): - x_size = (x.shape[2], x.shape[3]) - x = self.patch_embed(x) - if self.ape: - x = x + self.absolute_pos_embed - x = self.pos_drop(x) - - for layer in self.layers: - x = layer(x, x_size) - - x = self.norm(x) # B L C - x = self.patch_unembed(x, x_size) - - return x - - def forward(self, x): - H, W = x.shape[2:] - x = self.check_image_size(x) - - self.mean = self.mean.type_as(x) - x = (x - self.mean) * self.img_range - - if self.upsampler == 'pixelshuffle': - # for classical SR - x = self.conv_first(x) - x = self.conv_after_body(self.forward_features(x)) + x - x = self.conv_before_upsample(x) - x = self.conv_last(self.upsample(x)) - elif self.upsampler == 'pixelshuffledirect': - # for lightweight SR - x = self.conv_first(x) - x = self.conv_after_body(self.forward_features(x)) + x - x = self.upsample(x) - elif self.upsampler == 'nearest+conv': - # for real-world SR - x = self.conv_first(x) - x = self.conv_after_body(self.forward_features(x)) + x - x = self.conv_before_upsample(x) - x = self.lrelu(self.conv_up1(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest'))) - if self.upscale == 4: - x = self.lrelu(self.conv_up2(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest'))) - x = self.conv_last(self.lrelu(self.conv_hr(x))) - else: - # for image denoising and JPEG compression artifact reduction - x_first = self.conv_first(x) - res = self.conv_after_body(self.forward_features(x_first)) + x_first - x = x + self.conv_last(res) - - x = x / self.img_range + self.mean - - return x[:, :, :H*self.upscale, :W*self.upscale] - - def flops(self): - flops = 0 - H, W = self.patches_resolution - flops += H * W * 3 * self.embed_dim * 9 - flops += self.patch_embed.flops() - for layer in self.layers: - flops += layer.flops() - flops += H * W * 3 * self.embed_dim * self.embed_dim - flops += self.upsample.flops() - return flops - - -if __name__ == '__main__': - upscale = 4 - window_size = 8 - height = (1024 // upscale // window_size + 1) * window_size - width = (720 // upscale // window_size + 1) * window_size - model = SwinIR(upscale=2, img_size=(height, width), - window_size=window_size, img_range=1., depths=[6, 6, 6, 6], - embed_dim=60, num_heads=[6, 6, 6, 6], mlp_ratio=2, upsampler='pixelshuffledirect') - print(model) - print(height, width, model.flops() / 1e9) - - x = torch.randn((1, 3, height, width)) - x = model(x) - print(x.shape) diff --git a/extensions-builtin/SwinIR/swinir_model_arch_v2.py b/extensions-builtin/SwinIR/swinir_model_arch_v2.py deleted file mode 100644 index dad22cca29e..00000000000 --- a/extensions-builtin/SwinIR/swinir_model_arch_v2.py +++ /dev/null @@ -1,1017 +0,0 @@ -# ----------------------------------------------------------------------------------- -# Swin2SR: Swin2SR: SwinV2 Transformer for Compressed Image Super-Resolution and Restoration, https://arxiv.org/abs/ -# Written by Conde and Choi et al. -# ----------------------------------------------------------------------------------- - -import math -import numpy as np -import torch -import torch.nn as nn -import torch.nn.functional as F -import torch.utils.checkpoint as checkpoint -from timm.models.layers import DropPath, to_2tuple, trunc_normal_ - - -class Mlp(nn.Module): - def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): - super().__init__() - out_features = out_features or in_features - hidden_features = hidden_features or in_features - self.fc1 = nn.Linear(in_features, hidden_features) - self.act = act_layer() - self.fc2 = nn.Linear(hidden_features, out_features) - self.drop = nn.Dropout(drop) - - def forward(self, x): - x = self.fc1(x) - x = self.act(x) - x = self.drop(x) - x = self.fc2(x) - x = self.drop(x) - return x - - -def window_partition(x, window_size): - """ - Args: - x: (B, H, W, C) - window_size (int): window size - Returns: - windows: (num_windows*B, window_size, window_size, C) - """ - B, H, W, C = x.shape - x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) - windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) - return windows - - -def window_reverse(windows, window_size, H, W): - """ - Args: - windows: (num_windows*B, window_size, window_size, C) - window_size (int): Window size - H (int): Height of image - W (int): Width of image - Returns: - x: (B, H, W, C) - """ - B = int(windows.shape[0] / (H * W / window_size / window_size)) - x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) - x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) - return x - -class WindowAttention(nn.Module): - r""" Window based multi-head self attention (W-MSA) module with relative position bias. - It supports both of shifted and non-shifted window. - Args: - dim (int): Number of input channels. - window_size (tuple[int]): The height and width of the window. - num_heads (int): Number of attention heads. - qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True - attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 - proj_drop (float, optional): Dropout ratio of output. Default: 0.0 - pretrained_window_size (tuple[int]): The height and width of the window in pre-training. - """ - - def __init__(self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0., - pretrained_window_size=(0, 0)): - - super().__init__() - self.dim = dim - self.window_size = window_size # Wh, Ww - self.pretrained_window_size = pretrained_window_size - self.num_heads = num_heads - - self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1))), requires_grad=True) - - # mlp to generate continuous relative position bias - self.cpb_mlp = nn.Sequential(nn.Linear(2, 512, bias=True), - nn.ReLU(inplace=True), - nn.Linear(512, num_heads, bias=False)) - - # get relative_coords_table - relative_coords_h = torch.arange(-(self.window_size[0] - 1), self.window_size[0], dtype=torch.float32) - relative_coords_w = torch.arange(-(self.window_size[1] - 1), self.window_size[1], dtype=torch.float32) - relative_coords_table = torch.stack( - torch.meshgrid([relative_coords_h, - relative_coords_w])).permute(1, 2, 0).contiguous().unsqueeze(0) # 1, 2*Wh-1, 2*Ww-1, 2 - if pretrained_window_size[0] > 0: - relative_coords_table[:, :, :, 0] /= (pretrained_window_size[0] - 1) - relative_coords_table[:, :, :, 1] /= (pretrained_window_size[1] - 1) - else: - relative_coords_table[:, :, :, 0] /= (self.window_size[0] - 1) - relative_coords_table[:, :, :, 1] /= (self.window_size[1] - 1) - relative_coords_table *= 8 # normalize to -8, 8 - relative_coords_table = torch.sign(relative_coords_table) * torch.log2( - torch.abs(relative_coords_table) + 1.0) / np.log2(8) - - self.register_buffer("relative_coords_table", relative_coords_table) - - # get pair-wise relative position index for each token inside the window - coords_h = torch.arange(self.window_size[0]) - coords_w = torch.arange(self.window_size[1]) - coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww - coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww - relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww - relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 - relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 - relative_coords[:, :, 1] += self.window_size[1] - 1 - relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 - relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww - self.register_buffer("relative_position_index", relative_position_index) - - self.qkv = nn.Linear(dim, dim * 3, bias=False) - if qkv_bias: - self.q_bias = nn.Parameter(torch.zeros(dim)) - self.v_bias = nn.Parameter(torch.zeros(dim)) - else: - self.q_bias = None - self.v_bias = None - self.attn_drop = nn.Dropout(attn_drop) - self.proj = nn.Linear(dim, dim) - self.proj_drop = nn.Dropout(proj_drop) - self.softmax = nn.Softmax(dim=-1) - - def forward(self, x, mask=None): - """ - Args: - x: input features with shape of (num_windows*B, N, C) - mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None - """ - B_, N, C = x.shape - qkv_bias = None - if self.q_bias is not None: - qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias)) - qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias) - qkv = qkv.reshape(B_, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) - q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) - - # cosine attention - attn = (F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1)) - logit_scale = torch.clamp(self.logit_scale, max=torch.log(torch.tensor(1. / 0.01)).to(self.logit_scale.device)).exp() - attn = attn * logit_scale - - relative_position_bias_table = self.cpb_mlp(self.relative_coords_table).view(-1, self.num_heads) - relative_position_bias = relative_position_bias_table[self.relative_position_index.view(-1)].view( - self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH - relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww - relative_position_bias = 16 * torch.sigmoid(relative_position_bias) - attn = attn + relative_position_bias.unsqueeze(0) - - if mask is not None: - nW = mask.shape[0] - attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) - attn = attn.view(-1, self.num_heads, N, N) - attn = self.softmax(attn) - else: - attn = self.softmax(attn) - - attn = self.attn_drop(attn) - - x = (attn @ v).transpose(1, 2).reshape(B_, N, C) - x = self.proj(x) - x = self.proj_drop(x) - return x - - def extra_repr(self) -> str: - return f'dim={self.dim}, window_size={self.window_size}, ' \ - f'pretrained_window_size={self.pretrained_window_size}, num_heads={self.num_heads}' - - def flops(self, N): - # calculate flops for 1 window with token length of N - flops = 0 - # qkv = self.qkv(x) - flops += N * self.dim * 3 * self.dim - # attn = (q @ k.transpose(-2, -1)) - flops += self.num_heads * N * (self.dim // self.num_heads) * N - # x = (attn @ v) - flops += self.num_heads * N * N * (self.dim // self.num_heads) - # x = self.proj(x) - flops += N * self.dim * self.dim - return flops - -class SwinTransformerBlock(nn.Module): - r""" Swin Transformer Block. - Args: - dim (int): Number of input channels. - input_resolution (tuple[int]): Input resulotion. - num_heads (int): Number of attention heads. - window_size (int): Window size. - shift_size (int): Shift size for SW-MSA. - mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. - qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True - drop (float, optional): Dropout rate. Default: 0.0 - attn_drop (float, optional): Attention dropout rate. Default: 0.0 - drop_path (float, optional): Stochastic depth rate. Default: 0.0 - act_layer (nn.Module, optional): Activation layer. Default: nn.GELU - norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm - pretrained_window_size (int): Window size in pre-training. - """ - - def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, - mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., drop_path=0., - act_layer=nn.GELU, norm_layer=nn.LayerNorm, pretrained_window_size=0): - super().__init__() - self.dim = dim - self.input_resolution = input_resolution - self.num_heads = num_heads - self.window_size = window_size - self.shift_size = shift_size - self.mlp_ratio = mlp_ratio - if min(self.input_resolution) <= self.window_size: - # if window size is larger than input resolution, we don't partition windows - self.shift_size = 0 - self.window_size = min(self.input_resolution) - assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" - - self.norm1 = norm_layer(dim) - self.attn = WindowAttention( - dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, - qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop, - pretrained_window_size=to_2tuple(pretrained_window_size)) - - self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() - self.norm2 = norm_layer(dim) - mlp_hidden_dim = int(dim * mlp_ratio) - self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) - - if self.shift_size > 0: - attn_mask = self.calculate_mask(self.input_resolution) - else: - attn_mask = None - - self.register_buffer("attn_mask", attn_mask) - - def calculate_mask(self, x_size): - # calculate attention mask for SW-MSA - H, W = x_size - img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 - h_slices = (slice(0, -self.window_size), - slice(-self.window_size, -self.shift_size), - slice(-self.shift_size, None)) - w_slices = (slice(0, -self.window_size), - slice(-self.window_size, -self.shift_size), - slice(-self.shift_size, None)) - cnt = 0 - for h in h_slices: - for w in w_slices: - img_mask[:, h, w, :] = cnt - cnt += 1 - - mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 - mask_windows = mask_windows.view(-1, self.window_size * self.window_size) - attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) - attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) - - return attn_mask - - def forward(self, x, x_size): - H, W = x_size - B, L, C = x.shape - #assert L == H * W, "input feature has wrong size" - - shortcut = x - x = x.view(B, H, W, C) - - # cyclic shift - if self.shift_size > 0: - shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) - else: - shifted_x = x - - # partition windows - x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C - x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C - - # W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size - if self.input_resolution == x_size: - attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C - else: - attn_windows = self.attn(x_windows, mask=self.calculate_mask(x_size).to(x.device)) - - # merge windows - attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) - shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C - - # reverse cyclic shift - if self.shift_size > 0: - x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) - else: - x = shifted_x - x = x.view(B, H * W, C) - x = shortcut + self.drop_path(self.norm1(x)) - - # FFN - x = x + self.drop_path(self.norm2(self.mlp(x))) - - return x - - def extra_repr(self) -> str: - return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ - f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" - - def flops(self): - flops = 0 - H, W = self.input_resolution - # norm1 - flops += self.dim * H * W - # W-MSA/SW-MSA - nW = H * W / self.window_size / self.window_size - flops += nW * self.attn.flops(self.window_size * self.window_size) - # mlp - flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio - # norm2 - flops += self.dim * H * W - return flops - -class PatchMerging(nn.Module): - r""" Patch Merging Layer. - Args: - input_resolution (tuple[int]): Resolution of input feature. - dim (int): Number of input channels. - norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm - """ - - def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): - super().__init__() - self.input_resolution = input_resolution - self.dim = dim - self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) - self.norm = norm_layer(2 * dim) - - def forward(self, x): - """ - x: B, H*W, C - """ - H, W = self.input_resolution - B, L, C = x.shape - assert L == H * W, "input feature has wrong size" - assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." - - x = x.view(B, H, W, C) - - x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C - x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C - x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C - x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C - x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C - x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C - - x = self.reduction(x) - x = self.norm(x) - - return x - - def extra_repr(self) -> str: - return f"input_resolution={self.input_resolution}, dim={self.dim}" - - def flops(self): - H, W = self.input_resolution - flops = (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim - flops += H * W * self.dim // 2 - return flops - -class BasicLayer(nn.Module): - """ A basic Swin Transformer layer for one stage. - Args: - dim (int): Number of input channels. - input_resolution (tuple[int]): Input resolution. - depth (int): Number of blocks. - num_heads (int): Number of attention heads. - window_size (int): Local window size. - mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. - qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True - drop (float, optional): Dropout rate. Default: 0.0 - attn_drop (float, optional): Attention dropout rate. Default: 0.0 - drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 - norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm - downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None - use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. - pretrained_window_size (int): Local window size in pre-training. - """ - - def __init__(self, dim, input_resolution, depth, num_heads, window_size, - mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., - drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False, - pretrained_window_size=0): - - super().__init__() - self.dim = dim - self.input_resolution = input_resolution - self.depth = depth - self.use_checkpoint = use_checkpoint - - # build blocks - self.blocks = nn.ModuleList([ - SwinTransformerBlock(dim=dim, input_resolution=input_resolution, - num_heads=num_heads, window_size=window_size, - shift_size=0 if (i % 2 == 0) else window_size // 2, - mlp_ratio=mlp_ratio, - qkv_bias=qkv_bias, - drop=drop, attn_drop=attn_drop, - drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, - norm_layer=norm_layer, - pretrained_window_size=pretrained_window_size) - for i in range(depth)]) - - # patch merging layer - if downsample is not None: - self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) - else: - self.downsample = None - - def forward(self, x, x_size): - for blk in self.blocks: - if self.use_checkpoint: - x = checkpoint.checkpoint(blk, x, x_size) - else: - x = blk(x, x_size) - if self.downsample is not None: - x = self.downsample(x) - return x - - def extra_repr(self) -> str: - return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" - - def flops(self): - flops = 0 - for blk in self.blocks: - flops += blk.flops() - if self.downsample is not None: - flops += self.downsample.flops() - return flops - - def _init_respostnorm(self): - for blk in self.blocks: - nn.init.constant_(blk.norm1.bias, 0) - nn.init.constant_(blk.norm1.weight, 0) - nn.init.constant_(blk.norm2.bias, 0) - nn.init.constant_(blk.norm2.weight, 0) - -class PatchEmbed(nn.Module): - r""" Image to Patch Embedding - Args: - img_size (int): Image size. Default: 224. - patch_size (int): Patch token size. Default: 4. - in_chans (int): Number of input image channels. Default: 3. - embed_dim (int): Number of linear projection output channels. Default: 96. - norm_layer (nn.Module, optional): Normalization layer. Default: None - """ - - def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): - super().__init__() - img_size = to_2tuple(img_size) - patch_size = to_2tuple(patch_size) - patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] - self.img_size = img_size - self.patch_size = patch_size - self.patches_resolution = patches_resolution - self.num_patches = patches_resolution[0] * patches_resolution[1] - - self.in_chans = in_chans - self.embed_dim = embed_dim - - self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) - if norm_layer is not None: - self.norm = norm_layer(embed_dim) - else: - self.norm = None - - def forward(self, x): - B, C, H, W = x.shape - # FIXME look at relaxing size constraints - # assert H == self.img_size[0] and W == self.img_size[1], - # f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." - x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C - if self.norm is not None: - x = self.norm(x) - return x - - def flops(self): - Ho, Wo = self.patches_resolution - flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) - if self.norm is not None: - flops += Ho * Wo * self.embed_dim - return flops - -class RSTB(nn.Module): - """Residual Swin Transformer Block (RSTB). - - Args: - dim (int): Number of input channels. - input_resolution (tuple[int]): Input resolution. - depth (int): Number of blocks. - num_heads (int): Number of attention heads. - window_size (int): Local window size. - mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. - qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True - drop (float, optional): Dropout rate. Default: 0.0 - attn_drop (float, optional): Attention dropout rate. Default: 0.0 - drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 - norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm - downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None - use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. - img_size: Input image size. - patch_size: Patch size. - resi_connection: The convolutional block before residual connection. - """ - - def __init__(self, dim, input_resolution, depth, num_heads, window_size, - mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., - drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False, - img_size=224, patch_size=4, resi_connection='1conv'): - super(RSTB, self).__init__() - - self.dim = dim - self.input_resolution = input_resolution - - self.residual_group = BasicLayer(dim=dim, - input_resolution=input_resolution, - depth=depth, - num_heads=num_heads, - window_size=window_size, - mlp_ratio=mlp_ratio, - qkv_bias=qkv_bias, - drop=drop, attn_drop=attn_drop, - drop_path=drop_path, - norm_layer=norm_layer, - downsample=downsample, - use_checkpoint=use_checkpoint) - - if resi_connection == '1conv': - self.conv = nn.Conv2d(dim, dim, 3, 1, 1) - elif resi_connection == '3conv': - # to save parameters and memory - self.conv = nn.Sequential(nn.Conv2d(dim, dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True), - nn.Conv2d(dim // 4, dim // 4, 1, 1, 0), - nn.LeakyReLU(negative_slope=0.2, inplace=True), - nn.Conv2d(dim // 4, dim, 3, 1, 1)) - - self.patch_embed = PatchEmbed( - img_size=img_size, patch_size=patch_size, in_chans=dim, embed_dim=dim, - norm_layer=None) - - self.patch_unembed = PatchUnEmbed( - img_size=img_size, patch_size=patch_size, in_chans=dim, embed_dim=dim, - norm_layer=None) - - def forward(self, x, x_size): - return self.patch_embed(self.conv(self.patch_unembed(self.residual_group(x, x_size), x_size))) + x - - def flops(self): - flops = 0 - flops += self.residual_group.flops() - H, W = self.input_resolution - flops += H * W * self.dim * self.dim * 9 - flops += self.patch_embed.flops() - flops += self.patch_unembed.flops() - - return flops - -class PatchUnEmbed(nn.Module): - r""" Image to Patch Unembedding - - Args: - img_size (int): Image size. Default: 224. - patch_size (int): Patch token size. Default: 4. - in_chans (int): Number of input image channels. Default: 3. - embed_dim (int): Number of linear projection output channels. Default: 96. - norm_layer (nn.Module, optional): Normalization layer. Default: None - """ - - def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): - super().__init__() - img_size = to_2tuple(img_size) - patch_size = to_2tuple(patch_size) - patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] - self.img_size = img_size - self.patch_size = patch_size - self.patches_resolution = patches_resolution - self.num_patches = patches_resolution[0] * patches_resolution[1] - - self.in_chans = in_chans - self.embed_dim = embed_dim - - def forward(self, x, x_size): - B, HW, C = x.shape - x = x.transpose(1, 2).view(B, self.embed_dim, x_size[0], x_size[1]) # B Ph*Pw C - return x - - def flops(self): - flops = 0 - return flops - - -class Upsample(nn.Sequential): - """Upsample module. - - Args: - scale (int): Scale factor. Supported scales: 2^n and 3. - num_feat (int): Channel number of intermediate features. - """ - - def __init__(self, scale, num_feat): - m = [] - if (scale & (scale - 1)) == 0: # scale = 2^n - for _ in range(int(math.log(scale, 2))): - m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1)) - m.append(nn.PixelShuffle(2)) - elif scale == 3: - m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1)) - m.append(nn.PixelShuffle(3)) - else: - raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.') - super(Upsample, self).__init__(*m) - -class Upsample_hf(nn.Sequential): - """Upsample module. - - Args: - scale (int): Scale factor. Supported scales: 2^n and 3. - num_feat (int): Channel number of intermediate features. - """ - - def __init__(self, scale, num_feat): - m = [] - if (scale & (scale - 1)) == 0: # scale = 2^n - for _ in range(int(math.log(scale, 2))): - m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1)) - m.append(nn.PixelShuffle(2)) - elif scale == 3: - m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1)) - m.append(nn.PixelShuffle(3)) - else: - raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.') - super(Upsample_hf, self).__init__(*m) - - -class UpsampleOneStep(nn.Sequential): - """UpsampleOneStep module (the difference with Upsample is that it always only has 1conv + 1pixelshuffle) - Used in lightweight SR to save parameters. - - Args: - scale (int): Scale factor. Supported scales: 2^n and 3. - num_feat (int): Channel number of intermediate features. - - """ - - def __init__(self, scale, num_feat, num_out_ch, input_resolution=None): - self.num_feat = num_feat - self.input_resolution = input_resolution - m = [] - m.append(nn.Conv2d(num_feat, (scale ** 2) * num_out_ch, 3, 1, 1)) - m.append(nn.PixelShuffle(scale)) - super(UpsampleOneStep, self).__init__(*m) - - def flops(self): - H, W = self.input_resolution - flops = H * W * self.num_feat * 3 * 9 - return flops - - - -class Swin2SR(nn.Module): - r""" Swin2SR - A PyTorch impl of : `Swin2SR: SwinV2 Transformer for Compressed Image Super-Resolution and Restoration`. - - Args: - img_size (int | tuple(int)): Input image size. Default 64 - patch_size (int | tuple(int)): Patch size. Default: 1 - in_chans (int): Number of input image channels. Default: 3 - embed_dim (int): Patch embedding dimension. Default: 96 - depths (tuple(int)): Depth of each Swin Transformer layer. - num_heads (tuple(int)): Number of attention heads in different layers. - window_size (int): Window size. Default: 7 - mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 - qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True - drop_rate (float): Dropout rate. Default: 0 - attn_drop_rate (float): Attention dropout rate. Default: 0 - drop_path_rate (float): Stochastic depth rate. Default: 0.1 - norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. - ape (bool): If True, add absolute position embedding to the patch embedding. Default: False - patch_norm (bool): If True, add normalization after patch embedding. Default: True - use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False - upscale: Upscale factor. 2/3/4/8 for image SR, 1 for denoising and compress artifact reduction - img_range: Image range. 1. or 255. - upsampler: The reconstruction reconstruction module. 'pixelshuffle'/'pixelshuffledirect'/'nearest+conv'/None - resi_connection: The convolutional block before residual connection. '1conv'/'3conv' - """ - - def __init__(self, img_size=64, patch_size=1, in_chans=3, - embed_dim=96, depths=(6, 6, 6, 6), num_heads=(6, 6, 6, 6), - window_size=7, mlp_ratio=4., qkv_bias=True, - drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, - norm_layer=nn.LayerNorm, ape=False, patch_norm=True, - use_checkpoint=False, upscale=2, img_range=1., upsampler='', resi_connection='1conv', - **kwargs): - super(Swin2SR, self).__init__() - num_in_ch = in_chans - num_out_ch = in_chans - num_feat = 64 - self.img_range = img_range - if in_chans == 3: - rgb_mean = (0.4488, 0.4371, 0.4040) - self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1) - else: - self.mean = torch.zeros(1, 1, 1, 1) - self.upscale = upscale - self.upsampler = upsampler - self.window_size = window_size - - ##################################################################################################### - ################################### 1, shallow feature extraction ################################### - self.conv_first = nn.Conv2d(num_in_ch, embed_dim, 3, 1, 1) - - ##################################################################################################### - ################################### 2, deep feature extraction ###################################### - self.num_layers = len(depths) - self.embed_dim = embed_dim - self.ape = ape - self.patch_norm = patch_norm - self.num_features = embed_dim - self.mlp_ratio = mlp_ratio - - # split image into non-overlapping patches - self.patch_embed = PatchEmbed( - img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim, - norm_layer=norm_layer if self.patch_norm else None) - num_patches = self.patch_embed.num_patches - patches_resolution = self.patch_embed.patches_resolution - self.patches_resolution = patches_resolution - - # merge non-overlapping patches into image - self.patch_unembed = PatchUnEmbed( - img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim, - norm_layer=norm_layer if self.patch_norm else None) - - # absolute position embedding - if self.ape: - self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) - trunc_normal_(self.absolute_pos_embed, std=.02) - - self.pos_drop = nn.Dropout(p=drop_rate) - - # stochastic depth - dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule - - # build Residual Swin Transformer blocks (RSTB) - self.layers = nn.ModuleList() - for i_layer in range(self.num_layers): - layer = RSTB(dim=embed_dim, - input_resolution=(patches_resolution[0], - patches_resolution[1]), - depth=depths[i_layer], - num_heads=num_heads[i_layer], - window_size=window_size, - mlp_ratio=self.mlp_ratio, - qkv_bias=qkv_bias, - drop=drop_rate, attn_drop=attn_drop_rate, - drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results - norm_layer=norm_layer, - downsample=None, - use_checkpoint=use_checkpoint, - img_size=img_size, - patch_size=patch_size, - resi_connection=resi_connection - - ) - self.layers.append(layer) - - if self.upsampler == 'pixelshuffle_hf': - self.layers_hf = nn.ModuleList() - for i_layer in range(self.num_layers): - layer = RSTB(dim=embed_dim, - input_resolution=(patches_resolution[0], - patches_resolution[1]), - depth=depths[i_layer], - num_heads=num_heads[i_layer], - window_size=window_size, - mlp_ratio=self.mlp_ratio, - qkv_bias=qkv_bias, - drop=drop_rate, attn_drop=attn_drop_rate, - drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results - norm_layer=norm_layer, - downsample=None, - use_checkpoint=use_checkpoint, - img_size=img_size, - patch_size=patch_size, - resi_connection=resi_connection - - ) - self.layers_hf.append(layer) - - self.norm = norm_layer(self.num_features) - - # build the last conv layer in deep feature extraction - if resi_connection == '1conv': - self.conv_after_body = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1) - elif resi_connection == '3conv': - # to save parameters and memory - self.conv_after_body = nn.Sequential(nn.Conv2d(embed_dim, embed_dim // 4, 3, 1, 1), - nn.LeakyReLU(negative_slope=0.2, inplace=True), - nn.Conv2d(embed_dim // 4, embed_dim // 4, 1, 1, 0), - nn.LeakyReLU(negative_slope=0.2, inplace=True), - nn.Conv2d(embed_dim // 4, embed_dim, 3, 1, 1)) - - ##################################################################################################### - ################################ 3, high quality image reconstruction ################################ - if self.upsampler == 'pixelshuffle': - # for classical SR - self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1), - nn.LeakyReLU(inplace=True)) - self.upsample = Upsample(upscale, num_feat) - self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) - elif self.upsampler == 'pixelshuffle_aux': - self.conv_bicubic = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1) - self.conv_before_upsample = nn.Sequential( - nn.Conv2d(embed_dim, num_feat, 3, 1, 1), - nn.LeakyReLU(inplace=True)) - self.conv_aux = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) - self.conv_after_aux = nn.Sequential( - nn.Conv2d(3, num_feat, 3, 1, 1), - nn.LeakyReLU(inplace=True)) - self.upsample = Upsample(upscale, num_feat) - self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) - - elif self.upsampler == 'pixelshuffle_hf': - self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1), - nn.LeakyReLU(inplace=True)) - self.upsample = Upsample(upscale, num_feat) - self.upsample_hf = Upsample_hf(upscale, num_feat) - self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) - self.conv_first_hf = nn.Sequential(nn.Conv2d(num_feat, embed_dim, 3, 1, 1), - nn.LeakyReLU(inplace=True)) - self.conv_after_body_hf = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1) - self.conv_before_upsample_hf = nn.Sequential( - nn.Conv2d(embed_dim, num_feat, 3, 1, 1), - nn.LeakyReLU(inplace=True)) - self.conv_last_hf = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) - - elif self.upsampler == 'pixelshuffledirect': - # for lightweight SR (to save parameters) - self.upsample = UpsampleOneStep(upscale, embed_dim, num_out_ch, - (patches_resolution[0], patches_resolution[1])) - elif self.upsampler == 'nearest+conv': - # for real-world SR (less artifacts) - assert self.upscale == 4, 'only support x4 now.' - self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1), - nn.LeakyReLU(inplace=True)) - self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) - self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) - self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1) - self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) - self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) - else: - # for image denoising and JPEG compression artifact reduction - self.conv_last = nn.Conv2d(embed_dim, num_out_ch, 3, 1, 1) - - self.apply(self._init_weights) - - def _init_weights(self, m): - if isinstance(m, nn.Linear): - trunc_normal_(m.weight, std=.02) - if isinstance(m, nn.Linear) and m.bias is not None: - nn.init.constant_(m.bias, 0) - elif isinstance(m, nn.LayerNorm): - nn.init.constant_(m.bias, 0) - nn.init.constant_(m.weight, 1.0) - - @torch.jit.ignore - def no_weight_decay(self): - return {'absolute_pos_embed'} - - @torch.jit.ignore - def no_weight_decay_keywords(self): - return {'relative_position_bias_table'} - - def check_image_size(self, x): - _, _, h, w = x.size() - mod_pad_h = (self.window_size - h % self.window_size) % self.window_size - mod_pad_w = (self.window_size - w % self.window_size) % self.window_size - x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), 'reflect') - return x - - def forward_features(self, x): - x_size = (x.shape[2], x.shape[3]) - x = self.patch_embed(x) - if self.ape: - x = x + self.absolute_pos_embed - x = self.pos_drop(x) - - for layer in self.layers: - x = layer(x, x_size) - - x = self.norm(x) # B L C - x = self.patch_unembed(x, x_size) - - return x - - def forward_features_hf(self, x): - x_size = (x.shape[2], x.shape[3]) - x = self.patch_embed(x) - if self.ape: - x = x + self.absolute_pos_embed - x = self.pos_drop(x) - - for layer in self.layers_hf: - x = layer(x, x_size) - - x = self.norm(x) # B L C - x = self.patch_unembed(x, x_size) - - return x - - def forward(self, x): - H, W = x.shape[2:] - x = self.check_image_size(x) - - self.mean = self.mean.type_as(x) - x = (x - self.mean) * self.img_range - - if self.upsampler == 'pixelshuffle': - # for classical SR - x = self.conv_first(x) - x = self.conv_after_body(self.forward_features(x)) + x - x = self.conv_before_upsample(x) - x = self.conv_last(self.upsample(x)) - elif self.upsampler == 'pixelshuffle_aux': - bicubic = F.interpolate(x, size=(H * self.upscale, W * self.upscale), mode='bicubic', align_corners=False) - bicubic = self.conv_bicubic(bicubic) - x = self.conv_first(x) - x = self.conv_after_body(self.forward_features(x)) + x - x = self.conv_before_upsample(x) - aux = self.conv_aux(x) # b, 3, LR_H, LR_W - x = self.conv_after_aux(aux) - x = self.upsample(x)[:, :, :H * self.upscale, :W * self.upscale] + bicubic[:, :, :H * self.upscale, :W * self.upscale] - x = self.conv_last(x) - aux = aux / self.img_range + self.mean - elif self.upsampler == 'pixelshuffle_hf': - # for classical SR with HF - x = self.conv_first(x) - x = self.conv_after_body(self.forward_features(x)) + x - x_before = self.conv_before_upsample(x) - x_out = self.conv_last(self.upsample(x_before)) - - x_hf = self.conv_first_hf(x_before) - x_hf = self.conv_after_body_hf(self.forward_features_hf(x_hf)) + x_hf - x_hf = self.conv_before_upsample_hf(x_hf) - x_hf = self.conv_last_hf(self.upsample_hf(x_hf)) - x = x_out + x_hf - x_hf = x_hf / self.img_range + self.mean - - elif self.upsampler == 'pixelshuffledirect': - # for lightweight SR - x = self.conv_first(x) - x = self.conv_after_body(self.forward_features(x)) + x - x = self.upsample(x) - elif self.upsampler == 'nearest+conv': - # for real-world SR - x = self.conv_first(x) - x = self.conv_after_body(self.forward_features(x)) + x - x = self.conv_before_upsample(x) - x = self.lrelu(self.conv_up1(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest'))) - x = self.lrelu(self.conv_up2(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest'))) - x = self.conv_last(self.lrelu(self.conv_hr(x))) - else: - # for image denoising and JPEG compression artifact reduction - x_first = self.conv_first(x) - res = self.conv_after_body(self.forward_features(x_first)) + x_first - x = x + self.conv_last(res) - - x = x / self.img_range + self.mean - if self.upsampler == "pixelshuffle_aux": - return x[:, :, :H*self.upscale, :W*self.upscale], aux - - elif self.upsampler == "pixelshuffle_hf": - x_out = x_out / self.img_range + self.mean - return x_out[:, :, :H*self.upscale, :W*self.upscale], x[:, :, :H*self.upscale, :W*self.upscale], x_hf[:, :, :H*self.upscale, :W*self.upscale] - - else: - return x[:, :, :H*self.upscale, :W*self.upscale] - - def flops(self): - flops = 0 - H, W = self.patches_resolution - flops += H * W * 3 * self.embed_dim * 9 - flops += self.patch_embed.flops() - for layer in self.layers: - flops += layer.flops() - flops += H * W * 3 * self.embed_dim * self.embed_dim - flops += self.upsample.flops() - return flops - - -if __name__ == '__main__': - upscale = 4 - window_size = 8 - height = (1024 // upscale // window_size + 1) * window_size - width = (720 // upscale // window_size + 1) * window_size - model = Swin2SR(upscale=2, img_size=(height, width), - window_size=window_size, img_range=1., depths=[6, 6, 6, 6], - embed_dim=60, num_heads=[6, 6, 6, 6], mlp_ratio=2, upsampler='pixelshuffledirect') - print(model) - print(height, width, model.flops() / 1e9) - - x = torch.randn((1, 3, height, width)) - x = model(x) - print(x.shape) diff --git a/modules/codeformer/codeformer_arch.py b/modules/codeformer/codeformer_arch.py deleted file mode 100644 index 12db6814268..00000000000 --- a/modules/codeformer/codeformer_arch.py +++ /dev/null @@ -1,276 +0,0 @@ -# this file is copied from CodeFormer repository. Please see comment in modules/codeformer_model.py - -import math -import torch -from torch import nn, Tensor -import torch.nn.functional as F -from typing import Optional - -from modules.codeformer.vqgan_arch import VQAutoEncoder, ResBlock -from basicsr.utils.registry import ARCH_REGISTRY - -def calc_mean_std(feat, eps=1e-5): - """Calculate mean and std for adaptive_instance_normalization. - - Args: - feat (Tensor): 4D tensor. - eps (float): A small value added to the variance to avoid - divide-by-zero. Default: 1e-5. - """ - size = feat.size() - assert len(size) == 4, 'The input feature should be 4D tensor.' - b, c = size[:2] - feat_var = feat.view(b, c, -1).var(dim=2) + eps - feat_std = feat_var.sqrt().view(b, c, 1, 1) - feat_mean = feat.view(b, c, -1).mean(dim=2).view(b, c, 1, 1) - return feat_mean, feat_std - - -def adaptive_instance_normalization(content_feat, style_feat): - """Adaptive instance normalization. - - Adjust the reference features to have the similar color and illuminations - as those in the degradate features. - - Args: - content_feat (Tensor): The reference feature. - style_feat (Tensor): The degradate features. - """ - size = content_feat.size() - style_mean, style_std = calc_mean_std(style_feat) - content_mean, content_std = calc_mean_std(content_feat) - normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(size) - return normalized_feat * style_std.expand(size) + style_mean.expand(size) - - -class PositionEmbeddingSine(nn.Module): - """ - This is a more standard version of the position embedding, very similar to the one - used by the Attention is all you need paper, generalized to work on images. - """ - - def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None): - super().__init__() - self.num_pos_feats = num_pos_feats - self.temperature = temperature - self.normalize = normalize - if scale is not None and normalize is False: - raise ValueError("normalize should be True if scale is passed") - if scale is None: - scale = 2 * math.pi - self.scale = scale - - def forward(self, x, mask=None): - if mask is None: - mask = torch.zeros((x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool) - not_mask = ~mask - y_embed = not_mask.cumsum(1, dtype=torch.float32) - x_embed = not_mask.cumsum(2, dtype=torch.float32) - if self.normalize: - eps = 1e-6 - y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale - x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale - - dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) - dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) - - pos_x = x_embed[:, :, :, None] / dim_t - pos_y = y_embed[:, :, :, None] / dim_t - pos_x = torch.stack( - (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4 - ).flatten(3) - pos_y = torch.stack( - (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4 - ).flatten(3) - pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) - return pos - -def _get_activation_fn(activation): - """Return an activation function given a string""" - if activation == "relu": - return F.relu - if activation == "gelu": - return F.gelu - if activation == "glu": - return F.glu - raise RuntimeError(F"activation should be relu/gelu, not {activation}.") - - -class TransformerSALayer(nn.Module): - def __init__(self, embed_dim, nhead=8, dim_mlp=2048, dropout=0.0, activation="gelu"): - super().__init__() - self.self_attn = nn.MultiheadAttention(embed_dim, nhead, dropout=dropout) - # Implementation of Feedforward model - MLP - self.linear1 = nn.Linear(embed_dim, dim_mlp) - self.dropout = nn.Dropout(dropout) - self.linear2 = nn.Linear(dim_mlp, embed_dim) - - self.norm1 = nn.LayerNorm(embed_dim) - self.norm2 = nn.LayerNorm(embed_dim) - self.dropout1 = nn.Dropout(dropout) - self.dropout2 = nn.Dropout(dropout) - - self.activation = _get_activation_fn(activation) - - def with_pos_embed(self, tensor, pos: Optional[Tensor]): - return tensor if pos is None else tensor + pos - - def forward(self, tgt, - tgt_mask: Optional[Tensor] = None, - tgt_key_padding_mask: Optional[Tensor] = None, - query_pos: Optional[Tensor] = None): - - # self attention - tgt2 = self.norm1(tgt) - q = k = self.with_pos_embed(tgt2, query_pos) - tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask, - key_padding_mask=tgt_key_padding_mask)[0] - tgt = tgt + self.dropout1(tgt2) - - # ffn - tgt2 = self.norm2(tgt) - tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2)))) - tgt = tgt + self.dropout2(tgt2) - return tgt - -class Fuse_sft_block(nn.Module): - def __init__(self, in_ch, out_ch): - super().__init__() - self.encode_enc = ResBlock(2*in_ch, out_ch) - - self.scale = nn.Sequential( - nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1), - nn.LeakyReLU(0.2, True), - nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1)) - - self.shift = nn.Sequential( - nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1), - nn.LeakyReLU(0.2, True), - nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1)) - - def forward(self, enc_feat, dec_feat, w=1): - enc_feat = self.encode_enc(torch.cat([enc_feat, dec_feat], dim=1)) - scale = self.scale(enc_feat) - shift = self.shift(enc_feat) - residual = w * (dec_feat * scale + shift) - out = dec_feat + residual - return out - - -@ARCH_REGISTRY.register() -class CodeFormer(VQAutoEncoder): - def __init__(self, dim_embd=512, n_head=8, n_layers=9, - codebook_size=1024, latent_size=256, - connect_list=('32', '64', '128', '256'), - fix_modules=('quantize', 'generator')): - super(CodeFormer, self).__init__(512, 64, [1, 2, 2, 4, 4, 8], 'nearest',2, [16], codebook_size) - - if fix_modules is not None: - for module in fix_modules: - for param in getattr(self, module).parameters(): - param.requires_grad = False - - self.connect_list = connect_list - self.n_layers = n_layers - self.dim_embd = dim_embd - self.dim_mlp = dim_embd*2 - - self.position_emb = nn.Parameter(torch.zeros(latent_size, self.dim_embd)) - self.feat_emb = nn.Linear(256, self.dim_embd) - - # transformer - self.ft_layers = nn.Sequential(*[TransformerSALayer(embed_dim=dim_embd, nhead=n_head, dim_mlp=self.dim_mlp, dropout=0.0) - for _ in range(self.n_layers)]) - - # logits_predict head - self.idx_pred_layer = nn.Sequential( - nn.LayerNorm(dim_embd), - nn.Linear(dim_embd, codebook_size, bias=False)) - - self.channels = { - '16': 512, - '32': 256, - '64': 256, - '128': 128, - '256': 128, - '512': 64, - } - - # after second residual block for > 16, before attn layer for ==16 - self.fuse_encoder_block = {'512':2, '256':5, '128':8, '64':11, '32':14, '16':18} - # after first residual block for > 16, before attn layer for ==16 - self.fuse_generator_block = {'16':6, '32': 9, '64':12, '128':15, '256':18, '512':21} - - # fuse_convs_dict - self.fuse_convs_dict = nn.ModuleDict() - for f_size in self.connect_list: - in_ch = self.channels[f_size] - self.fuse_convs_dict[f_size] = Fuse_sft_block(in_ch, in_ch) - - def _init_weights(self, module): - if isinstance(module, (nn.Linear, nn.Embedding)): - module.weight.data.normal_(mean=0.0, std=0.02) - if isinstance(module, nn.Linear) and module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) - - def forward(self, x, w=0, detach_16=True, code_only=False, adain=False): - # ################### Encoder ##################### - enc_feat_dict = {} - out_list = [self.fuse_encoder_block[f_size] for f_size in self.connect_list] - for i, block in enumerate(self.encoder.blocks): - x = block(x) - if i in out_list: - enc_feat_dict[str(x.shape[-1])] = x.clone() - - lq_feat = x - # ################# Transformer ################### - # quant_feat, codebook_loss, quant_stats = self.quantize(lq_feat) - pos_emb = self.position_emb.unsqueeze(1).repeat(1,x.shape[0],1) - # BCHW -> BC(HW) -> (HW)BC - feat_emb = self.feat_emb(lq_feat.flatten(2).permute(2,0,1)) - query_emb = feat_emb - # Transformer encoder - for layer in self.ft_layers: - query_emb = layer(query_emb, query_pos=pos_emb) - - # output logits - logits = self.idx_pred_layer(query_emb) # (hw)bn - logits = logits.permute(1,0,2) # (hw)bn -> b(hw)n - - if code_only: # for training stage II - # logits doesn't need softmax before cross_entropy loss - return logits, lq_feat - - # ################# Quantization ################### - # if self.training: - # quant_feat = torch.einsum('btn,nc->btc', [soft_one_hot, self.quantize.embedding.weight]) - # # b(hw)c -> bc(hw) -> bchw - # quant_feat = quant_feat.permute(0,2,1).view(lq_feat.shape) - # ------------ - soft_one_hot = F.softmax(logits, dim=2) - _, top_idx = torch.topk(soft_one_hot, 1, dim=2) - quant_feat = self.quantize.get_codebook_feat(top_idx, shape=[x.shape[0],16,16,256]) - # preserve gradients - # quant_feat = lq_feat + (quant_feat - lq_feat).detach() - - if detach_16: - quant_feat = quant_feat.detach() # for training stage III - if adain: - quant_feat = adaptive_instance_normalization(quant_feat, lq_feat) - - # ################## Generator #################### - x = quant_feat - fuse_list = [self.fuse_generator_block[f_size] for f_size in self.connect_list] - - for i, block in enumerate(self.generator.blocks): - x = block(x) - if i in fuse_list: # fuse after i-th block - f_size = str(x.shape[-1]) - if w>0: - x = self.fuse_convs_dict[f_size](enc_feat_dict[f_size].detach(), x, w) - out = x - # logits doesn't need softmax before cross_entropy loss - return out, logits, lq_feat diff --git a/modules/codeformer/vqgan_arch.py b/modules/codeformer/vqgan_arch.py deleted file mode 100644 index 09ee6660dc5..00000000000 --- a/modules/codeformer/vqgan_arch.py +++ /dev/null @@ -1,435 +0,0 @@ -# this file is copied from CodeFormer repository. Please see comment in modules/codeformer_model.py - -''' -VQGAN code, adapted from the original created by the Unleashing Transformers authors: -https://github.com/samb-t/unleashing-transformers/blob/master/models/vqgan.py - -''' -import torch -import torch.nn as nn -import torch.nn.functional as F -from basicsr.utils import get_root_logger -from basicsr.utils.registry import ARCH_REGISTRY - -def normalize(in_channels): - return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) - - -@torch.jit.script -def swish(x): - return x*torch.sigmoid(x) - - -# Define VQVAE classes -class VectorQuantizer(nn.Module): - def __init__(self, codebook_size, emb_dim, beta): - super(VectorQuantizer, self).__init__() - self.codebook_size = codebook_size # number of embeddings - self.emb_dim = emb_dim # dimension of embedding - self.beta = beta # commitment cost used in loss term, beta * ||z_e(x)-sg[e]||^2 - self.embedding = nn.Embedding(self.codebook_size, self.emb_dim) - self.embedding.weight.data.uniform_(-1.0 / self.codebook_size, 1.0 / self.codebook_size) - - def forward(self, z): - # reshape z -> (batch, height, width, channel) and flatten - z = z.permute(0, 2, 3, 1).contiguous() - z_flattened = z.view(-1, self.emb_dim) - - # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z - d = (z_flattened ** 2).sum(dim=1, keepdim=True) + (self.embedding.weight**2).sum(1) - \ - 2 * torch.matmul(z_flattened, self.embedding.weight.t()) - - mean_distance = torch.mean(d) - # find closest encodings - # min_encoding_indices = torch.argmin(d, dim=1).unsqueeze(1) - min_encoding_scores, min_encoding_indices = torch.topk(d, 1, dim=1, largest=False) - # [0-1], higher score, higher confidence - min_encoding_scores = torch.exp(-min_encoding_scores/10) - - min_encodings = torch.zeros(min_encoding_indices.shape[0], self.codebook_size).to(z) - min_encodings.scatter_(1, min_encoding_indices, 1) - - # get quantized latent vectors - z_q = torch.matmul(min_encodings, self.embedding.weight).view(z.shape) - # compute loss for embedding - loss = torch.mean((z_q.detach()-z)**2) + self.beta * torch.mean((z_q - z.detach()) ** 2) - # preserve gradients - z_q = z + (z_q - z).detach() - - # perplexity - e_mean = torch.mean(min_encodings, dim=0) - perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + 1e-10))) - # reshape back to match original input shape - z_q = z_q.permute(0, 3, 1, 2).contiguous() - - return z_q, loss, { - "perplexity": perplexity, - "min_encodings": min_encodings, - "min_encoding_indices": min_encoding_indices, - "min_encoding_scores": min_encoding_scores, - "mean_distance": mean_distance - } - - def get_codebook_feat(self, indices, shape): - # input indices: batch*token_num -> (batch*token_num)*1 - # shape: batch, height, width, channel - indices = indices.view(-1,1) - min_encodings = torch.zeros(indices.shape[0], self.codebook_size).to(indices) - min_encodings.scatter_(1, indices, 1) - # get quantized latent vectors - z_q = torch.matmul(min_encodings.float(), self.embedding.weight) - - if shape is not None: # reshape back to match original input shape - z_q = z_q.view(shape).permute(0, 3, 1, 2).contiguous() - - return z_q - - -class GumbelQuantizer(nn.Module): - def __init__(self, codebook_size, emb_dim, num_hiddens, straight_through=False, kl_weight=5e-4, temp_init=1.0): - super().__init__() - self.codebook_size = codebook_size # number of embeddings - self.emb_dim = emb_dim # dimension of embedding - self.straight_through = straight_through - self.temperature = temp_init - self.kl_weight = kl_weight - self.proj = nn.Conv2d(num_hiddens, codebook_size, 1) # projects last encoder layer to quantized logits - self.embed = nn.Embedding(codebook_size, emb_dim) - - def forward(self, z): - hard = self.straight_through if self.training else True - - logits = self.proj(z) - - soft_one_hot = F.gumbel_softmax(logits, tau=self.temperature, dim=1, hard=hard) - - z_q = torch.einsum("b n h w, n d -> b d h w", soft_one_hot, self.embed.weight) - - # + kl divergence to the prior loss - qy = F.softmax(logits, dim=1) - diff = self.kl_weight * torch.sum(qy * torch.log(qy * self.codebook_size + 1e-10), dim=1).mean() - min_encoding_indices = soft_one_hot.argmax(dim=1) - - return z_q, diff, { - "min_encoding_indices": min_encoding_indices - } - - -class Downsample(nn.Module): - def __init__(self, in_channels): - super().__init__() - self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0) - - def forward(self, x): - pad = (0, 1, 0, 1) - x = torch.nn.functional.pad(x, pad, mode="constant", value=0) - x = self.conv(x) - return x - - -class Upsample(nn.Module): - def __init__(self, in_channels): - super().__init__() - self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1) - - def forward(self, x): - x = F.interpolate(x, scale_factor=2.0, mode="nearest") - x = self.conv(x) - - return x - - -class ResBlock(nn.Module): - def __init__(self, in_channels, out_channels=None): - super(ResBlock, self).__init__() - self.in_channels = in_channels - self.out_channels = in_channels if out_channels is None else out_channels - self.norm1 = normalize(in_channels) - self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) - self.norm2 = normalize(out_channels) - self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) - if self.in_channels != self.out_channels: - self.conv_out = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) - - def forward(self, x_in): - x = x_in - x = self.norm1(x) - x = swish(x) - x = self.conv1(x) - x = self.norm2(x) - x = swish(x) - x = self.conv2(x) - if self.in_channels != self.out_channels: - x_in = self.conv_out(x_in) - - return x + x_in - - -class AttnBlock(nn.Module): - def __init__(self, in_channels): - super().__init__() - self.in_channels = in_channels - - self.norm = normalize(in_channels) - self.q = torch.nn.Conv2d( - in_channels, - in_channels, - kernel_size=1, - stride=1, - padding=0 - ) - self.k = torch.nn.Conv2d( - in_channels, - in_channels, - kernel_size=1, - stride=1, - padding=0 - ) - self.v = torch.nn.Conv2d( - in_channels, - in_channels, - kernel_size=1, - stride=1, - padding=0 - ) - self.proj_out = torch.nn.Conv2d( - in_channels, - in_channels, - kernel_size=1, - stride=1, - padding=0 - ) - - def forward(self, x): - h_ = x - h_ = self.norm(h_) - q = self.q(h_) - k = self.k(h_) - v = self.v(h_) - - # compute attention - b, c, h, w = q.shape - q = q.reshape(b, c, h*w) - q = q.permute(0, 2, 1) - k = k.reshape(b, c, h*w) - w_ = torch.bmm(q, k) - w_ = w_ * (int(c)**(-0.5)) - w_ = F.softmax(w_, dim=2) - - # attend to values - v = v.reshape(b, c, h*w) - w_ = w_.permute(0, 2, 1) - h_ = torch.bmm(v, w_) - h_ = h_.reshape(b, c, h, w) - - h_ = self.proj_out(h_) - - return x+h_ - - -class Encoder(nn.Module): - def __init__(self, in_channels, nf, emb_dim, ch_mult, num_res_blocks, resolution, attn_resolutions): - super().__init__() - self.nf = nf - self.num_resolutions = len(ch_mult) - self.num_res_blocks = num_res_blocks - self.resolution = resolution - self.attn_resolutions = attn_resolutions - - curr_res = self.resolution - in_ch_mult = (1,)+tuple(ch_mult) - - blocks = [] - # initial convultion - blocks.append(nn.Conv2d(in_channels, nf, kernel_size=3, stride=1, padding=1)) - - # residual and downsampling blocks, with attention on smaller res (16x16) - for i in range(self.num_resolutions): - block_in_ch = nf * in_ch_mult[i] - block_out_ch = nf * ch_mult[i] - for _ in range(self.num_res_blocks): - blocks.append(ResBlock(block_in_ch, block_out_ch)) - block_in_ch = block_out_ch - if curr_res in attn_resolutions: - blocks.append(AttnBlock(block_in_ch)) - - if i != self.num_resolutions - 1: - blocks.append(Downsample(block_in_ch)) - curr_res = curr_res // 2 - - # non-local attention block - blocks.append(ResBlock(block_in_ch, block_in_ch)) - blocks.append(AttnBlock(block_in_ch)) - blocks.append(ResBlock(block_in_ch, block_in_ch)) - - # normalise and convert to latent size - blocks.append(normalize(block_in_ch)) - blocks.append(nn.Conv2d(block_in_ch, emb_dim, kernel_size=3, stride=1, padding=1)) - self.blocks = nn.ModuleList(blocks) - - def forward(self, x): - for block in self.blocks: - x = block(x) - - return x - - -class Generator(nn.Module): - def __init__(self, nf, emb_dim, ch_mult, res_blocks, img_size, attn_resolutions): - super().__init__() - self.nf = nf - self.ch_mult = ch_mult - self.num_resolutions = len(self.ch_mult) - self.num_res_blocks = res_blocks - self.resolution = img_size - self.attn_resolutions = attn_resolutions - self.in_channels = emb_dim - self.out_channels = 3 - block_in_ch = self.nf * self.ch_mult[-1] - curr_res = self.resolution // 2 ** (self.num_resolutions-1) - - blocks = [] - # initial conv - blocks.append(nn.Conv2d(self.in_channels, block_in_ch, kernel_size=3, stride=1, padding=1)) - - # non-local attention block - blocks.append(ResBlock(block_in_ch, block_in_ch)) - blocks.append(AttnBlock(block_in_ch)) - blocks.append(ResBlock(block_in_ch, block_in_ch)) - - for i in reversed(range(self.num_resolutions)): - block_out_ch = self.nf * self.ch_mult[i] - - for _ in range(self.num_res_blocks): - blocks.append(ResBlock(block_in_ch, block_out_ch)) - block_in_ch = block_out_ch - - if curr_res in self.attn_resolutions: - blocks.append(AttnBlock(block_in_ch)) - - if i != 0: - blocks.append(Upsample(block_in_ch)) - curr_res = curr_res * 2 - - blocks.append(normalize(block_in_ch)) - blocks.append(nn.Conv2d(block_in_ch, self.out_channels, kernel_size=3, stride=1, padding=1)) - - self.blocks = nn.ModuleList(blocks) - - - def forward(self, x): - for block in self.blocks: - x = block(x) - - return x - - -@ARCH_REGISTRY.register() -class VQAutoEncoder(nn.Module): - def __init__(self, img_size, nf, ch_mult, quantizer="nearest", res_blocks=2, attn_resolutions=None, codebook_size=1024, emb_dim=256, - beta=0.25, gumbel_straight_through=False, gumbel_kl_weight=1e-8, model_path=None): - super().__init__() - logger = get_root_logger() - self.in_channels = 3 - self.nf = nf - self.n_blocks = res_blocks - self.codebook_size = codebook_size - self.embed_dim = emb_dim - self.ch_mult = ch_mult - self.resolution = img_size - self.attn_resolutions = attn_resolutions or [16] - self.quantizer_type = quantizer - self.encoder = Encoder( - self.in_channels, - self.nf, - self.embed_dim, - self.ch_mult, - self.n_blocks, - self.resolution, - self.attn_resolutions - ) - if self.quantizer_type == "nearest": - self.beta = beta #0.25 - self.quantize = VectorQuantizer(self.codebook_size, self.embed_dim, self.beta) - elif self.quantizer_type == "gumbel": - self.gumbel_num_hiddens = emb_dim - self.straight_through = gumbel_straight_through - self.kl_weight = gumbel_kl_weight - self.quantize = GumbelQuantizer( - self.codebook_size, - self.embed_dim, - self.gumbel_num_hiddens, - self.straight_through, - self.kl_weight - ) - self.generator = Generator( - self.nf, - self.embed_dim, - self.ch_mult, - self.n_blocks, - self.resolution, - self.attn_resolutions - ) - - if model_path is not None: - chkpt = torch.load(model_path, map_location='cpu') - if 'params_ema' in chkpt: - self.load_state_dict(torch.load(model_path, map_location='cpu')['params_ema']) - logger.info(f'vqgan is loaded from: {model_path} [params_ema]') - elif 'params' in chkpt: - self.load_state_dict(torch.load(model_path, map_location='cpu')['params']) - logger.info(f'vqgan is loaded from: {model_path} [params]') - else: - raise ValueError('Wrong params!') - - - def forward(self, x): - x = self.encoder(x) - quant, codebook_loss, quant_stats = self.quantize(x) - x = self.generator(quant) - return x, codebook_loss, quant_stats - - - -# patch based discriminator -@ARCH_REGISTRY.register() -class VQGANDiscriminator(nn.Module): - def __init__(self, nc=3, ndf=64, n_layers=4, model_path=None): - super().__init__() - - layers = [nn.Conv2d(nc, ndf, kernel_size=4, stride=2, padding=1), nn.LeakyReLU(0.2, True)] - ndf_mult = 1 - ndf_mult_prev = 1 - for n in range(1, n_layers): # gradually increase the number of filters - ndf_mult_prev = ndf_mult - ndf_mult = min(2 ** n, 8) - layers += [ - nn.Conv2d(ndf * ndf_mult_prev, ndf * ndf_mult, kernel_size=4, stride=2, padding=1, bias=False), - nn.BatchNorm2d(ndf * ndf_mult), - nn.LeakyReLU(0.2, True) - ] - - ndf_mult_prev = ndf_mult - ndf_mult = min(2 ** n_layers, 8) - - layers += [ - nn.Conv2d(ndf * ndf_mult_prev, ndf * ndf_mult, kernel_size=4, stride=1, padding=1, bias=False), - nn.BatchNorm2d(ndf * ndf_mult), - nn.LeakyReLU(0.2, True) - ] - - layers += [ - nn.Conv2d(ndf * ndf_mult, 1, kernel_size=4, stride=1, padding=1)] # output 1 channel prediction map - self.main = nn.Sequential(*layers) - - if model_path is not None: - chkpt = torch.load(model_path, map_location='cpu') - if 'params_d' in chkpt: - self.load_state_dict(torch.load(model_path, map_location='cpu')['params_d']) - elif 'params' in chkpt: - self.load_state_dict(torch.load(model_path, map_location='cpu')['params']) - else: - raise ValueError('Wrong params!') - - def forward(self, x): - return self.main(x) diff --git a/modules/codeformer_model.py b/modules/codeformer_model.py index da42b5e9932..517eadfd8de 100644 --- a/modules/codeformer_model.py +++ b/modules/codeformer_model.py @@ -8,9 +8,6 @@ from modules import shared, devices, modelloader, errors from modules.paths import models_path -# codeformer people made a choice to include modified basicsr library to their project which makes -# it utterly impossible to use it alongside with other libraries that also use basicsr, like GFPGAN. -# I am making a choice to include some files from codeformer to work around this issue. model_dir = "Codeformer" model_path = os.path.join(models_path, model_dir) model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth' @@ -18,115 +15,127 @@ codeformer = None -def setup_model(dirname): - os.makedirs(model_path, exist_ok=True) - - path = modules.paths.paths.get("CodeFormer", None) - if path is None: - return - - try: +class FaceRestorerCodeFormer(modules.face_restoration.FaceRestoration): + def name(self): + return "CodeFormer" + + def __init__(self, dirname): + self.net = None + self.face_helper = None + self.cmd_dir = dirname + + def create_models(self): + from facexlib.detection import retinaface + from facexlib.utils.face_restoration_helper import FaceRestoreHelper + + if self.net is not None and self.face_helper is not None: + self.net.to(devices.device_codeformer) + return self.net, self.face_helper + model_paths = modelloader.load_models( + model_path, + model_url, + self.cmd_dir, + download_name='codeformer-v0.1.0.pth', + ext_filter=['.pth'], + ) + + if len(model_paths) != 0: + ckpt_path = model_paths[0] + else: + print("Unable to load codeformer model.") + return None, None + net = modelloader.load_spandrel_model(ckpt_path, device=devices.device_codeformer) + + if hasattr(retinaface, 'device'): + retinaface.device = devices.device_codeformer + + face_helper = FaceRestoreHelper( + upscale_factor=1, + face_size=512, + crop_ratio=(1, 1), + det_model='retinaface_resnet50', + save_ext='png', + use_parse=True, + device=devices.device_codeformer, + ) + + self.net = net + self.face_helper = face_helper + + def send_model_to(self, device): + self.net.to(device) + self.face_helper.face_det.to(device) + self.face_helper.face_parse.to(device) + + def restore(self, np_image, w=None): from torchvision.transforms.functional import normalize - from modules.codeformer.codeformer_arch import CodeFormer from basicsr.utils import img2tensor, tensor2img - from facelib.utils.face_restoration_helper import FaceRestoreHelper - from facelib.detection.retinaface import retinaface - - net_class = CodeFormer - - class FaceRestorerCodeFormer(modules.face_restoration.FaceRestoration): - def name(self): - return "CodeFormer" - - def __init__(self, dirname): - self.net = None - self.face_helper = None - self.cmd_dir = dirname - - def create_models(self): - - if self.net is not None and self.face_helper is not None: - self.net.to(devices.device_codeformer) - return self.net, self.face_helper - model_paths = modelloader.load_models(model_path, model_url, self.cmd_dir, download_name='codeformer-v0.1.0.pth', ext_filter=['.pth']) - if len(model_paths) != 0: - ckpt_path = model_paths[0] - else: - print("Unable to load codeformer model.") - return None, None - net = net_class(dim_embd=512, codebook_size=1024, n_head=8, n_layers=9, connect_list=['32', '64', '128', '256']).to(devices.device_codeformer) - checkpoint = torch.load(ckpt_path)['params_ema'] - net.load_state_dict(checkpoint) - net.eval() + np_image = np_image[:, :, ::-1] - if hasattr(retinaface, 'device'): - retinaface.device = devices.device_codeformer - face_helper = FaceRestoreHelper(1, face_size=512, crop_ratio=(1, 1), det_model='retinaface_resnet50', save_ext='png', use_parse=True, device=devices.device_codeformer) + original_resolution = np_image.shape[0:2] - self.net = net - self.face_helper = face_helper + self.create_models() + if self.net is None or self.face_helper is None: + return np_image - return net, face_helper + self.send_model_to(devices.device_codeformer) - def send_model_to(self, device): - self.net.to(device) - self.face_helper.face_det.to(device) - self.face_helper.face_parse.to(device) + self.face_helper.clean_all() + self.face_helper.read_image(np_image) + self.face_helper.get_face_landmarks_5(only_center_face=False, resize=640, eye_dist_threshold=5) + self.face_helper.align_warp_face() - def restore(self, np_image, w=None): - np_image = np_image[:, :, ::-1] + for cropped_face in self.face_helper.cropped_faces: + cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True) + normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True) + cropped_face_t = cropped_face_t.unsqueeze(0).to(devices.device_codeformer) - original_resolution = np_image.shape[0:2] + try: + with torch.no_grad(): + res = self.net(cropped_face_t, w=w if w is not None else shared.opts.code_former_weight, adain=True) + if isinstance(res, tuple): + output = res[0] + else: + output = res + if not isinstance(res, torch.Tensor): + raise TypeError(f"Expected torch.Tensor, got {type(res)}") + restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1)) + del output + devices.torch_gc() + except Exception: + errors.report('Failed inference for CodeFormer', exc_info=True) + restored_face = tensor2img(cropped_face_t, rgb2bgr=True, min_max=(-1, 1)) - self.create_models() - if self.net is None or self.face_helper is None: - return np_image + restored_face = restored_face.astype('uint8') + self.face_helper.add_restored_face(restored_face) - self.send_model_to(devices.device_codeformer) + self.face_helper.get_inverse_affine(None) - self.face_helper.clean_all() - self.face_helper.read_image(np_image) - self.face_helper.get_face_landmarks_5(only_center_face=False, resize=640, eye_dist_threshold=5) - self.face_helper.align_warp_face() + restored_img = self.face_helper.paste_faces_to_input_image() + restored_img = restored_img[:, :, ::-1] - for cropped_face in self.face_helper.cropped_faces: - cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True) - normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True) - cropped_face_t = cropped_face_t.unsqueeze(0).to(devices.device_codeformer) + if original_resolution != restored_img.shape[0:2]: + restored_img = cv2.resize( + restored_img, + (0, 0), + fx=original_resolution[1]/restored_img.shape[1], + fy=original_resolution[0]/restored_img.shape[0], + interpolation=cv2.INTER_LINEAR, + ) - try: - with torch.no_grad(): - output = self.net(cropped_face_t, w=w if w is not None else shared.opts.code_former_weight, adain=True)[0] - restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1)) - del output - devices.torch_gc() - except Exception: - errors.report('Failed inference for CodeFormer', exc_info=True) - restored_face = tensor2img(cropped_face_t, rgb2bgr=True, min_max=(-1, 1)) + self.face_helper.clean_all() - restored_face = restored_face.astype('uint8') - self.face_helper.add_restored_face(restored_face) + if shared.opts.face_restoration_unload: + self.send_model_to(devices.cpu) - self.face_helper.get_inverse_affine(None) + return restored_img - restored_img = self.face_helper.paste_faces_to_input_image() - restored_img = restored_img[:, :, ::-1] - - if original_resolution != restored_img.shape[0:2]: - restored_img = cv2.resize(restored_img, (0, 0), fx=original_resolution[1]/restored_img.shape[1], fy=original_resolution[0]/restored_img.shape[0], interpolation=cv2.INTER_LINEAR) - - self.face_helper.clean_all() - - if shared.opts.face_restoration_unload: - self.send_model_to(devices.cpu) - - return restored_img +def setup_model(dirname): + os.makedirs(model_path, exist_ok=True) + try: global codeformer codeformer = FaceRestorerCodeFormer(dirname) shared.face_restorers.append(codeformer) - except Exception: errors.report("Error setting up CodeFormer", exc_info=True) - - # sys.path = stored_sys_path diff --git a/modules/esrgan_model.py b/modules/esrgan_model.py index c0d22a99273..a7c7c9e3013 100644 --- a/modules/esrgan_model.py +++ b/modules/esrgan_model.py @@ -1,122 +1,9 @@ -import sys - -import torch - -import modules.esrgan_model_arch as arch -from modules import modelloader, devices +from modules import modelloader, devices, errors from modules.shared import opts from modules.upscaler import Upscaler, UpscalerData from modules.upscaler_utils import upscale_with_model -def mod2normal(state_dict): - # this code is copied from https://github.com/victorca25/iNNfer - if 'conv_first.weight' in state_dict: - crt_net = {} - items = list(state_dict) - - crt_net['model.0.weight'] = state_dict['conv_first.weight'] - crt_net['model.0.bias'] = state_dict['conv_first.bias'] - - for k in items.copy(): - if 'RDB' in k: - ori_k = k.replace('RRDB_trunk.', 'model.1.sub.') - if '.weight' in k: - ori_k = ori_k.replace('.weight', '.0.weight') - elif '.bias' in k: - ori_k = ori_k.replace('.bias', '.0.bias') - crt_net[ori_k] = state_dict[k] - items.remove(k) - - crt_net['model.1.sub.23.weight'] = state_dict['trunk_conv.weight'] - crt_net['model.1.sub.23.bias'] = state_dict['trunk_conv.bias'] - crt_net['model.3.weight'] = state_dict['upconv1.weight'] - crt_net['model.3.bias'] = state_dict['upconv1.bias'] - crt_net['model.6.weight'] = state_dict['upconv2.weight'] - crt_net['model.6.bias'] = state_dict['upconv2.bias'] - crt_net['model.8.weight'] = state_dict['HRconv.weight'] - crt_net['model.8.bias'] = state_dict['HRconv.bias'] - crt_net['model.10.weight'] = state_dict['conv_last.weight'] - crt_net['model.10.bias'] = state_dict['conv_last.bias'] - state_dict = crt_net - return state_dict - - -def resrgan2normal(state_dict, nb=23): - # this code is copied from https://github.com/victorca25/iNNfer - if "conv_first.weight" in state_dict and "body.0.rdb1.conv1.weight" in state_dict: - re8x = 0 - crt_net = {} - items = list(state_dict) - - crt_net['model.0.weight'] = state_dict['conv_first.weight'] - crt_net['model.0.bias'] = state_dict['conv_first.bias'] - - for k in items.copy(): - if "rdb" in k: - ori_k = k.replace('body.', 'model.1.sub.') - ori_k = ori_k.replace('.rdb', '.RDB') - if '.weight' in k: - ori_k = ori_k.replace('.weight', '.0.weight') - elif '.bias' in k: - ori_k = ori_k.replace('.bias', '.0.bias') - crt_net[ori_k] = state_dict[k] - items.remove(k) - - crt_net[f'model.1.sub.{nb}.weight'] = state_dict['conv_body.weight'] - crt_net[f'model.1.sub.{nb}.bias'] = state_dict['conv_body.bias'] - crt_net['model.3.weight'] = state_dict['conv_up1.weight'] - crt_net['model.3.bias'] = state_dict['conv_up1.bias'] - crt_net['model.6.weight'] = state_dict['conv_up2.weight'] - crt_net['model.6.bias'] = state_dict['conv_up2.bias'] - - if 'conv_up3.weight' in state_dict: - # modification supporting: https://github.com/ai-forever/Real-ESRGAN/blob/main/RealESRGAN/rrdbnet_arch.py - re8x = 3 - crt_net['model.9.weight'] = state_dict['conv_up3.weight'] - crt_net['model.9.bias'] = state_dict['conv_up3.bias'] - - crt_net[f'model.{8+re8x}.weight'] = state_dict['conv_hr.weight'] - crt_net[f'model.{8+re8x}.bias'] = state_dict['conv_hr.bias'] - crt_net[f'model.{10+re8x}.weight'] = state_dict['conv_last.weight'] - crt_net[f'model.{10+re8x}.bias'] = state_dict['conv_last.bias'] - - state_dict = crt_net - return state_dict - - -def infer_params(state_dict): - # this code is copied from https://github.com/victorca25/iNNfer - scale2x = 0 - scalemin = 6 - n_uplayer = 0 - plus = False - - for block in list(state_dict): - parts = block.split(".") - n_parts = len(parts) - if n_parts == 5 and parts[2] == "sub": - nb = int(parts[3]) - elif n_parts == 3: - part_num = int(parts[1]) - if (part_num > scalemin - and parts[0] == "model" - and parts[2] == "weight"): - scale2x += 1 - if part_num > n_uplayer: - n_uplayer = part_num - out_nc = state_dict[block].shape[0] - if not plus and "conv1x1" in block: - plus = True - - nf = state_dict["model.0.weight"].shape[0] - in_nc = state_dict["model.0.weight"].shape[1] - out_nc = out_nc - scale = 2 ** scale2x - - return in_nc, out_nc, nf, nb, plus, scale - - class UpscalerESRGAN(Upscaler): def __init__(self, dirname): self.name = "ESRGAN" @@ -142,12 +29,11 @@ def __init__(self, dirname): def do_upscale(self, img, selected_model): try: model = self.load_model(selected_model) - except Exception as e: - print(f"Unable to load ESRGAN model {selected_model}: {e}", file=sys.stderr) + except Exception: + errors.report(f"Unable to load ESRGAN model {selected_model}", exc_info=True) return img model.to(devices.device_esrgan) - img = esrgan_upscale(model, img) - return img + return esrgan_upscale(model, img) def load_model(self, path: str): if path.startswith("http"): @@ -160,33 +46,10 @@ def load_model(self, path: str): else: filename = path - state_dict = torch.load(filename, map_location='cpu' if devices.device_esrgan.type == 'mps' else None) - - if "params_ema" in state_dict: - state_dict = state_dict["params_ema"] - elif "params" in state_dict: - state_dict = state_dict["params"] - num_conv = 16 if "realesr-animevideov3" in filename else 32 - model = arch.SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=num_conv, upscale=4, act_type='prelu') - model.load_state_dict(state_dict) - model.eval() - return model - - if "body.0.rdb1.conv1.weight" in state_dict and "conv_first.weight" in state_dict: - nb = 6 if "RealESRGAN_x4plus_anime_6B" in filename else 23 - state_dict = resrgan2normal(state_dict, nb) - elif "conv_first.weight" in state_dict: - state_dict = mod2normal(state_dict) - elif "model.0.weight" not in state_dict: - raise Exception("The file is not a recognized ESRGAN model.") - - in_nc, out_nc, nf, nb, plus, mscale = infer_params(state_dict) - - model = arch.RRDBNet(in_nc=in_nc, out_nc=out_nc, nf=nf, nb=nb, upscale=mscale, plus=plus) - model.load_state_dict(state_dict) - model.eval() - - return model + return modelloader.load_spandrel_model( + filename, + device=('cpu' if devices.device_esrgan.type == 'mps' else None), + ) def esrgan_upscale(model, img): diff --git a/modules/esrgan_model_arch.py b/modules/esrgan_model_arch.py deleted file mode 100644 index 2b9888bafbc..00000000000 --- a/modules/esrgan_model_arch.py +++ /dev/null @@ -1,465 +0,0 @@ -# this file is adapted from https://github.com/victorca25/iNNfer - -from collections import OrderedDict -import math -import torch -import torch.nn as nn -import torch.nn.functional as F - - -#################### -# RRDBNet Generator -#################### - -class RRDBNet(nn.Module): - def __init__(self, in_nc, out_nc, nf, nb, nr=3, gc=32, upscale=4, norm_type=None, - act_type='leakyrelu', mode='CNA', upsample_mode='upconv', convtype='Conv2D', - finalact=None, gaussian_noise=False, plus=False): - super(RRDBNet, self).__init__() - n_upscale = int(math.log(upscale, 2)) - if upscale == 3: - n_upscale = 1 - - self.resrgan_scale = 0 - if in_nc % 16 == 0: - self.resrgan_scale = 1 - elif in_nc != 4 and in_nc % 4 == 0: - self.resrgan_scale = 2 - - fea_conv = conv_block(in_nc, nf, kernel_size=3, norm_type=None, act_type=None, convtype=convtype) - rb_blocks = [RRDB(nf, nr, kernel_size=3, gc=32, stride=1, bias=1, pad_type='zero', - norm_type=norm_type, act_type=act_type, mode='CNA', convtype=convtype, - gaussian_noise=gaussian_noise, plus=plus) for _ in range(nb)] - LR_conv = conv_block(nf, nf, kernel_size=3, norm_type=norm_type, act_type=None, mode=mode, convtype=convtype) - - if upsample_mode == 'upconv': - upsample_block = upconv_block - elif upsample_mode == 'pixelshuffle': - upsample_block = pixelshuffle_block - else: - raise NotImplementedError(f'upsample mode [{upsample_mode}] is not found') - if upscale == 3: - upsampler = upsample_block(nf, nf, 3, act_type=act_type, convtype=convtype) - else: - upsampler = [upsample_block(nf, nf, act_type=act_type, convtype=convtype) for _ in range(n_upscale)] - HR_conv0 = conv_block(nf, nf, kernel_size=3, norm_type=None, act_type=act_type, convtype=convtype) - HR_conv1 = conv_block(nf, out_nc, kernel_size=3, norm_type=None, act_type=None, convtype=convtype) - - outact = act(finalact) if finalact else None - - self.model = sequential(fea_conv, ShortcutBlock(sequential(*rb_blocks, LR_conv)), - *upsampler, HR_conv0, HR_conv1, outact) - - def forward(self, x, outm=None): - if self.resrgan_scale == 1: - feat = pixel_unshuffle(x, scale=4) - elif self.resrgan_scale == 2: - feat = pixel_unshuffle(x, scale=2) - else: - feat = x - - return self.model(feat) - - -class RRDB(nn.Module): - """ - Residual in Residual Dense Block - (ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks) - """ - - def __init__(self, nf, nr=3, kernel_size=3, gc=32, stride=1, bias=1, pad_type='zero', - norm_type=None, act_type='leakyrelu', mode='CNA', convtype='Conv2D', - spectral_norm=False, gaussian_noise=False, plus=False): - super(RRDB, self).__init__() - # This is for backwards compatibility with existing models - if nr == 3: - self.RDB1 = ResidualDenseBlock_5C(nf, kernel_size, gc, stride, bias, pad_type, - norm_type, act_type, mode, convtype, spectral_norm=spectral_norm, - gaussian_noise=gaussian_noise, plus=plus) - self.RDB2 = ResidualDenseBlock_5C(nf, kernel_size, gc, stride, bias, pad_type, - norm_type, act_type, mode, convtype, spectral_norm=spectral_norm, - gaussian_noise=gaussian_noise, plus=plus) - self.RDB3 = ResidualDenseBlock_5C(nf, kernel_size, gc, stride, bias, pad_type, - norm_type, act_type, mode, convtype, spectral_norm=spectral_norm, - gaussian_noise=gaussian_noise, plus=plus) - else: - RDB_list = [ResidualDenseBlock_5C(nf, kernel_size, gc, stride, bias, pad_type, - norm_type, act_type, mode, convtype, spectral_norm=spectral_norm, - gaussian_noise=gaussian_noise, plus=plus) for _ in range(nr)] - self.RDBs = nn.Sequential(*RDB_list) - - def forward(self, x): - if hasattr(self, 'RDB1'): - out = self.RDB1(x) - out = self.RDB2(out) - out = self.RDB3(out) - else: - out = self.RDBs(x) - return out * 0.2 + x - - -class ResidualDenseBlock_5C(nn.Module): - """ - Residual Dense Block - The core module of paper: (Residual Dense Network for Image Super-Resolution, CVPR 18) - Modified options that can be used: - - "Partial Convolution based Padding" arXiv:1811.11718 - - "Spectral normalization" arXiv:1802.05957 - - "ICASSP 2020 - ESRGAN+ : Further Improving ESRGAN" N. C. - {Rakotonirina} and A. {Rasoanaivo} - """ - - def __init__(self, nf=64, kernel_size=3, gc=32, stride=1, bias=1, pad_type='zero', - norm_type=None, act_type='leakyrelu', mode='CNA', convtype='Conv2D', - spectral_norm=False, gaussian_noise=False, plus=False): - super(ResidualDenseBlock_5C, self).__init__() - - self.noise = GaussianNoise() if gaussian_noise else None - self.conv1x1 = conv1x1(nf, gc) if plus else None - - self.conv1 = conv_block(nf, gc, kernel_size, stride, bias=bias, pad_type=pad_type, - norm_type=norm_type, act_type=act_type, mode=mode, convtype=convtype, - spectral_norm=spectral_norm) - self.conv2 = conv_block(nf+gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type, - norm_type=norm_type, act_type=act_type, mode=mode, convtype=convtype, - spectral_norm=spectral_norm) - self.conv3 = conv_block(nf+2*gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type, - norm_type=norm_type, act_type=act_type, mode=mode, convtype=convtype, - spectral_norm=spectral_norm) - self.conv4 = conv_block(nf+3*gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type, - norm_type=norm_type, act_type=act_type, mode=mode, convtype=convtype, - spectral_norm=spectral_norm) - if mode == 'CNA': - last_act = None - else: - last_act = act_type - self.conv5 = conv_block(nf+4*gc, nf, 3, stride, bias=bias, pad_type=pad_type, - norm_type=norm_type, act_type=last_act, mode=mode, convtype=convtype, - spectral_norm=spectral_norm) - - def forward(self, x): - x1 = self.conv1(x) - x2 = self.conv2(torch.cat((x, x1), 1)) - if self.conv1x1: - x2 = x2 + self.conv1x1(x) - x3 = self.conv3(torch.cat((x, x1, x2), 1)) - x4 = self.conv4(torch.cat((x, x1, x2, x3), 1)) - if self.conv1x1: - x4 = x4 + x2 - x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1)) - if self.noise: - return self.noise(x5.mul(0.2) + x) - else: - return x5 * 0.2 + x - - -#################### -# ESRGANplus -#################### - -class GaussianNoise(nn.Module): - def __init__(self, sigma=0.1, is_relative_detach=False): - super().__init__() - self.sigma = sigma - self.is_relative_detach = is_relative_detach - self.noise = torch.tensor(0, dtype=torch.float) - - def forward(self, x): - if self.training and self.sigma != 0: - self.noise = self.noise.to(x.device) - scale = self.sigma * x.detach() if self.is_relative_detach else self.sigma * x - sampled_noise = self.noise.repeat(*x.size()).normal_() * scale - x = x + sampled_noise - return x - -def conv1x1(in_planes, out_planes, stride=1): - return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) - - -#################### -# SRVGGNetCompact -#################### - -class SRVGGNetCompact(nn.Module): - """A compact VGG-style network structure for super-resolution. - This class is copied from https://github.com/xinntao/Real-ESRGAN - """ - - def __init__(self, num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu'): - super(SRVGGNetCompact, self).__init__() - self.num_in_ch = num_in_ch - self.num_out_ch = num_out_ch - self.num_feat = num_feat - self.num_conv = num_conv - self.upscale = upscale - self.act_type = act_type - - self.body = nn.ModuleList() - # the first conv - self.body.append(nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)) - # the first activation - if act_type == 'relu': - activation = nn.ReLU(inplace=True) - elif act_type == 'prelu': - activation = nn.PReLU(num_parameters=num_feat) - elif act_type == 'leakyrelu': - activation = nn.LeakyReLU(negative_slope=0.1, inplace=True) - self.body.append(activation) - - # the body structure - for _ in range(num_conv): - self.body.append(nn.Conv2d(num_feat, num_feat, 3, 1, 1)) - # activation - if act_type == 'relu': - activation = nn.ReLU(inplace=True) - elif act_type == 'prelu': - activation = nn.PReLU(num_parameters=num_feat) - elif act_type == 'leakyrelu': - activation = nn.LeakyReLU(negative_slope=0.1, inplace=True) - self.body.append(activation) - - # the last conv - self.body.append(nn.Conv2d(num_feat, num_out_ch * upscale * upscale, 3, 1, 1)) - # upsample - self.upsampler = nn.PixelShuffle(upscale) - - def forward(self, x): - out = x - for i in range(0, len(self.body)): - out = self.body[i](out) - - out = self.upsampler(out) - # add the nearest upsampled image, so that the network learns the residual - base = F.interpolate(x, scale_factor=self.upscale, mode='nearest') - out += base - return out - - -#################### -# Upsampler -#################### - -class Upsample(nn.Module): - r"""Upsamples a given multi-channel 1D (temporal), 2D (spatial) or 3D (volumetric) data. - The input data is assumed to be of the form - `minibatch x channels x [optional depth] x [optional height] x width`. - """ - - def __init__(self, size=None, scale_factor=None, mode="nearest", align_corners=None): - super(Upsample, self).__init__() - if isinstance(scale_factor, tuple): - self.scale_factor = tuple(float(factor) for factor in scale_factor) - else: - self.scale_factor = float(scale_factor) if scale_factor else None - self.mode = mode - self.size = size - self.align_corners = align_corners - - def forward(self, x): - return nn.functional.interpolate(x, size=self.size, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners) - - def extra_repr(self): - if self.scale_factor is not None: - info = f'scale_factor={self.scale_factor}' - else: - info = f'size={self.size}' - info += f', mode={self.mode}' - return info - - -def pixel_unshuffle(x, scale): - """ Pixel unshuffle. - Args: - x (Tensor): Input feature with shape (b, c, hh, hw). - scale (int): Downsample ratio. - Returns: - Tensor: the pixel unshuffled feature. - """ - b, c, hh, hw = x.size() - out_channel = c * (scale**2) - assert hh % scale == 0 and hw % scale == 0 - h = hh // scale - w = hw // scale - x_view = x.view(b, c, h, scale, w, scale) - return x_view.permute(0, 1, 3, 5, 2, 4).reshape(b, out_channel, h, w) - - -def pixelshuffle_block(in_nc, out_nc, upscale_factor=2, kernel_size=3, stride=1, bias=True, - pad_type='zero', norm_type=None, act_type='relu', convtype='Conv2D'): - """ - Pixel shuffle layer - (Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional - Neural Network, CVPR17) - """ - conv = conv_block(in_nc, out_nc * (upscale_factor ** 2), kernel_size, stride, bias=bias, - pad_type=pad_type, norm_type=None, act_type=None, convtype=convtype) - pixel_shuffle = nn.PixelShuffle(upscale_factor) - - n = norm(norm_type, out_nc) if norm_type else None - a = act(act_type) if act_type else None - return sequential(conv, pixel_shuffle, n, a) - - -def upconv_block(in_nc, out_nc, upscale_factor=2, kernel_size=3, stride=1, bias=True, - pad_type='zero', norm_type=None, act_type='relu', mode='nearest', convtype='Conv2D'): - """ Upconv layer """ - upscale_factor = (1, upscale_factor, upscale_factor) if convtype == 'Conv3D' else upscale_factor - upsample = Upsample(scale_factor=upscale_factor, mode=mode) - conv = conv_block(in_nc, out_nc, kernel_size, stride, bias=bias, - pad_type=pad_type, norm_type=norm_type, act_type=act_type, convtype=convtype) - return sequential(upsample, conv) - - - - - - - - -#################### -# Basic blocks -#################### - - -def make_layer(basic_block, num_basic_block, **kwarg): - """Make layers by stacking the same blocks. - Args: - basic_block (nn.module): nn.module class for basic block. (block) - num_basic_block (int): number of blocks. (n_layers) - Returns: - nn.Sequential: Stacked blocks in nn.Sequential. - """ - layers = [] - for _ in range(num_basic_block): - layers.append(basic_block(**kwarg)) - return nn.Sequential(*layers) - - -def act(act_type, inplace=True, neg_slope=0.2, n_prelu=1, beta=1.0): - """ activation helper """ - act_type = act_type.lower() - if act_type == 'relu': - layer = nn.ReLU(inplace) - elif act_type in ('leakyrelu', 'lrelu'): - layer = nn.LeakyReLU(neg_slope, inplace) - elif act_type == 'prelu': - layer = nn.PReLU(num_parameters=n_prelu, init=neg_slope) - elif act_type == 'tanh': # [-1, 1] range output - layer = nn.Tanh() - elif act_type == 'sigmoid': # [0, 1] range output - layer = nn.Sigmoid() - else: - raise NotImplementedError(f'activation layer [{act_type}] is not found') - return layer - - -class Identity(nn.Module): - def __init__(self, *kwargs): - super(Identity, self).__init__() - - def forward(self, x, *kwargs): - return x - - -def norm(norm_type, nc): - """ Return a normalization layer """ - norm_type = norm_type.lower() - if norm_type == 'batch': - layer = nn.BatchNorm2d(nc, affine=True) - elif norm_type == 'instance': - layer = nn.InstanceNorm2d(nc, affine=False) - elif norm_type == 'none': - def norm_layer(x): return Identity() - else: - raise NotImplementedError(f'normalization layer [{norm_type}] is not found') - return layer - - -def pad(pad_type, padding): - """ padding layer helper """ - pad_type = pad_type.lower() - if padding == 0: - return None - if pad_type == 'reflect': - layer = nn.ReflectionPad2d(padding) - elif pad_type == 'replicate': - layer = nn.ReplicationPad2d(padding) - elif pad_type == 'zero': - layer = nn.ZeroPad2d(padding) - else: - raise NotImplementedError(f'padding layer [{pad_type}] is not implemented') - return layer - - -def get_valid_padding(kernel_size, dilation): - kernel_size = kernel_size + (kernel_size - 1) * (dilation - 1) - padding = (kernel_size - 1) // 2 - return padding - - -class ShortcutBlock(nn.Module): - """ Elementwise sum the output of a submodule to its input """ - def __init__(self, submodule): - super(ShortcutBlock, self).__init__() - self.sub = submodule - - def forward(self, x): - output = x + self.sub(x) - return output - - def __repr__(self): - return 'Identity + \n|' + self.sub.__repr__().replace('\n', '\n|') - - -def sequential(*args): - """ Flatten Sequential. It unwraps nn.Sequential. """ - if len(args) == 1: - if isinstance(args[0], OrderedDict): - raise NotImplementedError('sequential does not support OrderedDict input.') - return args[0] # No sequential is needed. - modules = [] - for module in args: - if isinstance(module, nn.Sequential): - for submodule in module.children(): - modules.append(submodule) - elif isinstance(module, nn.Module): - modules.append(module) - return nn.Sequential(*modules) - - -def conv_block(in_nc, out_nc, kernel_size, stride=1, dilation=1, groups=1, bias=True, - pad_type='zero', norm_type=None, act_type='relu', mode='CNA', convtype='Conv2D', - spectral_norm=False): - """ Conv layer with padding, normalization, activation """ - assert mode in ['CNA', 'NAC', 'CNAC'], f'Wrong conv mode [{mode}]' - padding = get_valid_padding(kernel_size, dilation) - p = pad(pad_type, padding) if pad_type and pad_type != 'zero' else None - padding = padding if pad_type == 'zero' else 0 - - if convtype=='PartialConv2D': - from torchvision.ops import PartialConv2d # this is definitely not going to work, but PartialConv2d doesn't work anyway and this shuts up static analyzer - c = PartialConv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding, - dilation=dilation, bias=bias, groups=groups) - elif convtype=='DeformConv2D': - from torchvision.ops import DeformConv2d # not tested - c = DeformConv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding, - dilation=dilation, bias=bias, groups=groups) - elif convtype=='Conv3D': - c = nn.Conv3d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding, - dilation=dilation, bias=bias, groups=groups) - else: - c = nn.Conv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding, - dilation=dilation, bias=bias, groups=groups) - - if spectral_norm: - c = nn.utils.spectral_norm(c) - - a = act(act_type) if act_type else None - if 'CNA' in mode: - n = norm(norm_type, out_nc) if norm_type else None - return sequential(p, c, n, a) - elif mode == 'NAC': - if norm_type is None and act_type is not None: - a = act(act_type, inplace=False) - n = norm(norm_type, in_nc) if norm_type else None - return sequential(n, a, p, c) diff --git a/modules/gfpgan_model.py b/modules/gfpgan_model.py index 01d668ecdaf..6b6f17c435a 100644 --- a/modules/gfpgan_model.py +++ b/modules/gfpgan_model.py @@ -1,8 +1,5 @@ import os -import facexlib -import gfpgan - import modules.face_restoration from modules import paths, shared, devices, modelloader, errors @@ -41,6 +38,8 @@ def gfpgann(): print("Unable to load gfpgan model!") return None + import facexlib.detection.retinaface + if hasattr(facexlib.detection.retinaface, 'device'): facexlib.detection.retinaface.device = devices.device_gfpgan model_file_path = model_file @@ -81,8 +80,10 @@ def gfpgan_fix_faces(np_image): def setup_model(dirname): try: os.makedirs(model_path, exist_ok=True) - from gfpgan import GFPGANer - from facexlib import detection, parsing # noqa: F401 + import gfpgan + import facexlib.detection + import facexlib.parsing + global user_path global have_gfpgan global gfpgan_constructor @@ -111,7 +112,7 @@ def facex_load_file_from_url2(**kwargs): facexlib.parsing.load_file_from_url = facex_load_file_from_url2 user_path = dirname have_gfpgan = True - gfpgan_constructor = GFPGANer + gfpgan_constructor = gfpgan.GFPGANer class FaceRestorerGFPGAN(modules.face_restoration.FaceRestoration): def name(self): diff --git a/modules/launch_utils.py b/modules/launch_utils.py index dabef0f5335..c2cbd8ce73e 100644 --- a/modules/launch_utils.py +++ b/modules/launch_utils.py @@ -345,13 +345,11 @@ def prepare_environment(): stable_diffusion_repo = os.environ.get('STABLE_DIFFUSION_REPO', "https://github.com/Stability-AI/stablediffusion.git") stable_diffusion_xl_repo = os.environ.get('STABLE_DIFFUSION_XL_REPO', "https://github.com/Stability-AI/generative-models.git") k_diffusion_repo = os.environ.get('K_DIFFUSION_REPO', 'https://github.com/crowsonkb/k-diffusion.git') - codeformer_repo = os.environ.get('CODEFORMER_REPO', 'https://github.com/sczhou/CodeFormer.git') blip_repo = os.environ.get('BLIP_REPO', 'https://github.com/salesforce/BLIP.git') stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "cf1d67a6fd5ea1aa600c4df58e5b47da45f6bdbf") stable_diffusion_xl_commit_hash = os.environ.get('STABLE_DIFFUSION_XL_COMMIT_HASH', "45c443b316737a4ab6e40413d7794a7f5657c19f") k_diffusion_commit_hash = os.environ.get('K_DIFFUSION_COMMIT_HASH', "ab527a9a6d347f364e3d185ba6d714e22d80cb3c") - codeformer_commit_hash = os.environ.get('CODEFORMER_COMMIT_HASH', "c5b4593074ba6214284d6acd5f1719b6c5d739af") blip_commit_hash = os.environ.get('BLIP_COMMIT_HASH', "48211a1594f1321b00f14c9f7a5b4813144b2fb9") try: @@ -408,15 +406,10 @@ def prepare_environment(): git_clone(stable_diffusion_repo, repo_dir('stable-diffusion-stability-ai'), "Stable Diffusion", stable_diffusion_commit_hash) git_clone(stable_diffusion_xl_repo, repo_dir('generative-models'), "Stable Diffusion XL", stable_diffusion_xl_commit_hash) git_clone(k_diffusion_repo, repo_dir('k-diffusion'), "K-diffusion", k_diffusion_commit_hash) - git_clone(codeformer_repo, repo_dir('CodeFormer'), "CodeFormer", codeformer_commit_hash) git_clone(blip_repo, repo_dir('BLIP'), "BLIP", blip_commit_hash) startup_timer.record("clone repositores") - if not is_installed("lpips"): - run_pip(f"install -r \"{os.path.join(repo_dir('CodeFormer'), 'requirements.txt')}\"", "requirements for CodeFormer") - startup_timer.record("install CodeFormer requirements") - if not os.path.isfile(requirements_file): requirements_file = os.path.join(script_path, requirements_file) diff --git a/modules/modelloader.py b/modules/modelloader.py index 098bcb79336..30116932a1d 100644 --- a/modules/modelloader.py +++ b/modules/modelloader.py @@ -1,5 +1,6 @@ from __future__ import annotations +import logging import os import shutil import importlib @@ -10,6 +11,9 @@ from modules.paths import script_path, models_path +logger = logging.getLogger(__name__) + + def load_file_from_url( url: str, *, @@ -177,3 +181,15 @@ def load_upscalers(): # Special case for UpscalerNone keeps it at the beginning of the list. key=lambda x: x.name.lower() if not isinstance(x.scaler, (UpscalerNone, UpscalerLanczos, UpscalerNearest)) else "" ) + + +def load_spandrel_model(path, *, device, half: bool = False, dtype=None): + import spandrel + model = spandrel.ModelLoader(device=device).load_from_file(path) + if half: + model = model.model.half() + if dtype: + model = model.model.to(dtype=dtype) + model.eval() + logger.debug("Loaded %s from %s (device=%s, half=%s, dtype=%s)", model, path, device, half, dtype) + return model diff --git a/modules/paths.py b/modules/paths.py index 187b949612a..030646519c3 100644 --- a/modules/paths.py +++ b/modules/paths.py @@ -38,7 +38,6 @@ class Dummy: path_dirs = [ (sd_path, 'ldm', 'Stable Diffusion', []), (os.path.join(sd_path, '../generative-models'), 'sgm', 'Stable Diffusion XL', ["sgm"]), - (os.path.join(sd_path, '../CodeFormer'), 'inference_codeformer.py', 'CodeFormer', []), (os.path.join(sd_path, '../BLIP'), 'models/blip.py', 'BLIP', []), (os.path.join(sd_path, '../k-diffusion'), 'k_diffusion/sampling.py', 'k_diffusion', ["atstart"]), ] diff --git a/modules/realesrgan_model.py b/modules/realesrgan_model.py index 02841c30289..332d8f4b108 100644 --- a/modules/realesrgan_model.py +++ b/modules/realesrgan_model.py @@ -1,9 +1,6 @@ import os -import numpy as np -from PIL import Image -from realesrgan import RealESRGANer - +from modules.upscaler_utils import upscale_with_model from modules.upscaler import Upscaler, UpscalerData from modules.shared import cmd_opts, opts from modules import modelloader, errors @@ -14,29 +11,20 @@ def __init__(self, path): self.name = "RealESRGAN" self.user_path = path super().__init__() - try: - from basicsr.archs.rrdbnet_arch import RRDBNet # noqa: F401 - from realesrgan import RealESRGANer # noqa: F401 - from realesrgan.archs.srvgg_arch import SRVGGNetCompact # noqa: F401 - self.enable = True - self.scalers = [] - scalers = self.load_models(path) + self.enable = True + self.scalers = [] + scalers = get_realesrgan_models(self) - local_model_paths = self.find_models(ext_filter=[".pth"]) - for scaler in scalers: - if scaler.local_data_path.startswith("http"): - filename = modelloader.friendly_name(scaler.local_data_path) - local_model_candidates = [local_model for local_model in local_model_paths if local_model.endswith(f"{filename}.pth")] - if local_model_candidates: - scaler.local_data_path = local_model_candidates[0] + local_model_paths = self.find_models(ext_filter=[".pth"]) + for scaler in scalers: + if scaler.local_data_path.startswith("http"): + filename = modelloader.friendly_name(scaler.local_data_path) + local_model_candidates = [local_model for local_model in local_model_paths if local_model.endswith(f"{filename}.pth")] + if local_model_candidates: + scaler.local_data_path = local_model_candidates[0] - if scaler.name in opts.realesrgan_enabled_models: - self.scalers.append(scaler) - - except Exception: - errors.report("Error importing Real-ESRGAN", exc_info=True) - self.enable = False - self.scalers = [] + if scaler.name in opts.realesrgan_enabled_models: + self.scalers.append(scaler) def do_upscale(self, img, path): if not self.enable: @@ -48,20 +36,18 @@ def do_upscale(self, img, path): errors.report(f"Unable to load RealESRGAN model {path}", exc_info=True) return img - upsampler = RealESRGANer( - scale=info.scale, - model_path=info.local_data_path, - model=info.model(), - half=not cmd_opts.no_half and not cmd_opts.upcast_sampling, - tile=opts.ESRGAN_tile, - tile_pad=opts.ESRGAN_tile_overlap, + mod = modelloader.load_spandrel_model( + info.local_data_path, device=self.device, + half=(not cmd_opts.no_half and not cmd_opts.upcast_sampling), + ) + return upscale_with_model( + mod, + img, + tile_size=opts.ESRGAN_tile, + tile_overlap=opts.ESRGAN_tile_overlap, + # TODO: `outscale`? ) - - upsampled = upsampler.enhance(np.array(img), outscale=info.scale)[0] - - image = Image.fromarray(upsampled) - return image def load_model(self, path): for scaler in self.scalers: @@ -76,58 +62,43 @@ def load_model(self, path): return scaler raise ValueError(f"Unable to find model info: {path}") - def load_models(self, _): - return get_realesrgan_models(self) - -def get_realesrgan_models(scaler): - try: - from basicsr.archs.rrdbnet_arch import RRDBNet - from realesrgan.archs.srvgg_arch import SRVGGNetCompact - models = [ - UpscalerData( - name="R-ESRGAN General 4xV3", - path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth", - scale=4, - upscaler=scaler, - model=lambda: SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu') - ), - UpscalerData( - name="R-ESRGAN General WDN 4xV3", - path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-wdn-x4v3.pth", - scale=4, - upscaler=scaler, - model=lambda: SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu') - ), - UpscalerData( - name="R-ESRGAN AnimeVideo", - path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-animevideov3.pth", - scale=4, - upscaler=scaler, - model=lambda: SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu') - ), - UpscalerData( - name="R-ESRGAN 4x+", - path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth", - scale=4, - upscaler=scaler, - model=lambda: RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4) - ), - UpscalerData( - name="R-ESRGAN 4x+ Anime6B", - path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth", - scale=4, - upscaler=scaler, - model=lambda: RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4) - ), - UpscalerData( - name="R-ESRGAN 2x+", - path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth", - scale=2, - upscaler=scaler, - model=lambda: RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2) - ), - ] - return models - except Exception: - errors.report("Error making Real-ESRGAN models list", exc_info=True) +def get_realesrgan_models(scaler: UpscalerRealESRGAN): + return [ + UpscalerData( + name="R-ESRGAN General 4xV3", + path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth", + scale=4, + upscaler=scaler, + ), + UpscalerData( + name="R-ESRGAN General WDN 4xV3", + path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-wdn-x4v3.pth", + scale=4, + upscaler=scaler, + ), + UpscalerData( + name="R-ESRGAN AnimeVideo", + path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-animevideov3.pth", + scale=4, + upscaler=scaler, + ), + UpscalerData( + name="R-ESRGAN 4x+", + path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth", + scale=4, + upscaler=scaler, + ), + UpscalerData( + name="R-ESRGAN 4x+ Anime6B", + path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth", + scale=4, + upscaler=scaler, + ), + UpscalerData( + name="R-ESRGAN 2x+", + path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth", + scale=2, + upscaler=scaler, + ), + ] diff --git a/modules/sysinfo.py b/modules/sysinfo.py index b669edd0cfd..5abf616b707 100644 --- a/modules/sysinfo.py +++ b/modules/sysinfo.py @@ -26,11 +26,9 @@ "OPENCLIP_PACKAGE", "STABLE_DIFFUSION_REPO", "K_DIFFUSION_REPO", - "CODEFORMER_REPO", "BLIP_REPO", "STABLE_DIFFUSION_COMMIT_HASH", "K_DIFFUSION_COMMIT_HASH", - "CODEFORMER_COMMIT_HASH", "BLIP_COMMIT_HASH", "COMMANDLINE_ARGS", "IGNORE_CMD_ARGS_ERRORS", diff --git a/modules/upscaler.py b/modules/upscaler.py index b256e085b6d..3aee69db8d2 100644 --- a/modules/upscaler.py +++ b/modules/upscaler.py @@ -98,6 +98,9 @@ def __init__(self, name: str, path: str, upscaler: Upscaler = None, scale: int = self.scale = scale self.model = model + def __repr__(self): + return f"" + class UpscalerNone(Upscaler): name = "None" diff --git a/requirements.txt b/requirements.txt index 80b438455ce..36f5674ad99 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,6 +6,7 @@ basicsr blendmodes clean-fid einops +facexlib fastapi>=0.90.1 gfpgan gradio==3.41.2 @@ -20,13 +21,11 @@ open-clip-torch piexif psutil pytorch_lightning -realesrgan requests resize-right safetensors scikit-image>=0.19 -timm tomesd torch torchdiffeq diff --git a/requirements_versions.txt b/requirements_versions.txt index cb7403a9d4b..042fa708c49 100644 --- a/requirements_versions.txt +++ b/requirements_versions.txt @@ -5,6 +5,7 @@ basicsr==1.4.2 blendmodes==2022 clean-fid==0.1.35 einops==0.4.1 +facexlib==0.3.0 fastapi==0.94.0 gfpgan==1.3.8 gradio==3.41.2 @@ -19,11 +20,10 @@ open-clip-torch==2.20.0 piexif==1.1.3 psutil==5.9.5 pytorch_lightning==1.9.4 -realesrgan==0.3.0 resize-right==0.0.2 safetensors==0.3.1 scikit-image==0.21.0 -timm==0.9.2 +spandrel==0.1.6 tomesd==0.1.3 torch torchdiffeq==0.2.3 From b621a63cf68c788487684250856707cb352b82d0 Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Mon, 25 Dec 2023 23:01:02 +0200 Subject: [PATCH 0483/1174] Unify CodeFormer and GFPGAN restoration backends, use Spandrel for GFPGAN --- .github/workflows/run_tests.yaml | 8 ++ .gitignore | 1 + modules/codeformer_model.py | 158 +++++++--------------------- modules/face_restoration_utils.py | 163 +++++++++++++++++++++++++++++ modules/gfpgan_model.py | 166 ++++++++++-------------------- requirements.txt | 1 - requirements_versions.txt | 1 - test/conftest.py | 15 ++- test/test_face_restorers.py | 29 ++++++ test/test_files/two-faces.jpg | Bin 0 -> 14768 bytes test/test_outputs/.gitkeep | 0 11 files changed, 308 insertions(+), 234 deletions(-) create mode 100644 modules/face_restoration_utils.py create mode 100644 test/test_face_restorers.py create mode 100644 test/test_files/two-faces.jpg create mode 100644 test/test_outputs/.gitkeep diff --git a/.github/workflows/run_tests.yaml b/.github/workflows/run_tests.yaml index 3dafaf8dcfc..cd5c3f868ff 100644 --- a/.github/workflows/run_tests.yaml +++ b/.github/workflows/run_tests.yaml @@ -20,6 +20,12 @@ jobs: cache-dependency-path: | **/requirements*txt launch.py + - name: Cache models + id: cache-models + uses: actions/cache@v3 + with: + path: models + key: "2023-12-30" - name: Install test dependencies run: pip install wait-for-it -r requirements-test.txt env: @@ -33,6 +39,8 @@ jobs: TORCH_INDEX_URL: https://download.pytorch.org/whl/cpu WEBUI_LAUNCH_LIVE_OUTPUT: "1" PYTHONUNBUFFERED: "1" + - name: Print installed packages + run: pip freeze - name: Start test server run: > python -m coverage run diff --git a/.gitignore b/.gitignore index 09734267ff5..6790e9ee728 100644 --- a/.gitignore +++ b/.gitignore @@ -37,3 +37,4 @@ notification.mp3 /node_modules /package-lock.json /.coverage* +/test/test_outputs diff --git a/modules/codeformer_model.py b/modules/codeformer_model.py index 517eadfd8de..ceda4bab919 100644 --- a/modules/codeformer_model.py +++ b/modules/codeformer_model.py @@ -1,140 +1,62 @@ -import os +from __future__ import annotations + +import logging -import cv2 import torch -import modules.face_restoration -import modules.shared -from modules import shared, devices, modelloader, errors -from modules.paths import models_path +from modules import ( + devices, + errors, + face_restoration, + face_restoration_utils, + modelloader, + shared, +) + +logger = logging.getLogger(__name__) -model_dir = "Codeformer" -model_path = os.path.join(models_path, model_dir) model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth' +model_download_name = 'codeformer-v0.1.0.pth' -codeformer = None +# used by e.g. postprocessing_codeformer.py +codeformer: face_restoration.FaceRestoration | None = None -class FaceRestorerCodeFormer(modules.face_restoration.FaceRestoration): +class FaceRestorerCodeFormer(face_restoration_utils.CommonFaceRestoration): def name(self): return "CodeFormer" - def __init__(self, dirname): - self.net = None - self.face_helper = None - self.cmd_dir = dirname - - def create_models(self): - from facexlib.detection import retinaface - from facexlib.utils.face_restoration_helper import FaceRestoreHelper - - if self.net is not None and self.face_helper is not None: - self.net.to(devices.device_codeformer) - return self.net, self.face_helper - model_paths = modelloader.load_models( - model_path, - model_url, - self.cmd_dir, - download_name='codeformer-v0.1.0.pth', + def load_net(self) -> torch.Module: + for model_path in modelloader.load_models( + model_path=self.model_path, + model_url=model_url, + command_path=self.model_path, + download_name=model_download_name, ext_filter=['.pth'], - ) - - if len(model_paths) != 0: - ckpt_path = model_paths[0] - else: - print("Unable to load codeformer model.") - return None, None - net = modelloader.load_spandrel_model(ckpt_path, device=devices.device_codeformer) - - if hasattr(retinaface, 'device'): - retinaface.device = devices.device_codeformer - - face_helper = FaceRestoreHelper( - upscale_factor=1, - face_size=512, - crop_ratio=(1, 1), - det_model='retinaface_resnet50', - save_ext='png', - use_parse=True, - device=devices.device_codeformer, - ) - - self.net = net - self.face_helper = face_helper - - def send_model_to(self, device): - self.net.to(device) - self.face_helper.face_det.to(device) - self.face_helper.face_parse.to(device) - - def restore(self, np_image, w=None): - from torchvision.transforms.functional import normalize - from basicsr.utils import img2tensor, tensor2img - np_image = np_image[:, :, ::-1] - - original_resolution = np_image.shape[0:2] - - self.create_models() - if self.net is None or self.face_helper is None: - return np_image - - self.send_model_to(devices.device_codeformer) - - self.face_helper.clean_all() - self.face_helper.read_image(np_image) - self.face_helper.get_face_landmarks_5(only_center_face=False, resize=640, eye_dist_threshold=5) - self.face_helper.align_warp_face() - - for cropped_face in self.face_helper.cropped_faces: - cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True) - normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True) - cropped_face_t = cropped_face_t.unsqueeze(0).to(devices.device_codeformer) - - try: - with torch.no_grad(): - res = self.net(cropped_face_t, w=w if w is not None else shared.opts.code_former_weight, adain=True) - if isinstance(res, tuple): - output = res[0] - else: - output = res - if not isinstance(res, torch.Tensor): - raise TypeError(f"Expected torch.Tensor, got {type(res)}") - restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1)) - del output - devices.torch_gc() - except Exception: - errors.report('Failed inference for CodeFormer', exc_info=True) - restored_face = tensor2img(cropped_face_t, rgb2bgr=True, min_max=(-1, 1)) - - restored_face = restored_face.astype('uint8') - self.face_helper.add_restored_face(restored_face) - - self.face_helper.get_inverse_affine(None) - - restored_img = self.face_helper.paste_faces_to_input_image() - restored_img = restored_img[:, :, ::-1] + ): + return modelloader.load_spandrel_model( + model_path, + device=devices.device_codeformer, + ).model + raise ValueError("No codeformer model found") - if original_resolution != restored_img.shape[0:2]: - restored_img = cv2.resize( - restored_img, - (0, 0), - fx=original_resolution[1]/restored_img.shape[1], - fy=original_resolution[0]/restored_img.shape[0], - interpolation=cv2.INTER_LINEAR, - ) + def get_device(self): + return devices.device_codeformer - self.face_helper.clean_all() + def restore(self, np_image, w: float | None = None): + if w is None: + w = getattr(shared.opts, "code_former_weight", 0.5) - if shared.opts.face_restoration_unload: - self.send_model_to(devices.cpu) + def restore_face(cropped_face_t): + assert self.net is not None + return self.net(cropped_face_t, w=w, adain=True)[0] - return restored_img + return self.restore_with_helper(np_image, restore_face) -def setup_model(dirname): - os.makedirs(model_path, exist_ok=True) +def setup_model(dirname: str) -> None: + global codeformer try: - global codeformer codeformer = FaceRestorerCodeFormer(dirname) shared.face_restorers.append(codeformer) except Exception: diff --git a/modules/face_restoration_utils.py b/modules/face_restoration_utils.py new file mode 100644 index 00000000000..c65c85ef89d --- /dev/null +++ b/modules/face_restoration_utils.py @@ -0,0 +1,163 @@ +from __future__ import annotations + +import logging +import os +from functools import cached_property +from typing import TYPE_CHECKING, Callable + +import cv2 +import numpy as np +import torch + +from modules import devices, errors, face_restoration, shared + +if TYPE_CHECKING: + from facexlib.utils.face_restoration_helper import FaceRestoreHelper + +logger = logging.getLogger(__name__) + + +def create_face_helper(device) -> FaceRestoreHelper: + from facexlib.detection import retinaface + from facexlib.utils.face_restoration_helper import FaceRestoreHelper + if hasattr(retinaface, 'device'): + retinaface.device = device + return FaceRestoreHelper( + upscale_factor=1, + face_size=512, + crop_ratio=(1, 1), + det_model='retinaface_resnet50', + save_ext='png', + use_parse=True, + device=device, + ) + + +def restore_with_face_helper( + np_image: np.ndarray, + face_helper: FaceRestoreHelper, + restore_face: Callable[[np.ndarray], np.ndarray], +) -> np.ndarray: + """ + Find faces in the image using face_helper, restore them using restore_face, and paste them back into the image. + + `restore_face` should take a cropped face image and return a restored face image. + """ + from basicsr.utils import img2tensor, tensor2img + from torchvision.transforms.functional import normalize + np_image = np_image[:, :, ::-1] + original_resolution = np_image.shape[0:2] + + try: + logger.debug("Detecting faces...") + face_helper.clean_all() + face_helper.read_image(np_image) + face_helper.get_face_landmarks_5(only_center_face=False, resize=640, eye_dist_threshold=5) + face_helper.align_warp_face() + logger.debug("Found %d faces, restoring", len(face_helper.cropped_faces)) + for cropped_face in face_helper.cropped_faces: + cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True) + normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True) + cropped_face_t = cropped_face_t.unsqueeze(0).to(devices.device_codeformer) + + try: + with torch.no_grad(): + restored_face = tensor2img( + restore_face(cropped_face_t), + rgb2bgr=True, + min_max=(-1, 1), + ) + devices.torch_gc() + except Exception: + errors.report('Failed face-restoration inference', exc_info=True) + restored_face = tensor2img(cropped_face_t, rgb2bgr=True, min_max=(-1, 1)) + + restored_face = restored_face.astype('uint8') + face_helper.add_restored_face(restored_face) + + logger.debug("Merging restored faces into image") + face_helper.get_inverse_affine(None) + img = face_helper.paste_faces_to_input_image() + img = img[:, :, ::-1] + if original_resolution != img.shape[0:2]: + img = cv2.resize( + img, + (0, 0), + fx=original_resolution[1] / img.shape[1], + fy=original_resolution[0] / img.shape[0], + interpolation=cv2.INTER_LINEAR, + ) + logger.debug("Face restoration complete") + finally: + face_helper.clean_all() + return img + + +class CommonFaceRestoration(face_restoration.FaceRestoration): + net: torch.Module | None + model_url: str + model_download_name: str + + def __init__(self, model_path: str): + super().__init__() + self.net = None + self.model_path = model_path + os.makedirs(model_path, exist_ok=True) + + @cached_property + def face_helper(self) -> FaceRestoreHelper: + return create_face_helper(self.get_device()) + + def send_model_to(self, device): + if self.net: + logger.debug("Sending %s to %s", self.net, device) + self.net.to(device) + if self.face_helper: + logger.debug("Sending face helper to %s", device) + self.face_helper.face_det.to(device) + self.face_helper.face_parse.to(device) + + def get_device(self): + raise NotImplementedError("get_device must be implemented by subclasses") + + def load_net(self) -> torch.Module: + raise NotImplementedError("load_net must be implemented by subclasses") + + def restore_with_helper( + self, + np_image: np.ndarray, + restore_face: Callable[[np.ndarray], np.ndarray], + ) -> np.ndarray: + try: + if self.net is None: + self.net = self.load_net() + except Exception: + logger.warning("Unable to load face-restoration model", exc_info=True) + return np_image + + try: + self.send_model_to(self.get_device()) + return restore_with_face_helper(np_image, self.face_helper, restore_face) + finally: + if shared.opts.face_restoration_unload: + self.send_model_to(devices.cpu) + + +def patch_facexlib(dirname: str) -> None: + import facexlib.detection + import facexlib.parsing + + det_facex_load_file_from_url = facexlib.detection.load_file_from_url + par_facex_load_file_from_url = facexlib.parsing.load_file_from_url + + def update_kwargs(kwargs): + return dict(kwargs, save_dir=dirname, model_dir=None) + + def facex_load_file_from_url(**kwargs): + return det_facex_load_file_from_url(**update_kwargs(kwargs)) + + def facex_load_file_from_url2(**kwargs): + return par_facex_load_file_from_url(**update_kwargs(kwargs)) + + facexlib.detection.load_file_from_url = facex_load_file_from_url + facexlib.parsing.load_file_from_url = facex_load_file_from_url2 diff --git a/modules/gfpgan_model.py b/modules/gfpgan_model.py index 6b6f17c435a..a356b56fe1a 100644 --- a/modules/gfpgan_model.py +++ b/modules/gfpgan_model.py @@ -1,126 +1,68 @@ +from __future__ import annotations + +import logging import os -import modules.face_restoration -from modules import paths, shared, devices, modelloader, errors +from modules import ( + devices, + errors, + face_restoration, + face_restoration_utils, + modelloader, + shared, +) -model_dir = "GFPGAN" -user_path = None -model_path = os.path.join(paths.models_path, model_dir) -model_file_path = None +logger = logging.getLogger(__name__) model_url = "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth" -have_gfpgan = False -loaded_gfpgan_model = None - - -def gfpgann(): - global loaded_gfpgan_model - global model_path - global model_file_path - if loaded_gfpgan_model is not None: - loaded_gfpgan_model.gfpgan.to(devices.device_gfpgan) - return loaded_gfpgan_model - - if gfpgan_constructor is None: - return None - - models = modelloader.load_models(model_path, model_url, user_path, ext_filter=['.pth']) - - if len(models) == 1 and models[0].startswith("http"): - model_file = models[0] - elif len(models) != 0: - gfp_models = [] - for item in models: - if 'GFPGAN' in os.path.basename(item): - gfp_models.append(item) - latest_file = max(gfp_models, key=os.path.getctime) - model_file = latest_file - else: - print("Unable to load gfpgan model!") - return None - - import facexlib.detection.retinaface - - if hasattr(facexlib.detection.retinaface, 'device'): - facexlib.detection.retinaface.device = devices.device_gfpgan - model_file_path = model_file - model = gfpgan_constructor(model_path=model_file, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None, device=devices.device_gfpgan) - loaded_gfpgan_model = model - - return model - - -def send_model_to(model, device): - model.gfpgan.to(device) - model.face_helper.face_det.to(device) - model.face_helper.face_parse.to(device) +model_download_name = "GFPGANv1.4.pth" +gfpgan_face_restorer: face_restoration.FaceRestoration | None = None + + +class FaceRestorerGFPGAN(face_restoration_utils.CommonFaceRestoration): + def name(self): + return "GFPGAN" + + def get_device(self): + return devices.device_gfpgan + + def load_net(self) -> None: + for model_path in modelloader.load_models( + model_path=self.model_path, + model_url=model_url, + command_path=self.model_path, + download_name=model_download_name, + ext_filter=['.pth'], + ): + if 'GFPGAN' in os.path.basename(model_path): + net = modelloader.load_spandrel_model( + model_path, + device=self.get_device(), + ).model + net.different_w = True # see https://github.com/chaiNNer-org/spandrel/pull/81 + return net + raise ValueError("No GFPGAN model found") + + def restore(self, np_image): + def restore_face(cropped_face_t): + assert self.net is not None + return self.net(cropped_face_t, return_rgb=False)[0] + + return self.restore_with_helper(np_image, restore_face) def gfpgan_fix_faces(np_image): - model = gfpgann() - if model is None: - return np_image - - send_model_to(model, devices.device_gfpgan) - - np_image_bgr = np_image[:, :, ::-1] - cropped_faces, restored_faces, gfpgan_output_bgr = model.enhance(np_image_bgr, has_aligned=False, only_center_face=False, paste_back=True) - np_image = gfpgan_output_bgr[:, :, ::-1] - - model.face_helper.clean_all() - - if shared.opts.face_restoration_unload: - send_model_to(model, devices.cpu) - + if gfpgan_face_restorer: + return gfpgan_face_restorer.restore(np_image) + logger.warning("GFPGAN face restorer not set up") return np_image -gfpgan_constructor = None +def setup_model(dirname: str) -> None: + global gfpgan_face_restorer - -def setup_model(dirname): try: - os.makedirs(model_path, exist_ok=True) - import gfpgan - import facexlib.detection - import facexlib.parsing - - global user_path - global have_gfpgan - global gfpgan_constructor - global model_file_path - - facexlib_path = model_path - - if dirname is not None: - facexlib_path = dirname - - load_file_from_url_orig = gfpgan.utils.load_file_from_url - facex_load_file_from_url_orig = facexlib.detection.load_file_from_url - facex_load_file_from_url_orig2 = facexlib.parsing.load_file_from_url - - def my_load_file_from_url(**kwargs): - return load_file_from_url_orig(**dict(kwargs, model_dir=model_file_path)) - - def facex_load_file_from_url(**kwargs): - return facex_load_file_from_url_orig(**dict(kwargs, save_dir=facexlib_path, model_dir=None)) - - def facex_load_file_from_url2(**kwargs): - return facex_load_file_from_url_orig2(**dict(kwargs, save_dir=facexlib_path, model_dir=None)) - - gfpgan.utils.load_file_from_url = my_load_file_from_url - facexlib.detection.load_file_from_url = facex_load_file_from_url - facexlib.parsing.load_file_from_url = facex_load_file_from_url2 - user_path = dirname - have_gfpgan = True - gfpgan_constructor = gfpgan.GFPGANer - - class FaceRestorerGFPGAN(modules.face_restoration.FaceRestoration): - def name(self): - return "GFPGAN" - - def restore(self, np_image): - return gfpgan_fix_faces(np_image) - - shared.face_restorers.append(FaceRestorerGFPGAN()) + face_restoration_utils.patch_facexlib(dirname) + gfpgan_face_restorer = FaceRestorerGFPGAN(model_path=dirname) + shared.face_restorers.append(gfpgan_face_restorer) except Exception: errors.report("Error setting up GFPGAN", exc_info=True) diff --git a/requirements.txt b/requirements.txt index 36f5674ad99..b1329c9e3ab 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,7 +8,6 @@ clean-fid einops facexlib fastapi>=0.90.1 -gfpgan gradio==3.41.2 inflection jsonmerge diff --git a/requirements_versions.txt b/requirements_versions.txt index 042fa708c49..edbb6db9e0e 100644 --- a/requirements_versions.txt +++ b/requirements_versions.txt @@ -7,7 +7,6 @@ clean-fid==0.1.35 einops==0.4.1 facexlib==0.3.0 fastapi==0.94.0 -gfpgan==1.3.8 gradio==3.41.2 httpcore==0.15 inflection==0.5.1 diff --git a/test/conftest.py b/test/conftest.py index 31a5d9eafb8..e4fc567854d 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -1,10 +1,16 @@ +import base64 import os import pytest -import base64 - test_files_path = os.path.dirname(__file__) + "/test_files" +test_outputs_path = os.path.dirname(__file__) + "/test_outputs" + + +def pytest_configure(config): + # We don't want to fail on Py.test command line arguments being + # parsed by webui: + os.environ.setdefault("IGNORE_CMD_ARGS_ERRORS", "1") def file_to_base64(filename): @@ -23,3 +29,8 @@ def img2img_basic_image_base64() -> str: @pytest.fixture(scope="session") # session so we don't read this over and over def mask_basic_image_base64() -> str: return file_to_base64(os.path.join(test_files_path, "mask_basic.png")) + + +@pytest.fixture(scope="session") +def initialize() -> None: + import webui # noqa: F401 diff --git a/test/test_face_restorers.py b/test/test_face_restorers.py new file mode 100644 index 00000000000..7760d51bfbe --- /dev/null +++ b/test/test_face_restorers.py @@ -0,0 +1,29 @@ +import os +from test.conftest import test_files_path, test_outputs_path + +import numpy as np +import pytest +from PIL import Image + + +@pytest.mark.usefixtures("initialize") +@pytest.mark.parametrize("restorer_name", ["gfpgan", "codeformer"]) +def test_face_restorers(restorer_name): + from modules import shared + + if restorer_name == "gfpgan": + from modules import gfpgan_model + gfpgan_model.setup_model(shared.cmd_opts.gfpgan_models_path) + restorer = gfpgan_model.gfpgan_fix_faces + elif restorer_name == "codeformer": + from modules import codeformer_model + codeformer_model.setup_model(shared.cmd_opts.codeformer_models_path) + restorer = codeformer_model.codeformer.restore + else: + raise NotImplementedError("...") + img = Image.open(os.path.join(test_files_path, "two-faces.jpg")) + np_img = np.array(img, dtype=np.uint8) + fixed_image = restorer(np_img) + assert fixed_image.shape == np_img.shape + assert not np.allclose(fixed_image, np_img) # should have visibly changed + Image.fromarray(fixed_image).save(os.path.join(test_outputs_path, f"{restorer_name}.png")) diff --git a/test/test_files/two-faces.jpg b/test/test_files/two-faces.jpg new file mode 100644 index 0000000000000000000000000000000000000000..c9d1b01032a7298d76608c8b65cbb243463491c5 GIT binary patch literal 14768 zcmbVz^LHju)9n-6wr$(Cp4hfEu`#i2+t$QRCUz#~#7-u-dEfhe_Yb(${adZmeb(t! zyQ=oC{#yIm1t7~x%18o0KmY)c?+fs?2?zy%{yzf)0|f&E2m3Z~Nbvs{Bn$*3#PQS_fQ*KMf`*NSO+-Ws^8eOpBU(=V^Q~-H4M7#g@dO9zX4+!d2gNfw(CEqw0*5$;BaF${~6H~tc~o# z_=CaCf^y3T4B`tXfI!ujEhIeVg!1U-Bn4;cI5D`Ju6(vFJnmH}txN}~biq~yee(OMc zX_Jzrvr-8Bw`n3>f?;Cgig5eMy#^l4FnC!wf9R7c|hQBhx+$hVSM_`$oH z3~XPl^$2<}H+=zm8571hkCPGKCRb?{E$vDwNlQL-5iK&L{w#Z#3EA3TYvG*EMkU1A zk*f@MRjxZwZPcwl6=$wCx~_7voQc{@VAPv_1Mcv z`?M)V>uVmKNSu~-A1(SP)dkwC{A$JfZ7G=XUeT`6z|i2-qhh%}ISi2tHw>$lOqOhj z80$M#{939h#mRz#1mq9@yzVS2@>OrDJNzd{hNN*{6dH@~z%; z?_|t|JKV$;ai&gSEjH!3z?qngyNPx;f$?n*?QUUSDi&9I^k8C~cmHpV1`+%KDCB>J6^5GDC*R!s<)wn-EW>O=K2v@>)eR0r(CaP{?EQG%`~9QN0k3O&OD5n zv+O>Q0h3CtJ%_WCHSASGj^tjKYdaX8(@a6GE0!9gP@_-W9V_|D<>N*&PpqLjpp zWEAv+GrSV**8AFNbkaO0^FH6{68!3zx~{P%{TILa$ym*;B3zP?WGxg*{gtk*mtmmp-mix7hY4GYsC|j-3b9cuH>`3G?5;DxwaK1qAg2p zk8a@rb3dz0fy*sq(x?ljVJ>OT7eL@mCbhr9u`$wC$Yw~s@IsxUMdfN?u3VG<;aQ0G zS2l*kKwNjpZ9X#|st`#V)cI3d!V_$b$vs{X@I;`tK18k2>ylT;IBD23zamiA{*(U5l0p3NCg<+fZlkL+ghzHTA$I+6~3etKo*LhX2k28WLQ zPtwLRF`0J6?w5a?0(Q{OvN5?CDGR|&h@|VxmCFk)3;m-}F(#HV&AG?4gO2=E5sA#P zsPc4U82o9>P`Cyi_TgI&;PKHnhNmA-R2(9qQGfONaKDvbHV2~}ug|=&wH2$hD6`$3 zFQjzIxq8<{JXwmkNrf_#XA)JDz(fy`1Uk8YAN^9#E={NSy#8-l0)*qj0G3+aRf->6 zUUc^S7eL07CZ64MZ5^&YOG+f#6v90PSA8S5cSi=rHAlIo&JJp>ckZ?a!IIIyw&Y&U zLBVz-FS=6h@?RJ!Mw=(~>DE(cpsS3kXfXbb*HgO39vvZ7tZe+PJ3m9OU$TMPZ<1fj z&yMmWXa2r;988u&QX#9T2Ketp?@aH8FUrNST=ne6e2IsuuqJ9p(OT|XE9>a0jP|4uJ+r?97atpquh)gM_oaT9Ft zcPG=KTYk`BE0!xc{HJSk;2PBppVp+CcJix>YQ9Aud*-p-2$iGhs_dG)_E{WB@e9!Y z7wJS(YA?JLK}u@)9^W2A+_j+M!E0gdqlt2=b5gtToMp+i%i5VX9c03-hns&fj%Md>yZE&1nx4l*d{?ORjPx|pXltABl>?3+$hB!l$(G2hl<(` zZBlKse0pW#g-`bg#(pz->grd0))GLrOn5g$2(e&2c`-Pk#puFQ_go5^_ncMp5sC#J zs}F`RL(Wwkb^WP4X+cHMfepct2*70F~zf38+^r1FWXYI;duNz-}dbH%dF=X$Yx zqDsf1A-~)%c1lp2nkAP7@sqt-FkolJt^5-Isyao9V7HEMgR%HU+0POILSrL2zqFW} zaLdH0)c*2m8>)nNRD;Qw%_{Ma3)Rc%5|~zXTF-@C^L>`$&ZC1X{dzL0-^sRx#CdXU z>vmXjS}m#vY!s?sG(W#g)xHLn4H*Ujm^$a?&{UZSU8wi+b)O3vxFt=le2mPpf{BmR zdO=OMY9FP5&0YAu^m4TvL%+F1=os$#@E3rob#7&}YGw+(qMKqrqrj8Q1;*q}eZADq@js(w)d+ zhu#TcSAp=J-BFGdMA5)ow1tP>=X4223u+f#KHGsB6?0L_xO>^T-IH`O4%uc46ZZLQ z(MToV8@NdJ{F??#C$ossHx!B>13gZ~4SU_u~HBsK$?Iuz56+i*5&dP`br zS;yQ2>gs%Wcy+?^WPb!l)QI-uyUKpXwN&{^95wbFUFFEhGL_K`sQi3(A(HUJ>V@fHUg)r%wQAIEp%wy~kU;62 z)nY-IzuO7{p-yLmY$lre!Yq~%D{~j+VDG7LT*An-mUNd{-q2TU&=A{}&D^`lh%puI zS2QoHTyxNx^g@4h%JPVgi&|UB`_)8yC?RYUxEoY2^UUHF=rD#r1S9bkaEL&jx6WI} zXL799Gy`{*Ik}~LDAOI9vL8OM&a>Uv;W*}I;W)MNmvPElyO=gwaP_&!)I@aQ3hdoj zw>e@-;Br-HC$C^^PDgLt~at{&lZA4JAmCqC4++}HIh z+o9z!!p}t*hJXptjHd2w)dH!-^9Fl7n;2?FC+vqp6!fZQ2`J||^ma8ELE^nwc4&V- zCDVx{egZk)5d>x1Vl0`n!A-5L=iaKfcGNww=nEin*4bm#z-0X{Fh?eu@GBto&F5_} z0u@r*oAwL_xI0ecoU%5VJd2WwY&qXj4b@XGpH=erdmKN=bgQ7DF1dRKHajdF5C2w4 z<}q-nSsDo%6at4311$UvkHE|O_s0zWa!FhRNS{wJDCHl?^@l3v=wXvAnv$T8Gr%&p zL10U#5kkm~9F1MdY=2{1y7DG%2ADMPTkru#2WURYedE+b0>U7%sBz~DRsKgd z{2{?kDRVJ^m7M~nGC~#xB$pDgo0vdsY#PHJZlx2yq#`APvUIpM#EZj)ANoTotLSPP z6doT9C$SBe{3zWQKzoI4i)AIYTwh9E0h~TI5CDp~=+5?DWqQ1>w$#w1mG$=-_{1Gf zI13@?zRxD1E~jFXeq)>YZV@LFao_$4RLo!d%VTBB?iV{m@v?aeZ02YZpK6G19jZ_# ztn!EzQRB{qt?V07M+^hD7pDqoKqx{b&^iXv1_VjmKoIm)VPl^W8Lb5G`^#_9Oo;SFq9`L@%b%&+Rww%V@4?GZRS@cRevcwZ+FiD zuI?Gtj)yOmk|_wE4ejUrehRqw%Q22&a_NAwFNoL;wLb*NJBZpBAPpT;@R9=@SwTT% z_r>nqEA8Dij4WfOFDXq&x-fux_xU&Et9aw(`9%AH38<{MkVNTjc3 z%iKj7d&eaq5iEML1RO)QI=@t>B>OD~T*u{V8qS6K+`5!m_7UH2CmPOs*|^>g`BikY zf;cK6hrzD2nWxb@cIlF7$Q_4m?gZt;*8DM`X4(Nn5r!IK| z=f*>nX^Kr1pRFA~zKYuk!IrlIJz$nIe=3N+03rg8EaF`cM8FOng9py5qbG@buGw2p6ZJbxK56z!CUs z=<0Z8ZGbdp@paj=FlGU{yjmk#wqZv&q3&&KXCjDCx0mp=&=9@Z^-e8N2@mNwVW!siGT zDqC@l15j%eV3YOf+3c=QkUg5eK9yr&V!rJ#bJ6fZ8(?yCDIuZF`&;`6s;{alDUXHD zlwcDxkP#2IwrYe0w}y+j@;lm?pSG|N{LkB(2UEBoO~rqNVfccvxy^0|kZ(-02A2^1 z5gqdzg!T^%=A$A-kt%XW3cuyixm&%#pYt)4;7~%OwYo4{{>{JgNUh4t%Pq4nGwygy z{`{z7{Nk&~f=JfSL~Ap5AZwf$VSpe8mi>HMrRAa|B8AQ~oOk@(@v1E*cZuq>F()~N zoR3^r`SXoabWKYk+j#52fNg;4$A1Ro?PFjR$aaRGZ?y1?IAD00myXv&mH!G=qW4j`9Oxo8#i`%L z@D<&&`Q!w^76Q>=;cI!KitXVZE@+hA*bL63hntH5&Yz`$d80D+umrSGWUxI z8{(}iTqP|#j$#AfQq@(^8mSd^R+GtMd_dis>Iy2&F-ccBm0cp-Xf0->@2*`bqfv@) z`~)dN)dgYGG$e2mvya@D6^YNzonbSna&LGw1*I$~NB1mm9tPQ?pt<0+ zNw8E}XQL$#$OUk|3D($@J~U}@4qR`wce7*LMy;gO1BE;-iO}j-!53g2xT@CHmS*wu z;K;)vq-%t^9vS$1-;UK+HgkxbGCxDlpTOtAfZCm3sPdggD!e%)ahW@1NaO>>!M`%2 zIs5GyGas|QTRM(qpe%NJZ9TVFokD=lnr7FOthK2@3!H{*k{w``Z@6cZ*iDic{|}E5 z&%|i^dRnSL5(#eM_?~B0XdASfE>`&>WAXf_eaf4}ZWB zSC(PDaDyJdPZlN57XUx!>hl~rJ3SE0Ux%I!vp5O2>c{q^eGbTzvT^HNYo@*eo!)6A zQI_p34xNk3sFz!lUEFTR`OX9grdYvMgBZ=z7FFKiKLsYT*1ttVhp`EB2iV;J%<**s zpIq32vXm8B@}#nz-l_nUOR%LXd|oV+f01>hj@tNDEZ!5>I`j<)_a^8WEw>!1j&8}AJ%-q=E){WY1Do8U1;QdaGz0$zhdjW*y$TkTaUg3`&tyH@>xl~y| z49<`8ogW_004KKdt`|Aqx6^F!c~v>03g_}GM8kt}+FZjXvtVb*{e-$oYD4~Pe8(Z+tp;yde1 zp4y?KCMh|M)cz(qiR5xLtt8Ia!QD*8Ckq#F1RP{roIHMmP~)p*k5d^T z)9+66_E^_mp_yyr!K?m0I+V!8mBkT$Mod=5z{*SAjsswbe1+)trb;tV7=xamj#O2@uSwjm*JV%tjM<66>fsR%TMOG7`G!mC zxp{Qaoaqc;S|qM;ng6T$qjv{2%7t`ZNjv1eH-5>n{L4S;x#>}_oO7&RoYrZgz@t+# zpZq+Bw=7BQ+6XE@{D(4M1*79y*#vrGK3!Pf*#qn3sGCpr)4UAI^2F%zxy=~z=8+-? zo8r^m2wdcAw{Vk{H^e`c+@aK*i3Q%Pt=#l3pU9<}E7$OAVx$RBbN3de#lt1J583_c zf%6y@mO5Mn>CskFEa#gk9(Iy^lyCdb$e$Uyf-o(7BKe`L8eWJVhn*w5R8YJ*!x?Vc`K7cRUS zm+0EH_ace>k6p|Q@=vcnCwDUSC?kH2&%2!w>hV_N;RR1H<(`OPyJ0{IH>( zU_XCDysr=$HI3^?mgOCsD%(r2X)h)Dm*PcZ3*+hBuD!mGXKx>Xt!{&dz8Qq-?9nVIj6ylw@`i z*$@XjkWFAvW8 zv^*9f{-w3TZ%K(gn_IJ-M>>C-tNbJk`P}|5As;zU>Xq&9oy~KKDvI6jt|@J$vUpUa zLTJG*q}+(fbI$!F75+`Gp6E}}$r(&lgyc8$t7;%GeRut`7Jb9;KcKo#nm3eJ?vuFr zYmp1lkwyx;n~~M`r@SVMi~eRwDUiQe(f_FK|CQJOqppJhP)JxrRE)ow(ZWXLft&fg z{}qjdko75Ir;y=i6=sr|d#vs8ESIl%mC1RfK*DCnwL?A%GBi4jaQV&AfSN3+`VyrZ zg+0=eR<8~!Aqib^K}H=Yc+sjcPA!dM%!wFc~YSadN!a z?MLQuvPF}Nb;thu$jBH_>OK_Wzb=TQ{YN}lyCD%p(hwqCs;%{9JoN|I&uIW(fFNk; z&??v%RV!+f$V!7W^L@;@5k}Hsm-XNjJM7FA*~*`~6>=O-OPm|#^HympQ9q5-J4Gz$ z8AKR$Q>w058{}QnWpo;qD@Brs$qzKKyJiN)!)kXAbgWs{R8SHQNEbS9F>hQS`hNm# z2*xxclM8-|!Cx9z^n*TNp;*TRg$x(iBcl%47~z-A4_{i)g;SIxLn)dDnmZ|UrmyYE zLbj~{q;Ihy^@dfC0!uNA8D(V0M^VRK47Rp7HONwXVA|RmG!ueqDHc@PbgxQ@bNjo5 zsX7k_2q3RWgPELuZiN(=r5*U04iJz{Fi44Uol9a*sUE{)atuS$6hIs!N!zas+)-(g zbrTDRqk*UElA975vw?<7qN~bcb;wWl%BnV5vY4ss`J%}KMzAbd@*88fH=?qE8VtRo z5X@3|TqT7f1+<8Nr0ZKtup+GiLIXjUMCtX z$U>YAScf#70Ke4|QW<83Uqrk)IhYuXL$*4eH~-#ACOU8pqZaB&Ni?H~B(ZWr$5jnB zA#1BUz{LJ0b;iZEhKc8lm8uG4HGqhStf3Fhb?R0th<=KgCdhb!5nkLiM1nL8yUl@W z{TvnQP-qyBHpQ*DUC7(!?@E6;CwYTR@3g!}#KsKNDpvjibdxi!8|7VTw8vek!n6Gl zL=guVN%_+E_3vSEGDlPxQ8A7XBFPNX4Vm!7#|=wXd0>+lb3}Gl_H z5L1k$Bhx7&vAk|Bs7Kb7+8xo@cQjV2fK{lq2*DO{G*^FkSvVNDmt*}xTU%8a!b;P0 zsB0ex)EPY8=aHY(VP1wfyO2SoI5Ph|xXKt3dYZvi!lT*8EP4T=u0I^!R)b2w zYRlot*w#sT7xd8W&t#Wj)h4T#&zRGn)Hv6D>kc{4m<3-ZGA(kL`D)!skE)F{+>!^O z^nMNt?Eyqw;+KNlTk_L-)MOOqyS=!D2cjpf%kQnF+@Hf60uJm&@3^LCCUIZ^6)=dZ zB*%7A*naCI_Yq`NJI*t&29KxQ6pL>h)>BZj%F17%DuCEyeDj^miYTde-$fN9f^|$GDYntaZ@Ua+vPJIhFRtNREeZkY z$ZvRWN=hR#emy3!3b_?yb$<}@U-HIgsu^VlF48f)6-4`NB05R?gEkE3`ZTob{w700 z{-T|rov7=xAIP?{SbV|#N*z?EdSb)i0O2E;cAtVEtO(@gEu&K>&i-MLrkYqM3A6}~ zB>H_5Q*DR60fAFi%=V&P`*=u}WS|}4v|^saDD=JDj)3 zsB}2=0Tbge%FOqm>?>RlkNzY(BCPXH1TP-P4|ip-ze2SGL0Y<{`_1E-%2QN>pCQV_`!eVh53<9%kyJ);4C^z? zE)by9wlyM~8tbsWmf#;yLwWr#f+m z%Z8`&PjF>`iTr?EJGg+)q&dJG9)>Q0q}DgsZ`h;Io(HVI*;8~s$YHr8jieU&>sCas z)#PVO?rJa2Qgi&m-j_|+P?oD$1^KaLG6`TsnvX0+C7RCt9n4FM53@N=`NuloeJvNG z>t@DBOC+4lXp0ig6s%O>K8%jAP^Zul3wqHE%xqdsEoB-t(OREE2+9)Z-^S_}o}=kw zy9dE=y{Am2FSiXbAe1kSh|AnkzHM{yTX3?~GF*v!*g zy=6q?QmE845+xO4CWK$kP9ite%LHCj8Pe40Px6H`Q`GYw)nLL1IwV1(e2rKh2% z0bPn-ooV%Nhc&N{*m@t0YuKlP>IHu5U(UG`6b4tC^ zj*0-A%V;Qx80nkg_(Gjd7c=Ur4;XC}&a_^kqKm?NJXv+Aa-O9qUeVrW8?@%uMsH)x zYq~|jl;)s)uXHIQCn2XXijom+7HAe=coQE2Kf|f^^7|K_gH+;zL7}^8y+_M({LC}FInu9)Z_;a)k-?MPV%Jv-Cn!y?*!k%d%a*Z@Hfl)U z7GRM_+OWHVeRJU)#%)CA6&;X?PKc^rsu>SJ;YusniL|R1FXRa+2%f&{i-Wc@JoN}$ zTC38@){YHkh9s<_Kr(_hsV8&l5irsTVA%&P1U#gB;)r@X4T*PHNnSihhMONo%6e4C%yBoI3lQwvFBmYuv@~fJx)8Yx1aKnNIjl;=a zqL#BuxbT6aHPXyoJWwQ!Xi#TbegtU|SIJlDTamywlFfk}rD+J(Hs-=aFp@7vuVD=_ zWnZYQ856b@M7p3n;qfOk8ZfYUQjlW-IOKfG8w7r6=VIL`Y(aFPuNpH4)7&2G6#a zxQw6z=+iJhQe}L)B}Ido9x)&Q1FVu(5U&q$iP4ohhVkOGX4b~vaWFZ`tJY!BO@WH~ zCR5zXZ2fAmJOEFv_?$>m6)MIE+R7CM@BOvRgR%{&OxUqSYrLtabAsefMw3QRfK8j|s+0eWg%;-q)z|jn`RCj5Usd#Gd?=@lCgpymXMiDq?T& z^9We0P5v(a=zt!inyXEdv1@UF{Q@YDtcM?=I?FgA&OtT9Lp5OrBqNt5je)WRXX%?p zKGVg$^o}{D9tIc#=$edbryl{u9j_&$G=n1ld?J6t>VuR)G{_A`2PAtz<+0|B-1lvx6orHoD5qNWJO&#U0aqI~r;s%nHm}g`u?a2zQ8=Bh zW>GXC3{{KsDYtK^Zi!b-tF9Cu6|Wdsb`9-g)F2uvcNiTw1C8b5?2S!W=dI+A3&HhH zhs{i8NtYF`ZU$;Ff8Rf5nzey!_fIFKMmkaV(BFtyxF$fc+}iyu8sxOlXdN|mYDk|M z5L|h`!sY~MI)&n=Jpx2`YN)S2a~tJWuFhyF=H=}au=}W?oq`7k`xok=*yx-*qhT^D zJ}WSbG?Pt+`xaDSG?OAcN-|d30?IJT*VRoa53I~$*O-{;Jz273WkGdEz{mLfPd=49 zq*$TuQFI$D-lK6Cfj7=*L$EF7T|p}NnF6d$UZIZgSExmQO$aO&ZkcwAH?7|>aEFlr z$AUnxbL4eR$;l2xnOX>=mj%Z~HZo%obwr74ClY;?nIlHL2#|H%(u6re?7*j)+3~Zi zEHgQNVaQ4aFF(JEuq+NwE!)>%X5HDqVVy~kBhbB7$<;$p=ow#A6H$B-7H)IMqo9;g zx~Nye{22!kLula;jIIG>OXz}lr!7KEIF3lS=I@S;4dr_r8VnTTd$tAQKT|E=9UCYZ z02~FC1On~5ZzDB!hAbpw6%B$yCwEE2R849e_;3FP3IZYo$iz!?;Tn+wV-~cr(wFHk zSNwQe{)N}mMxx&t2?^l%EbMe26vpUztp4dx?qJ&Jn71?%e_Rs6khe5XtyeKnVP-}g z6cL&9d5cXmK0{^%@Q$*M9!5tPkzy?BimxioV3+;76xEa=eFsSZA2fN@~74Uy&{S14FfisiGbDbfiWQICe$gqp1P-)arz9~%HgIIzt znWBM=9okf!^!B(H63k(6*%Wv{<*VXr}Fe;)Wx|y>!-#7$1$h5!%$^e)~-Wj#J>uC5)zAu zRH_7*vP>-p8WW_XZd{fWFm;A*`F?^ha#9XuKGj>98`dhC>#}6xAW>^y^&CfMW~jT^ z8EJ7^rJ$@Z^tZ#9=G+Mm>ab*Z_7_3A@X2OAGGuK29+gOv$qBvzfpLl!cUW5Z}kz@JNFViIVnqxl8lqXZxM>xp?qj^e?3#nYxRwI4cn zl({*;VaV>)x0m3L{w=X2;0&Wob%JvaYeG%wZ4nxTL37Ru^HkrG>q)d&Ln@cB!SAOA z>xIwYlmd<+h`9$L;BRR4_occk;@fPd*FUI0aPGWqUfPRE6<-V1or*5)RM6$VQaFr^3bE3CGAT70rWE@BBSS4;g5*ipAvUMNYjIQ`KOY)#6Mkvfiaee+ysDG|pWuNWgLMvknHuOF(x`jrIHN zu~izt8qfjmB;mjqWS^*geCka!eySU`j;oencrZFNO;6dAbl6=@q%5obPOWrcT!%`I zVw! zn+@SOR_W&(n?IlX3leE;9UsdP;RFN45Aln>9xHS&Mz%7P!msVz%3~D^E_6S?0EjQ; zNSb+r61{&#I0Hsv}uSxXIbts;uiZ4i4Zc(m< z>QXK=N;bXu63saWc);QFCF-X;1>H?}^V1&YI+z!g&ic|)_q}ft8dnC|5R8LhHiy$6 z0!=BHmD;_IgRmQldqNa90n3QxGHP$d>mBe0e4PD+2IEHJ3?pJ!#xJ`$zp4v6Wj{`@yd}bo-njyD(j*& zBVBc_60V*OuKo+LC0xOpsG8601S3Y$>aB$Qi8RMJj-RNHuFIzyI9Jl9N*5c-IJ8#x zm|MYeum3lc^!Kpceu6pNVLY#AB`?>daKZx0l3V>Z(zE{;>Hjms2nzRqe~I~Dq-TK; zK_gXhhD2vIMh+q)XA@O*`6lZI<|+OQ_20waLI4era9R`oTIs!TItrNwjynr;eHifMZP&ZU{(%*Z%LoiQ{Mp1 zbZJpr28Fn?lpLNQRt!#lO((i6(?!35tsjF*ECtjk1MPGOgCBKss>qFYFz7DkR{t01 zmPrH^JCY5x`>g>EzpDl&3@Y5uQQcIv%|A&zRO7KYjq4I}c4w2fGwyKY?x>^e z0liwA~t@uA2bM6Iv{2}kJ)d$BV(E#u1;1b0yq+ zuYT!KW+c9w?jFxUi+iM{*qw?Q56SUAyA&%QGbADK4Thauk+8uT^QP^uVfoM=>n3{D zxM8RUKlNtGUdwV5l!cGAi+-C1G1nocw_KmFdUV# z3@YUSyQrVV;K1hn3jm?cLp&#zvtKB-U$(v@l&U>|%Zmg1646L(e=2w!J$)S(gZ-Hc zZU(%wG6XLoIe>;-f2xjL&O|#tYTvCU21mINssE`wWNlH)*lt0k!F~vGrjN>f^(G60 zGdVn66S?8KxZ4 zB%|;0nz@o{ZX$#H$Hc(RXpL&=;e1+j*%-POPtqx-@u}!jX{vIM&G2gz_9`$!VmBo= z6H;hIzt6(kl3Yb4<~YyJKKFyB+=UqrY|dUJn+M)iAD&`NDQd5ujjk!}sj_wO{8Wht z+y#M3r6r5GU7v$c-eHUw?OeiLu^yw~M{Oc_d8kO3th%LT;sm|#aNgmHHa2S#j-m+0 z0`FNy2^wlUvE4!V*b4{4$l<|LNyXI|2SO*j{w9OIjes0_ROg)MSg8WH*QN@|1Qk5= z$)8Aa8_?#q{nmW_Gef`AR$n@9PEY1MRlQGnS}Bd{zpL#(9~xIA2b;9c$9e^PP4aA88&T>*|=^ z^tYZ=JJJVT^8BSQTs5{W{z%TVW~{wk{aR_(AF(vD-NRfJxFG*fsg8H&{sOqut3o>U zR#!>w>h>i4I>SFxQ7$`s!i|-Lj~nSwwhr8UTUBhoWg-(5d_>(7ZAS-O_E#PmIkk{v zd4eNEkY7OfYE*dl40zHe2Ky<}Tt1?*TuX-8+3r>LFjk@_2N9cnzk)@&*n)c;hkRA@ zoQTf@wDBfoFW^!aCBpZ9tf7S0D6zXyop8+CYK$;jQ$gtifo?J&n8t%=;;EFr&WV$K zJaf6+JEf5qr2he9=>*G9_DZFTy@ci;<6u+_kyz4BaLlJ Date: Wed, 27 Dec 2023 10:55:01 +0200 Subject: [PATCH 0484/1174] Add experimental HAT model --- modules/hat_model.py | 42 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 42 insertions(+) create mode 100644 modules/hat_model.py diff --git a/modules/hat_model.py b/modules/hat_model.py new file mode 100644 index 00000000000..553e1941170 --- /dev/null +++ b/modules/hat_model.py @@ -0,0 +1,42 @@ +import os +import sys + +from modules import modelloader, devices +from modules.shared import opts +from modules.upscaler import Upscaler, UpscalerData +from modules.upscaler_utils import upscale_with_model + + +class UpscalerHAT(Upscaler): + def __init__(self, dirname): + self.name = "HAT" + self.scalers = [] + self.user_path = dirname + super().__init__() + for file in self.find_models(ext_filter=[".pt", ".pth"]): + name = modelloader.friendly_name(file) + scale = 4 # TODO: scale might not be 4, but we can't know without loading the model + scaler_data = UpscalerData(name, file, upscaler=self, scale=scale) + self.scalers.append(scaler_data) + + def do_upscale(self, img, selected_model): + try: + model = self.load_model(selected_model) + except Exception as e: + print(f"Unable to load HAT model {selected_model}: {e}", file=sys.stderr) + return img + model.to(devices.device_esrgan) # TODO: should probably be device_hat + return upscale_with_model( + model, + img, + tile_size=opts.ESRGAN_tile, # TODO: should probably be HAT_tile + tile_overlap=opts.ESRGAN_tile_overlap, # TODO: should probably be HAT_tile_overlap + ) + + def load_model(self, path: str): + if not os.path.isfile(path): + raise FileNotFoundError(f"Model file {path} not found") + return modelloader.load_spandrel_model( + path, + device=devices.device_esrgan, # TODO: should probably be device_hat + ) From 4ad0c0c0a805da4bac03cff86ea17c25a1291546 Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Sat, 30 Dec 2023 16:37:03 +0200 Subject: [PATCH 0485/1174] Verify architecture for loaded Spandrel models --- extensions-builtin/ScuNET/scripts/scunet_model.py | 2 +- extensions-builtin/SwinIR/scripts/swinir_model.py | 1 + modules/codeformer_model.py | 1 + modules/esrgan_model.py | 1 + modules/gfpgan_model.py | 1 + modules/hat_model.py | 1 + modules/modelloader.py | 13 ++++++++++++- modules/realesrgan_model.py | 7 ++++--- 8 files changed, 22 insertions(+), 5 deletions(-) diff --git a/extensions-builtin/ScuNET/scripts/scunet_model.py b/extensions-builtin/ScuNET/scripts/scunet_model.py index 18cf8e1a0b6..5f3dd08b337 100644 --- a/extensions-builtin/ScuNET/scripts/scunet_model.py +++ b/extensions-builtin/ScuNET/scripts/scunet_model.py @@ -121,7 +121,7 @@ def load_model(self, path: str): filename = modelloader.load_file_from_url(self.model_url, model_dir=self.model_download_path, file_name=f"{self.name}.pth") else: filename = path - return modelloader.load_spandrel_model(filename, device=device) + return modelloader.load_spandrel_model(filename, device=device, expected_architecture='SCUNet') def on_ui_settings(): diff --git a/extensions-builtin/SwinIR/scripts/swinir_model.py b/extensions-builtin/SwinIR/scripts/swinir_model.py index 85c18b9e939..aae159af56e 100644 --- a/extensions-builtin/SwinIR/scripts/swinir_model.py +++ b/extensions-builtin/SwinIR/scripts/swinir_model.py @@ -75,6 +75,7 @@ def load_model(self, path, scale=4): filename, device=self._get_device(), dtype=devices.dtype, + expected_architecture="SwinIR", ) if getattr(opts, 'SWIN_torch_compile', False): try: diff --git a/modules/codeformer_model.py b/modules/codeformer_model.py index ceda4bab919..44b84618ec0 100644 --- a/modules/codeformer_model.py +++ b/modules/codeformer_model.py @@ -37,6 +37,7 @@ def load_net(self) -> torch.Module: return modelloader.load_spandrel_model( model_path, device=devices.device_codeformer, + expected_architecture='CodeFormer', ).model raise ValueError("No codeformer model found") diff --git a/modules/esrgan_model.py b/modules/esrgan_model.py index a7c7c9e3013..70041ab0234 100644 --- a/modules/esrgan_model.py +++ b/modules/esrgan_model.py @@ -49,6 +49,7 @@ def load_model(self, path: str): return modelloader.load_spandrel_model( filename, device=('cpu' if devices.device_esrgan.type == 'mps' else None), + expected_architecture='ESRGAN', ) diff --git a/modules/gfpgan_model.py b/modules/gfpgan_model.py index a356b56fe1a..48f8ad5e294 100644 --- a/modules/gfpgan_model.py +++ b/modules/gfpgan_model.py @@ -37,6 +37,7 @@ def load_net(self) -> None: net = modelloader.load_spandrel_model( model_path, device=self.get_device(), + expected_architecture='GFPGAN', ).model net.different_w = True # see https://github.com/chaiNNer-org/spandrel/pull/81 return net diff --git a/modules/hat_model.py b/modules/hat_model.py index 553e1941170..7f2abb41660 100644 --- a/modules/hat_model.py +++ b/modules/hat_model.py @@ -39,4 +39,5 @@ def load_model(self, path: str): return modelloader.load_spandrel_model( path, device=devices.device_esrgan, # TODO: should probably be device_hat + expected_architecture='HAT', ) diff --git a/modules/modelloader.py b/modules/modelloader.py index 30116932a1d..f4182559e59 100644 --- a/modules/modelloader.py +++ b/modules/modelloader.py @@ -6,6 +6,8 @@ import importlib from urllib.parse import urlparse +import torch + from modules import shared from modules.upscaler import Upscaler, UpscalerLanczos, UpscalerNearest, UpscalerNone from modules.paths import script_path, models_path @@ -183,9 +185,18 @@ def load_upscalers(): ) -def load_spandrel_model(path, *, device, half: bool = False, dtype=None): +def load_spandrel_model( + path: str, + *, + device: str | torch.device | None, + half: bool = False, + dtype: str | None = None, + expected_architecture: str | None = None, +): import spandrel model = spandrel.ModelLoader(device=device).load_from_file(path) + if expected_architecture and model.architecture != expected_architecture: + raise TypeError(f"Model {path} is not a {expected_architecture} model") if half: model = model.model.half() if dtype: diff --git a/modules/realesrgan_model.py b/modules/realesrgan_model.py index 332d8f4b108..2a2be5ad7fe 100644 --- a/modules/realesrgan_model.py +++ b/modules/realesrgan_model.py @@ -1,9 +1,9 @@ import os -from modules.upscaler_utils import upscale_with_model -from modules.upscaler import Upscaler, UpscalerData -from modules.shared import cmd_opts, opts from modules import modelloader, errors +from modules.shared import cmd_opts, opts +from modules.upscaler import Upscaler, UpscalerData +from modules.upscaler_utils import upscale_with_model class UpscalerRealESRGAN(Upscaler): @@ -40,6 +40,7 @@ def do_upscale(self, img, path): info.local_data_path, device=self.device, half=(not cmd_opts.no_half and not cmd_opts.upcast_sampling), + expected_architecture="RealESRGAN", ) return upscale_with_model( mod, From 05230c02606080527b65ace9eacb6fb835239877 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sat, 30 Dec 2023 18:02:51 +0300 Subject: [PATCH 0486/1174] fix img2img api that i broke when implementing infotext support --- modules/api/api.py | 1 + 1 file changed, 1 insertion(+) diff --git a/modules/api/api.py b/modules/api/api.py index 2918f785750..2e18c6b9185 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -507,6 +507,7 @@ def img2imgapi(self, img2imgreq: models.StableDiffusionImg2ImgProcessingAPI): args.pop('script_name', None) args.pop('script_args', None) # will refeed them to the pipeline directly after initializing them args.pop('alwayson_scripts', None) + args.pop('infotext', None) script_args = self.init_script_args(img2imgreq, self.default_script_arg_img2img, selectable_scripts, selectable_script_idx, script_runner, input_script_args=infotext_script_args) From f476649c02cf3547d891fa08c50a92f92c4d73bd Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Sat, 30 Dec 2023 17:41:19 +0200 Subject: [PATCH 0487/1174] Correct arg type for restore_face --- modules/face_restoration_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/modules/face_restoration_utils.py b/modules/face_restoration_utils.py index c65c85ef89d..85cb3057050 100644 --- a/modules/face_restoration_utils.py +++ b/modules/face_restoration_utils.py @@ -36,7 +36,7 @@ def create_face_helper(device) -> FaceRestoreHelper: def restore_with_face_helper( np_image: np.ndarray, face_helper: FaceRestoreHelper, - restore_face: Callable[[np.ndarray], np.ndarray], + restore_face: Callable[[torch.Tensor], torch.Tensor], ) -> np.ndarray: """ Find faces in the image using face_helper, restore them using restore_face, and paste them back into the image. @@ -126,7 +126,7 @@ def load_net(self) -> torch.Module: def restore_with_helper( self, np_image: np.ndarray, - restore_face: Callable[[np.ndarray], np.ndarray], + restore_face: Callable[[torch.Tensor], torch.Tensor], ) -> np.ndarray: try: if self.net is None: From 91560e98c47f8271d444556ef4ae6505dece9aba Mon Sep 17 00:00:00 2001 From: lanyeeee <1210347077@qq.com> Date: Sat, 30 Dec 2023 23:42:10 +0800 Subject: [PATCH 0488/1174] fix format issue --- modules/api/api.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/modules/api/api.py b/modules/api/api.py index 2f718ec262f..d202cb8d27d 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -418,7 +418,6 @@ def text2imgapi(self, txt2imgreq: models.StableDiffusionTxt2ImgProcessingAPI): task_id = txt2imgreq.force_task_id or create_task_id("txt2img") script_runner = scripts.scripts_txt2img - with self.txt2img_script_arg_init_lock: if not script_runner.scripts: script_runner.initialize_scripts(False) @@ -489,14 +488,13 @@ def img2imgapi(self, img2imgreq: models.StableDiffusionImg2ImgProcessingAPI): mask = decode_base64_to_image(mask) script_runner = scripts.scripts_img2img - with self.img2img_script_arg_init_lock: if not script_runner.scripts: script_runner.initialize_scripts(True) ui.create_ui() - infotext_script_args = {} - self.apply_infotext(img2imgreq, "img2img", script_runner=script_runner, mentioned_script_args=infotext_script_args) + infotext_script_args = {} + self.apply_infotext(img2imgreq, "img2img", script_runner=script_runner, mentioned_script_args=infotext_script_args) if not self.default_script_arg_img2img: self.default_script_arg_img2img = self.init_default_script_args(script_runner) From c9174253fb603e6b2552e4c2721fd767b6ede87d Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Sat, 30 Dec 2023 17:45:26 +0200 Subject: [PATCH 0489/1174] Drop dependency on basicsr --- modules/face_restoration_utils.py | 35 +++++++++++++++++++++++-------- requirements.txt | 1 - requirements_versions.txt | 1 - 3 files changed, 26 insertions(+), 11 deletions(-) diff --git a/modules/face_restoration_utils.py b/modules/face_restoration_utils.py index 85cb3057050..1cbac236480 100644 --- a/modules/face_restoration_utils.py +++ b/modules/face_restoration_utils.py @@ -17,6 +17,28 @@ logger = logging.getLogger(__name__) +def bgr_image_to_rgb_tensor(img: np.ndarray) -> torch.Tensor: + """Convert a BGR NumPy image in [0..1] range to a PyTorch RGB float32 tensor.""" + assert img.shape[2] == 3, "image must be RGB" + if img.dtype == "float64": + img = img.astype("float32") + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + return torch.from_numpy(img.transpose(2, 0, 1)).float() + + +def rgb_tensor_to_bgr_image(tensor: torch.Tensor, *, min_max=(0.0, 1.0)) -> np.ndarray: + """ + Convert a PyTorch RGB tensor in range `min_max` to a BGR NumPy image in [0..1] range. + """ + tensor = tensor.squeeze(0).float().detach().cpu().clamp_(*min_max) + tensor = (tensor - min_max[0]) / (min_max[1] - min_max[0]) + assert tensor.dim() == 3, "tensor must be RGB" + img_np = tensor.numpy().transpose(1, 2, 0) + if img_np.shape[2] == 1: # gray image, no RGB/BGR required + return np.squeeze(img_np, axis=2) + return cv2.cvtColor(img_np, cv2.COLOR_BGR2RGB) + + def create_face_helper(device) -> FaceRestoreHelper: from facexlib.detection import retinaface from facexlib.utils.face_restoration_helper import FaceRestoreHelper @@ -43,7 +65,6 @@ def restore_with_face_helper( `restore_face` should take a cropped face image and return a restored face image. """ - from basicsr.utils import img2tensor, tensor2img from torchvision.transforms.functional import normalize np_image = np_image[:, :, ::-1] original_resolution = np_image.shape[0:2] @@ -56,23 +77,19 @@ def restore_with_face_helper( face_helper.align_warp_face() logger.debug("Found %d faces, restoring", len(face_helper.cropped_faces)) for cropped_face in face_helper.cropped_faces: - cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True) + cropped_face_t = bgr_image_to_rgb_tensor(cropped_face / 255.0) normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True) cropped_face_t = cropped_face_t.unsqueeze(0).to(devices.device_codeformer) try: with torch.no_grad(): - restored_face = tensor2img( - restore_face(cropped_face_t), - rgb2bgr=True, - min_max=(-1, 1), - ) + cropped_face_t = restore_face(cropped_face_t) devices.torch_gc() except Exception: errors.report('Failed face-restoration inference', exc_info=True) - restored_face = tensor2img(cropped_face_t, rgb2bgr=True, min_max=(-1, 1)) - restored_face = restored_face.astype('uint8') + restored_face = rgb_tensor_to_bgr_image(cropped_face_t, min_max=(-1, 1)) + restored_face = (restored_face * 255.0).astype('uint8') face_helper.add_restored_face(restored_face) logger.debug("Merging restored faces into image") diff --git a/requirements.txt b/requirements.txt index b1329c9e3ab..731a1be7d9b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,7 +2,6 @@ GitPython Pillow accelerate -basicsr blendmodes clean-fid einops diff --git a/requirements_versions.txt b/requirements_versions.txt index edbb6db9e0e..1e0ccafa7d2 100644 --- a/requirements_versions.txt +++ b/requirements_versions.txt @@ -1,7 +1,6 @@ GitPython==3.1.32 Pillow==9.5.0 accelerate==0.21.0 -basicsr==1.4.2 blendmodes==2022 clean-fid==0.1.35 einops==0.4.1 From b58ed1b2432c3c7643b39e53f7bb567ea8655aae Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Sat, 30 Dec 2023 18:02:01 +0200 Subject: [PATCH 0490/1174] Bump numpy to 1.26.2 This avoids it being downgraded during `launch.py` --- requirements_versions.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements_versions.txt b/requirements_versions.txt index edbb6db9e0e..7ec7abe2ee2 100644 --- a/requirements_versions.txt +++ b/requirements_versions.txt @@ -13,7 +13,7 @@ inflection==0.5.1 jsonmerge==1.8.0 kornia==0.6.7 lark==1.1.2 -numpy==1.23.5 +numpy==1.26.2 omegaconf==2.2.3 open-clip-torch==2.20.0 piexif==1.1.3 From f651405427dfc6d4ef96ecba7f9c2ceb580263fd Mon Sep 17 00:00:00 2001 From: lanyeeee <1210347077@qq.com> Date: Sun, 31 Dec 2023 01:09:13 +0800 Subject: [PATCH 0491/1174] remove locks, move init code to __init__ --- modules/api/api.py | 37 +++++++++++++++++++------------------ 1 file changed, 19 insertions(+), 18 deletions(-) diff --git a/modules/api/api.py b/modules/api/api.py index d202cb8d27d..fc3921c23e1 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -251,8 +251,21 @@ def __init__(self, app: FastAPI, queue_lock: Lock): self.default_script_arg_txt2img = [] self.default_script_arg_img2img = [] - self.txt2img_script_arg_init_lock = Lock() - self.img2img_script_arg_init_lock = Lock() + txt2img_script_runner = scripts.scripts_txt2img + img2img_script_runner = scripts.scripts_img2img + + if not txt2img_script_runner.scripts or not img2img_script_runner.scripts: + ui.create_ui() + + if not txt2img_script_runner.scripts: + txt2img_script_runner.initialize_scripts(False) + if not self.default_script_arg_txt2img: + self.default_script_arg_txt2img = self.init_default_script_args(txt2img_script_runner) + + if not img2img_script_runner.scripts: + img2img_script_runner.initialize_scripts(True) + if not self.default_script_arg_img2img: + self.default_script_arg_img2img = self.init_default_script_args(img2img_script_runner) @@ -418,16 +431,10 @@ def text2imgapi(self, txt2imgreq: models.StableDiffusionTxt2ImgProcessingAPI): task_id = txt2imgreq.force_task_id or create_task_id("txt2img") script_runner = scripts.scripts_txt2img - with self.txt2img_script_arg_init_lock: - if not script_runner.scripts: - script_runner.initialize_scripts(False) - ui.create_ui() - infotext_script_args = {} - self.apply_infotext(txt2imgreq, "txt2img", script_runner=script_runner, mentioned_script_args=infotext_script_args) + infotext_script_args = {} + self.apply_infotext(txt2imgreq, "txt2img", script_runner=script_runner, mentioned_script_args=infotext_script_args) - if not self.default_script_arg_txt2img: - self.default_script_arg_txt2img = self.init_default_script_args(script_runner) selectable_scripts, selectable_script_idx = self.get_selectable_script(txt2imgreq.script_name, script_runner) populate = txt2imgreq.copy(update={ # Override __init__ params @@ -488,16 +495,10 @@ def img2imgapi(self, img2imgreq: models.StableDiffusionImg2ImgProcessingAPI): mask = decode_base64_to_image(mask) script_runner = scripts.scripts_img2img - with self.img2img_script_arg_init_lock: - if not script_runner.scripts: - script_runner.initialize_scripts(True) - ui.create_ui() - infotext_script_args = {} - self.apply_infotext(img2imgreq, "img2img", script_runner=script_runner, mentioned_script_args=infotext_script_args) + infotext_script_args = {} + self.apply_infotext(img2imgreq, "img2img", script_runner=script_runner, mentioned_script_args=infotext_script_args) - if not self.default_script_arg_img2img: - self.default_script_arg_img2img = self.init_default_script_args(script_runner) selectable_scripts, selectable_script_idx = self.get_selectable_script(img2imgreq.script_name, script_runner) populate = img2imgreq.copy(update={ # Override __init__ params From 1465dab71564bb30091479ceabae6c69e3426bc6 Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Sat, 30 Dec 2023 19:44:05 +0200 Subject: [PATCH 0492/1174] Make Tensorboard a late import (it was implicitly installed by basicsr) --- modules/textual_inversion/textual_inversion.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 04dda585cc9..c6bcab1538d 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -11,7 +11,6 @@ import numpy as np from PIL import Image, PngImagePlugin -from torch.utils.tensorboard import SummaryWriter from modules import shared, devices, sd_hijack, sd_models, images, sd_samplers, sd_hijack_checkpoint, errors, hashes import modules.textual_inversion.dataset @@ -344,6 +343,7 @@ def write_loss(log_directory, filename, step, epoch_len, values): }) def tensorboard_setup(log_directory): + from torch.utils.tensorboard import SummaryWriter os.makedirs(os.path.join(log_directory, "tensorboard"), exist_ok=True) return SummaryWriter( log_dir=os.path.join(log_directory, "tensorboard"), @@ -448,8 +448,12 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..." old_parallel_processing_allowed = shared.parallel_processing_allowed + tensorboard_writer = None if shared.opts.training_enable_tensorboard: - tensorboard_writer = tensorboard_setup(log_directory) + try: + tensorboard_writer = tensorboard_setup(log_directory) + except ImportError: + errors.report("Error initializing tensorboard", exc_info=True) pin_memory = shared.opts.pin_memory @@ -622,7 +626,7 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False) last_saved_image += f", prompt: {preview_text}" - if shared.opts.training_enable_tensorboard and shared.opts.training_tensorboard_save_images: + if tensorboard_writer and shared.opts.training_tensorboard_save_images: tensorboard_add_image(tensorboard_writer, f"Validation at epoch {epoch_num}", image, embedding.step) if save_image_with_stored_embedding and os.path.exists(last_saved_file) and embedding_yet_to_be_embedded: From 48a2a1a437a48cc232725cc813242f98483b7697 Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Sat, 30 Dec 2023 19:44:38 +0200 Subject: [PATCH 0493/1174] Don't wait for 10 minutes for test server to come up --- .github/workflows/run_tests.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/run_tests.yaml b/.github/workflows/run_tests.yaml index cd5c3f868ff..f42e4758e63 100644 --- a/.github/workflows/run_tests.yaml +++ b/.github/workflows/run_tests.yaml @@ -57,7 +57,7 @@ jobs: 2>&1 | tee output.txt & - name: Run tests run: | - wait-for-it --service 127.0.0.1:7860 -t 600 + wait-for-it --service 127.0.0.1:7860 -t 20 python -m pytest -vv --junitxml=test/results.xml --cov . --cov-report=xml --verify-base-url test - name: Kill test server if: always() From 5fbb13e0da8eb2e26bd2c45ec8ffbb2de669ef47 Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Sat, 30 Dec 2023 20:46:44 +0200 Subject: [PATCH 0494/1174] Remove `cleanup_models` code --- modules/initialize.py | 3 --- modules/modelloader.py | 50 ------------------------------------------ 2 files changed, 53 deletions(-) diff --git a/modules/initialize.py b/modules/initialize.py index ac95fc6f00c..4a3cd98cf86 100644 --- a/modules/initialize.py +++ b/modules/initialize.py @@ -54,9 +54,6 @@ def initialize(): initialize_util.configure_sigint_handler() initialize_util.configure_opts_onchange() - from modules import modelloader - modelloader.cleanup_models() - from modules import sd_models sd_models.setup_model() startup_timer.record("setup SD model") diff --git a/modules/modelloader.py b/modules/modelloader.py index f4182559e59..5f7aec3e42a 100644 --- a/modules/modelloader.py +++ b/modules/modelloader.py @@ -2,7 +2,6 @@ import logging import os -import shutil import importlib from urllib.parse import urlparse @@ -10,7 +9,6 @@ from modules import shared from modules.upscaler import Upscaler, UpscalerLanczos, UpscalerNearest, UpscalerNone -from modules.paths import script_path, models_path logger = logging.getLogger(__name__) @@ -96,54 +94,6 @@ def friendly_name(file: str): return model_name -def cleanup_models(): - # This code could probably be more efficient if we used a tuple list or something to store the src/destinations - # and then enumerate that, but this works for now. In the future, it'd be nice to just have every "model" scaler - # somehow auto-register and just do these things... - root_path = script_path - src_path = models_path - dest_path = os.path.join(models_path, "Stable-diffusion") - move_files(src_path, dest_path, ".ckpt") - move_files(src_path, dest_path, ".safetensors") - src_path = os.path.join(root_path, "ESRGAN") - dest_path = os.path.join(models_path, "ESRGAN") - move_files(src_path, dest_path) - src_path = os.path.join(models_path, "BSRGAN") - dest_path = os.path.join(models_path, "ESRGAN") - move_files(src_path, dest_path, ".pth") - src_path = os.path.join(root_path, "gfpgan") - dest_path = os.path.join(models_path, "GFPGAN") - move_files(src_path, dest_path) - src_path = os.path.join(root_path, "SwinIR") - dest_path = os.path.join(models_path, "SwinIR") - move_files(src_path, dest_path) - src_path = os.path.join(root_path, "repositories/latent-diffusion/experiments/pretrained_models/") - dest_path = os.path.join(models_path, "LDSR") - move_files(src_path, dest_path) - - -def move_files(src_path: str, dest_path: str, ext_filter: str = None): - try: - os.makedirs(dest_path, exist_ok=True) - if os.path.exists(src_path): - for file in os.listdir(src_path): - fullpath = os.path.join(src_path, file) - if os.path.isfile(fullpath): - if ext_filter is not None: - if ext_filter not in file: - continue - print(f"Moving {file} from {src_path} to {dest_path}.") - try: - shutil.move(fullpath, dest_path) - except Exception: - pass - if len(os.listdir(src_path)) == 0: - print(f"Removing empty folder: {src_path}") - shutil.rmtree(src_path, True) - except Exception: - pass - - def load_upscalers(): # We can only do this 'magic' method to dynamically load upscalers if they are referenced, # so we'll try to import any _model.py files before looking in __subclasses__ From af050dcaa75ef40b6b1c3da3361f32fe52786aeb Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Sat, 30 Dec 2023 21:05:59 +0200 Subject: [PATCH 0495/1174] Soften Spandrel model-architecture check to just a warning --- modules/modelloader.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/modules/modelloader.py b/modules/modelloader.py index f4182559e59..6b7d697f103 100644 --- a/modules/modelloader.py +++ b/modules/modelloader.py @@ -196,7 +196,9 @@ def load_spandrel_model( import spandrel model = spandrel.ModelLoader(device=device).load_from_file(path) if expected_architecture and model.architecture != expected_architecture: - raise TypeError(f"Model {path} is not a {expected_architecture} model") + logger.warning( + f"Model {path!r} is not a {expected_architecture!r} model (got {model.architecture!r})", + ) if half: model = model.model.half() if dtype: From 393a5b82ba6df06d85f2bf7bbbe0456d3d06115f Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Sat, 30 Dec 2023 21:12:32 +0200 Subject: [PATCH 0496/1174] Correct RealESRGAN expected architecture type to ESRGAN --- modules/realesrgan_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/realesrgan_model.py b/modules/realesrgan_model.py index 2a2be5ad7fe..65f2e880668 100644 --- a/modules/realesrgan_model.py +++ b/modules/realesrgan_model.py @@ -40,7 +40,7 @@ def do_upscale(self, img, path): info.local_data_path, device=self.device, half=(not cmd_opts.no_half and not cmd_opts.upcast_sampling), - expected_architecture="RealESRGAN", + expected_architecture="ESRGAN", # "RealESRGAN" isn't a specific thing for Spandrel ) return upscale_with_model( mod, From 8100e901ab0c5b04d289eebb722c8a653b8beef1 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sat, 30 Dec 2023 22:41:53 +0300 Subject: [PATCH 0497/1174] fix error with RealESRGAN model failing to upscale fp32 image --- modules/upscaler_utils.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/modules/upscaler_utils.py b/modules/upscaler_utils.py index 8bdda51c4c2..39f78a0b7bc 100644 --- a/modules/upscaler_utils.py +++ b/modules/upscaler_utils.py @@ -16,9 +16,13 @@ def upscale_without_tiling(model, img: Image.Image): img = img[:, :, ::-1] img = np.ascontiguousarray(np.transpose(img, (2, 0, 1))) / 255 img = torch.from_numpy(img).float() - img = img.unsqueeze(0).to(devices.device_esrgan) + + model_weight = next(iter(model.parameters())) + img = img.unsqueeze(0).to(device=model_weight.device, dtype=model_weight.dtype) + with torch.no_grad(): output = model(img) + output = output.squeeze().float().cpu().clamp_(0, 1).numpy() output = 255. * np.moveaxis(output, 0, 2) output = output.astype(np.uint8) From bc5ae74c7d8949bab37e260b16e76889b9968099 Mon Sep 17 00:00:00 2001 From: Learwin <6223515+Learwin@users.noreply.github.com> Date: Sat, 30 Dec 2023 21:52:27 +0100 Subject: [PATCH 0498/1174] Added negative prompts to extra networks lora --- .../Lora/ui_edit_user_metadata.py | 14 +++++++-- .../Lora/ui_extra_networks_lora.py | 9 ++++++ javascript/extraNetworks.js | 31 +++++++++++++------ modules/ui_extra_networks.py | 5 ++- 4 files changed, 46 insertions(+), 13 deletions(-) diff --git a/extensions-builtin/Lora/ui_edit_user_metadata.py b/extensions-builtin/Lora/ui_edit_user_metadata.py index c7011909055..f7859b21f07 100644 --- a/extensions-builtin/Lora/ui_edit_user_metadata.py +++ b/extensions-builtin/Lora/ui_edit_user_metadata.py @@ -54,12 +54,14 @@ def __init__(self, ui, tabname, page): self.slider_preferred_weight = None self.edit_notes = None - def save_lora_user_metadata(self, name, desc, sd_version, activation_text, preferred_weight, notes): + def save_lora_user_metadata(self, name, desc, sd_version, activation_text, preferred_weight, negative_text, negative_weight, notes): user_metadata = self.get_user_metadata(name) user_metadata["description"] = desc user_metadata["sd version"] = sd_version user_metadata["activation text"] = activation_text user_metadata["preferred weight"] = preferred_weight + user_metadata["negative text"] = negative_text + user_metadata["negative weight"] = negative_weight user_metadata["notes"] = notes self.write_user_metadata(name, user_metadata) @@ -127,6 +129,8 @@ def put_values_into_components(self, name): gr.HighlightedText.update(value=gradio_tags, visible=True if tags else False), user_metadata.get('activation text', ''), float(user_metadata.get('preferred weight', 0.0)), + user_metadata.get('negative text', ''), + float(user_metadata.get('negative weight', 0.0)), gr.update(visible=True if tags else False), gr.update(value=self.generate_random_prompt_from_tags(tags), visible=True if tags else False), ] @@ -162,7 +166,8 @@ def create_editor(self): self.taginfo = gr.HighlightedText(label="Training dataset tags") self.edit_activation_text = gr.Text(label='Activation text', info="Will be added to prompt along with Lora") self.slider_preferred_weight = gr.Slider(label='Preferred weight', info="Set to 0 to disable", minimum=0.0, maximum=2.0, step=0.01) - + self.edit_negative_text = gr.Text(label='Negative prompt', info="Will be added to negative prompts") + self.slider_negative_weight = gr.Slider(label='Preferred negative weight', info="Set to 0 to disable", minimum=0.0, maximum=2.0, step=0.01) with gr.Row() as row_random_prompt: with gr.Column(scale=8): random_prompt = gr.Textbox(label='Random prompt', lines=4, max_lines=4, interactive=False) @@ -198,6 +203,8 @@ def select_tag(activation_text, evt: gr.SelectData): self.taginfo, self.edit_activation_text, self.slider_preferred_weight, + self.edit_negative_text, + self.slider_negative_weight, row_random_prompt, random_prompt, ] @@ -211,7 +218,10 @@ def select_tag(activation_text, evt: gr.SelectData): self.select_sd_version, self.edit_activation_text, self.slider_preferred_weight, + self.edit_negative_text, + self.slider_negative_weight, self.edit_notes, ] + self.setup_save_handler(self.button_save, self.save_lora_user_metadata, edited_components) diff --git a/extensions-builtin/Lora/ui_extra_networks_lora.py b/extensions-builtin/Lora/ui_extra_networks_lora.py index df02c663b12..09ce2a0578e 100644 --- a/extensions-builtin/Lora/ui_extra_networks_lora.py +++ b/extensions-builtin/Lora/ui_extra_networks_lora.py @@ -45,6 +45,15 @@ def create_item(self, name, index=None, enable_filter=True): if activation_text: item["prompt"] += " + " + quote_js(" " + activation_text) + negative_prompt = item["user_metadata"].get("negative text") + preferred_negative_weight = item["user_metadata"].get("negative weight") + item["negative_prompt"] = quote_js("") + if negative_prompt: + neg_prompt = negative_prompt + if (preferred_negative_weight > 0): + neg_prompt = '(' + negative_prompt + ':' + str(preferred_negative_weight) + ')' + item["negative_prompt"] = quote_js(neg_prompt) + sd_version = item["user_metadata"].get("sd version") if sd_version in network.SdVersion.__members__: item["sd_version"] = sd_version diff --git a/javascript/extraNetworks.js b/javascript/extraNetworks.js index 98a7abb745c..2bb9795dca4 100644 --- a/javascript/extraNetworks.js +++ b/javascript/extraNetworks.js @@ -185,8 +185,10 @@ onUiLoaded(setupExtraNetworks); var re_extranet = /<([^:^>]+:[^:]+):[\d.]+>(.*)/; var re_extranet_g = /<([^:^>]+:[^:]+):[\d.]+>/g; -function tryToRemoveExtraNetworkFromPrompt(textarea, text) { - var m = text.match(re_extranet); +var re_extranet_neg = /\(([^:^>]+:[\d.]+)\)/; +var re_extranet_g_neg = /\(([^:^>]+:[\d.]+)\)/g; +function tryToRemoveExtraNetworkFromPrompt(textarea, text, isNeg) { + var m = text.match(isNeg ? re_extranet_neg : re_extranet); var replaced = false; var newTextareaText; if (m) { @@ -194,8 +196,8 @@ function tryToRemoveExtraNetworkFromPrompt(textarea, text) { var extraTextAfterNet = m[2]; var partToSearch = m[1]; var foundAtPosition = -1; - newTextareaText = textarea.value.replaceAll(re_extranet_g, function(found, net, pos) { - m = found.match(re_extranet); + newTextareaText = textarea.value.replaceAll(isNeg ? re_extranet_g_neg : re_extranet_g, function(found, net, pos) { + m = found.match(isNeg ? re_extranet_neg : re_extranet); if (m[1] == partToSearch) { replaced = true; foundAtPosition = pos; @@ -205,7 +207,7 @@ function tryToRemoveExtraNetworkFromPrompt(textarea, text) { }); if (foundAtPosition >= 0) { - if (newTextareaText.substr(foundAtPosition, extraTextAfterNet.length) == extraTextAfterNet) { + if (extraTextAfterNet && newTextareaText.substr(foundAtPosition, extraTextAfterNet.length) == extraTextAfterNet) { newTextareaText = newTextareaText.substr(0, foundAtPosition) + newTextareaText.substr(foundAtPosition + extraTextAfterNet.length); } if (newTextareaText.substr(foundAtPosition - extraTextBeforeNet.length, extraTextBeforeNet.length) == extraTextBeforeNet) { @@ -230,14 +232,23 @@ function tryToRemoveExtraNetworkFromPrompt(textarea, text) { return false; } -function cardClicked(tabname, textToAdd, allowNegativePrompt) { - var textarea = allowNegativePrompt ? activePromptTextarea[tabname] : gradioApp().querySelector("#" + tabname + "_prompt > label > textarea"); +function updatePromptArea(text, textArea, isNeg) { - if (!tryToRemoveExtraNetworkFromPrompt(textarea, textToAdd)) { - textarea.value = textarea.value + opts.extra_networks_add_text_separator + textToAdd; + if (!tryToRemoveExtraNetworkFromPrompt(textArea, text, isNeg)) { + textArea.value = textArea.value + opts.extra_networks_add_text_separator + text; } - updateInput(textarea); + updateInput(textArea); +} + +function cardClicked(tabname, textToAdd, textToAddNegative, allowNegativePrompt) { + if (textToAddNegative.length > 0) { + updatePromptArea(textToAdd, gradioApp().querySelector("#" + tabname + "_prompt > label > textarea")) + updatePromptArea(textToAddNegative, gradioApp().querySelector("#" + tabname + "_neg_prompt > label > textarea"), true) + } else { + var textarea = allowNegativePrompt ? activePromptTextarea[tabname] : gradioApp().querySelector("#" + tabname + "_prompt > label > textarea"); + updatePromptArea(textToAdd, textarea) + } } function saveCardPreview(event, tabname, filename) { diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index fe5d3ba3338..b8c02241365 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -223,7 +223,10 @@ def create_html_for_item(self, item, tabname): onclick = item.get("onclick", None) if onclick is None: - onclick = '"' + html.escape(f"""return cardClicked({quote_js(tabname)}, {item["prompt"]}, {"true" if self.allow_negative_prompt else "false"})""") + '"' + if "negative_prompt" in item: + onclick = '"' + html.escape(f"""return cardClicked({quote_js(tabname)}, {item["prompt"]}, {item["negative_prompt"]}, {"true" if self.allow_negative_prompt else "false"})""") + '"' + else: + onclick = '"' + html.escape(f"""return cardClicked({quote_js(tabname)}, {item["prompt"]}, {'""'}, {"true" if self.allow_negative_prompt else "false"})""") + '"' height = f"height: {shared.opts.extra_networks_card_height}px;" if shared.opts.extra_networks_card_height else '' width = f"width: {shared.opts.extra_networks_card_width}px;" if shared.opts.extra_networks_card_width else '' From a2f23f9d22dde87bf2529dcb2854a6a5d3d44278 Mon Sep 17 00:00:00 2001 From: Learwin <6223515+Learwin@users.noreply.github.com> Date: Sat, 30 Dec 2023 22:16:51 +0100 Subject: [PATCH 0499/1174] Code Style fixes --- extensions-builtin/Lora/ui_extra_networks_lora.py | 4 ++-- javascript/extraNetworks.js | 6 +++--- modules/upscaler_utils.py | 2 +- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/extensions-builtin/Lora/ui_extra_networks_lora.py b/extensions-builtin/Lora/ui_extra_networks_lora.py index 09ce2a0578e..9a6624e3f1e 100644 --- a/extensions-builtin/Lora/ui_extra_networks_lora.py +++ b/extensions-builtin/Lora/ui_extra_networks_lora.py @@ -52,8 +52,8 @@ def create_item(self, name, index=None, enable_filter=True): neg_prompt = negative_prompt if (preferred_negative_weight > 0): neg_prompt = '(' + negative_prompt + ':' + str(preferred_negative_weight) + ')' - item["negative_prompt"] = quote_js(neg_prompt) - + item["negative_prompt"] = quote_js(neg_prompt) + sd_version = item["user_metadata"].get("sd version") if sd_version in network.SdVersion.__members__: item["sd_version"] = sd_version diff --git a/javascript/extraNetworks.js b/javascript/extraNetworks.js index 2bb9795dca4..f1ad19a66b7 100644 --- a/javascript/extraNetworks.js +++ b/javascript/extraNetworks.js @@ -243,11 +243,11 @@ function updatePromptArea(text, textArea, isNeg) { function cardClicked(tabname, textToAdd, textToAddNegative, allowNegativePrompt) { if (textToAddNegative.length > 0) { - updatePromptArea(textToAdd, gradioApp().querySelector("#" + tabname + "_prompt > label > textarea")) - updatePromptArea(textToAddNegative, gradioApp().querySelector("#" + tabname + "_neg_prompt > label > textarea"), true) + updatePromptArea(textToAdd, gradioApp().querySelector("#" + tabname + "_prompt > label > textarea")); + updatePromptArea(textToAddNegative, gradioApp().querySelector("#" + tabname + "_neg_prompt > label > textarea"), true); } else { var textarea = allowNegativePrompt ? activePromptTextarea[tabname] : gradioApp().querySelector("#" + tabname + "_prompt > label > textarea"); - updatePromptArea(textToAdd, textarea) + updatePromptArea(textToAdd, textarea); } } diff --git a/modules/upscaler_utils.py b/modules/upscaler_utils.py index 39f78a0b7bc..1d610dbffd9 100644 --- a/modules/upscaler_utils.py +++ b/modules/upscaler_utils.py @@ -6,7 +6,7 @@ import tqdm from PIL import Image -from modules import devices, images +from modules import images logger = logging.getLogger(__name__) From 3be90740316f8fbb950b31d440458a5e8ed4beb3 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sun, 31 Dec 2023 00:43:41 +0300 Subject: [PATCH 0500/1174] fix for the previous fix. --- modules/upscaler_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/upscaler_utils.py b/modules/upscaler_utils.py index 39f78a0b7bc..dde5d7ad43a 100644 --- a/modules/upscaler_utils.py +++ b/modules/upscaler_utils.py @@ -17,7 +17,7 @@ def upscale_without_tiling(model, img: Image.Image): img = np.ascontiguousarray(np.transpose(img, (2, 0, 1))) / 255 img = torch.from_numpy(img).float() - model_weight = next(iter(model.parameters())) + model_weight = next(iter(model.model.parameters())) img = img.unsqueeze(0).to(device=model_weight.device, dtype=model_weight.dtype) with torch.no_grad(): From c0ca6348e8489651df861a101142805c213c66a0 Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Sun, 31 Dec 2023 00:04:47 +0200 Subject: [PATCH 0501/1174] load_spandrel_model: always return a model descriptor --- modules/modelloader.py | 23 +++++++++++++---------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/modules/modelloader.py b/modules/modelloader.py index 0b89d682c55..8bcee08c155 100644 --- a/modules/modelloader.py +++ b/modules/modelloader.py @@ -1,8 +1,9 @@ from __future__ import annotations +import importlib import logging import os -import importlib +from typing import TYPE_CHECKING from urllib.parse import urlparse import torch @@ -10,6 +11,8 @@ from modules import shared from modules.upscaler import Upscaler, UpscalerLanczos, UpscalerNearest, UpscalerNone +if TYPE_CHECKING: + import spandrel logger = logging.getLogger(__name__) @@ -142,17 +145,17 @@ def load_spandrel_model( half: bool = False, dtype: str | None = None, expected_architecture: str | None = None, -): +) -> spandrel.ModelDescriptor: import spandrel - model = spandrel.ModelLoader(device=device).load_from_file(path) - if expected_architecture and model.architecture != expected_architecture: + model_descriptor = spandrel.ModelLoader(device=device).load_from_file(path) + if expected_architecture and model_descriptor.architecture != expected_architecture: logger.warning( - f"Model {path!r} is not a {expected_architecture!r} model (got {model.architecture!r})", + f"Model {path!r} is not a {expected_architecture!r} model (got {model_descriptor.architecture!r})", ) if half: - model = model.model.half() + model_descriptor.model.half() if dtype: - model = model.model.to(dtype=dtype) - model.eval() - logger.debug("Loaded %s from %s (device=%s, half=%s, dtype=%s)", model, path, device, half, dtype) - return model + model_descriptor.model.to(dtype=dtype) + model_descriptor.model.eval() + logger.debug("Loaded %s from %s (device=%s, half=%s, dtype=%s)", model_descriptor, path, device, half, dtype) + return model_descriptor From 777af661a21821994993df3ef566b01df2bb61a0 Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Sun, 31 Dec 2023 00:09:51 +0200 Subject: [PATCH 0502/1174] Be more clear about Spandrel model nomenclature --- extensions-builtin/SwinIR/scripts/swinir_model.py | 6 +++--- modules/gfpgan_model.py | 10 ++++++---- modules/modelloader.py | 2 +- modules/realesrgan_model.py | 4 ++-- modules/upscaler_utils.py | 2 +- 5 files changed, 13 insertions(+), 11 deletions(-) diff --git a/extensions-builtin/SwinIR/scripts/swinir_model.py b/extensions-builtin/SwinIR/scripts/swinir_model.py index aae159af56e..95c7ec648ef 100644 --- a/extensions-builtin/SwinIR/scripts/swinir_model.py +++ b/extensions-builtin/SwinIR/scripts/swinir_model.py @@ -71,7 +71,7 @@ def load_model(self, path, scale=4): else: filename = path - model = modelloader.load_spandrel_model( + model_descriptor = modelloader.load_spandrel_model( filename, device=self._get_device(), dtype=devices.dtype, @@ -79,10 +79,10 @@ def load_model(self, path, scale=4): ) if getattr(opts, 'SWIN_torch_compile', False): try: - model = torch.compile(model) + model_descriptor.model.compile() except Exception: logger.warning("Failed to compile SwinIR model, fallback to JIT", exc_info=True) - return model + return model_descriptor def _get_device(self): return devices.get_device_for('swinir') diff --git a/modules/gfpgan_model.py b/modules/gfpgan_model.py index 48f8ad5e294..445b040925e 100644 --- a/modules/gfpgan_model.py +++ b/modules/gfpgan_model.py @@ -3,6 +3,8 @@ import logging import os +import torch + from modules import ( devices, errors, @@ -25,7 +27,7 @@ def name(self): def get_device(self): return devices.device_gfpgan - def load_net(self) -> None: + def load_net(self) -> torch.Module: for model_path in modelloader.load_models( model_path=self.model_path, model_url=model_url, @@ -34,13 +36,13 @@ def load_net(self) -> None: ext_filter=['.pth'], ): if 'GFPGAN' in os.path.basename(model_path): - net = modelloader.load_spandrel_model( + model = modelloader.load_spandrel_model( model_path, device=self.get_device(), expected_architecture='GFPGAN', ).model - net.different_w = True # see https://github.com/chaiNNer-org/spandrel/pull/81 - return net + model.different_w = True # see https://github.com/chaiNNer-org/spandrel/pull/81 + return model raise ValueError("No GFPGAN model found") def restore(self, np_image): diff --git a/modules/modelloader.py b/modules/modelloader.py index 8bcee08c155..a7194137571 100644 --- a/modules/modelloader.py +++ b/modules/modelloader.py @@ -143,7 +143,7 @@ def load_spandrel_model( *, device: str | torch.device | None, half: bool = False, - dtype: str | None = None, + dtype: str | torch.dtype | None = None, expected_architecture: str | None = None, ) -> spandrel.ModelDescriptor: import spandrel diff --git a/modules/realesrgan_model.py b/modules/realesrgan_model.py index 65f2e880668..4d35b695c3b 100644 --- a/modules/realesrgan_model.py +++ b/modules/realesrgan_model.py @@ -36,14 +36,14 @@ def do_upscale(self, img, path): errors.report(f"Unable to load RealESRGAN model {path}", exc_info=True) return img - mod = modelloader.load_spandrel_model( + model_descriptor = modelloader.load_spandrel_model( info.local_data_path, device=self.device, half=(not cmd_opts.no_half and not cmd_opts.upcast_sampling), expected_architecture="ESRGAN", # "RealESRGAN" isn't a specific thing for Spandrel ) return upscale_with_model( - mod, + model_descriptor, img, tile_size=opts.ESRGAN_tile, tile_overlap=opts.ESRGAN_tile_overlap, diff --git a/modules/upscaler_utils.py b/modules/upscaler_utils.py index dde5d7ad43a..174c9bc3713 100644 --- a/modules/upscaler_utils.py +++ b/modules/upscaler_utils.py @@ -6,7 +6,7 @@ import tqdm from PIL import Image -from modules import devices, images +from modules import images logger = logging.getLogger(__name__) From 6f86b62a1be7993073ba3a789d522e0b8870605a Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Sat, 30 Dec 2023 22:53:49 +0200 Subject: [PATCH 0503/1174] Deduplicate tiled inference code from SwinIR/ScuNET --- .../ScuNET/scripts/scunet_model.py | 55 +++----------- .../SwinIR/scripts/swinir_model.py | 57 ++------------- modules/upscaler_utils.py | 72 ++++++++++++++++++- 3 files changed, 87 insertions(+), 97 deletions(-) diff --git a/extensions-builtin/ScuNET/scripts/scunet_model.py b/extensions-builtin/ScuNET/scripts/scunet_model.py index 5f3dd08b337..f799cb76dec 100644 --- a/extensions-builtin/ScuNET/scripts/scunet_model.py +++ b/extensions-builtin/ScuNET/scripts/scunet_model.py @@ -3,12 +3,11 @@ import PIL.Image import numpy as np import torch -from tqdm import tqdm import modules.upscaler from modules import devices, modelloader, script_callbacks, errors - from modules.shared import opts +from modules.upscaler_utils import tiled_upscale_2 class UpscalerScuNET(modules.upscaler.Upscaler): @@ -40,47 +39,6 @@ def __init__(self, dirname): scalers.append(scaler_data2) self.scalers = scalers - @staticmethod - @torch.no_grad() - def tiled_inference(img, model): - # test the image tile by tile - h, w = img.shape[2:] - tile = opts.SCUNET_tile - tile_overlap = opts.SCUNET_tile_overlap - if tile == 0: - return model(img) - - device = devices.get_device_for('scunet') - assert tile % 8 == 0, "tile size should be a multiple of window_size" - sf = 1 - - stride = tile - tile_overlap - h_idx_list = list(range(0, h - tile, stride)) + [h - tile] - w_idx_list = list(range(0, w - tile, stride)) + [w - tile] - E = torch.zeros(1, 3, h * sf, w * sf, dtype=img.dtype, device=device) - W = torch.zeros_like(E, dtype=devices.dtype, device=device) - - with tqdm(total=len(h_idx_list) * len(w_idx_list), desc="ScuNET tiles") as pbar: - for h_idx in h_idx_list: - - for w_idx in w_idx_list: - - in_patch = img[..., h_idx: h_idx + tile, w_idx: w_idx + tile] - - out_patch = model(in_patch) - out_patch_mask = torch.ones_like(out_patch) - - E[ - ..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf - ].add_(out_patch) - W[ - ..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf - ].add_(out_patch_mask) - pbar.update(1) - output = E.div_(W) - - return output - def do_upscale(self, img: PIL.Image.Image, selected_file): devices.torch_gc() @@ -104,7 +62,16 @@ def do_upscale(self, img: PIL.Image.Image, selected_file): _img[:, :, :h, :w] = torch_img # pad image torch_img = _img - torch_output = self.tiled_inference(torch_img, model).squeeze(0) + with torch.no_grad(): + torch_output = tiled_upscale_2( + torch_img, + model, + tile_size=opts.SCUNET_tile, + tile_overlap=opts.SCUNET_tile_overlap, + scale=1, + device=devices.get_device_for('scunet'), + desc="ScuNET tiles", + ).squeeze(0) torch_output = torch_output[:, :h * 1, :w * 1] # remove padding, if any np_output: np.ndarray = torch_output.float().cpu().clamp_(0, 1).numpy() del torch_img, torch_output diff --git a/extensions-builtin/SwinIR/scripts/swinir_model.py b/extensions-builtin/SwinIR/scripts/swinir_model.py index 95c7ec648ef..8a555c794f7 100644 --- a/extensions-builtin/SwinIR/scripts/swinir_model.py +++ b/extensions-builtin/SwinIR/scripts/swinir_model.py @@ -4,11 +4,11 @@ import numpy as np import torch from PIL import Image -from tqdm import tqdm from modules import modelloader, devices, script_callbacks, shared -from modules.shared import opts, state +from modules.shared import opts from modules.upscaler import Upscaler, UpscalerData +from modules.upscaler_utils import tiled_upscale_2 SWINIR_MODEL_URL = "https://github.com/JingyunLiang/SwinIR/releases/download/v0.0/003_realSR_BSRGAN_DFOWMFC_s64w8_SwinIR-L_x4_GAN.pth" @@ -110,14 +110,14 @@ def upscale( w_pad = (w_old // window_size + 1) * window_size - w_old img = torch.cat([img, torch.flip(img, [2])], 2)[:, :, : h_old + h_pad, :] img = torch.cat([img, torch.flip(img, [3])], 3)[:, :, :, : w_old + w_pad] - output = inference( + output = tiled_upscale_2( img, model, - tile=tile, + tile_size=tile, tile_overlap=tile_overlap, - window_size=window_size, scale=scale, device=device, + desc="SwinIR tiles", ) output = output[..., : h_old * scale, : w_old * scale] output = output.data.squeeze().float().cpu().clamp_(0, 1).numpy() @@ -129,53 +129,6 @@ def upscale( return Image.fromarray(output, "RGB") -def inference( - img, - model, - *, - tile: int, - tile_overlap: int, - window_size: int, - scale: int, - device, -): - # test the image tile by tile - b, c, h, w = img.size() - tile = min(tile, h, w) - assert tile % window_size == 0, "tile size should be a multiple of window_size" - sf = scale - - stride = tile - tile_overlap - h_idx_list = list(range(0, h - tile, stride)) + [h - tile] - w_idx_list = list(range(0, w - tile, stride)) + [w - tile] - E = torch.zeros(b, c, h * sf, w * sf, dtype=devices.dtype, device=device).type_as(img) - W = torch.zeros_like(E, dtype=devices.dtype, device=device) - - with tqdm(total=len(h_idx_list) * len(w_idx_list), desc="SwinIR tiles") as pbar: - for h_idx in h_idx_list: - if state.interrupted or state.skipped: - break - - for w_idx in w_idx_list: - if state.interrupted or state.skipped: - break - - in_patch = img[..., h_idx: h_idx + tile, w_idx: w_idx + tile] - out_patch = model(in_patch) - out_patch_mask = torch.ones_like(out_patch) - - E[ - ..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf - ].add_(out_patch) - W[ - ..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf - ].add_(out_patch_mask) - pbar.update(1) - output = E.div_(W) - - return output - - def on_ui_settings(): import gradio as gr diff --git a/modules/upscaler_utils.py b/modules/upscaler_utils.py index 174c9bc3713..8e413854303 100644 --- a/modules/upscaler_utils.py +++ b/modules/upscaler_utils.py @@ -6,7 +6,7 @@ import tqdm from PIL import Image -from modules import images +from modules import images, shared logger = logging.getLogger(__name__) @@ -68,3 +68,73 @@ def upscale_with_model( overlap=grid.overlap * scale_factor, ) return images.combine_grid(newgrid) + + +def tiled_upscale_2( + img, + model, + *, + tile_size: int, + tile_overlap: int, + scale: int, + device, + desc="Tiled upscale", +): + # Alternative implementation of `upscale_with_model` originally used by + # SwinIR and ScuNET. It differs from `upscale_with_model` in that tiling and + # weighting is done in PyTorch space, as opposed to `images.Grid` doing it in + # Pillow space without weighting. + b, c, h, w = img.size() + tile_size = min(tile_size, h, w) + + if tile_size <= 0: + logger.debug("Upscaling %s without tiling", img.shape) + return model(img) + + stride = tile_size - tile_overlap + h_idx_list = list(range(0, h - tile_size, stride)) + [h - tile_size] + w_idx_list = list(range(0, w - tile_size, stride)) + [w - tile_size] + result = torch.zeros( + b, + c, + h * scale, + w * scale, + device=device, + ).type_as(img) + weights = torch.zeros_like(result) + logger.debug("Upscaling %s to %s with tiles", img.shape, result.shape) + with tqdm.tqdm(total=len(h_idx_list) * len(w_idx_list), desc=desc) as pbar: + for h_idx in h_idx_list: + if shared.state.interrupted or shared.state.skipped: + break + + for w_idx in w_idx_list: + if shared.state.interrupted or shared.state.skipped: + break + + in_patch = img[ + ..., + h_idx : h_idx + tile_size, + w_idx : w_idx + tile_size, + ] + out_patch = model(in_patch) + + result[ + ..., + h_idx * scale : (h_idx + tile_size) * scale, + w_idx * scale : (w_idx + tile_size) * scale, + ].add_(out_patch) + + out_patch_mask = torch.ones_like(out_patch) + + weights[ + ..., + h_idx * scale : (h_idx + tile_size) * scale, + w_idx * scale : (w_idx + tile_size) * scale, + ].add_(out_patch_mask) + + pbar.update(1) + + output = result.div_(weights) + + return output From 5768afc776a66bb94e77a9c1daebeea58fa731d5 Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Sun, 31 Dec 2023 00:20:30 +0200 Subject: [PATCH 0504/1174] Add utility to inspect a model's parameters (to get dtype/device) --- modules/devices.py | 3 ++- modules/interrogate.py | 3 ++- modules/sd_models_xl.py | 3 ++- modules/torch_utils.py | 17 +++++++++++++++++ modules/upscaler_utils.py | 5 +++-- modules/xlmr.py | 5 ++++- modules/xlmr_m18.py | 5 ++++- test/test_torch_utils.py | 19 +++++++++++++++++++ 8 files changed, 53 insertions(+), 7 deletions(-) create mode 100644 modules/torch_utils.py create mode 100644 test/test_torch_utils.py diff --git a/modules/devices.py b/modules/devices.py index c956207f334..bd6bd579b1a 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -4,6 +4,7 @@ import torch from modules import errors, shared +from modules.torch_utils import get_param if sys.platform == "darwin": from modules import mac_specific @@ -131,7 +132,7 @@ def cond_cast_float(input): def manual_cast_forward(self, *args, **kwargs): - org_dtype = next(self.parameters()).dtype + org_dtype = get_param(self).dtype self.to(dtype) args = [arg.to(dtype) if isinstance(arg, torch.Tensor) else arg for arg in args] kwargs = {k: v.to(dtype) if isinstance(v, torch.Tensor) else v for k, v in kwargs.items()} diff --git a/modules/interrogate.py b/modules/interrogate.py index 3045560d0ae..5be5a10f39c 100644 --- a/modules/interrogate.py +++ b/modules/interrogate.py @@ -11,6 +11,7 @@ from torchvision.transforms.functional import InterpolationMode from modules import devices, paths, shared, lowvram, modelloader, errors +from modules.torch_utils import get_param blip_image_eval_size = 384 clip_model_name = 'ViT-L/14' @@ -131,7 +132,7 @@ def load(self): self.clip_model = self.clip_model.to(devices.device_interrogate) - self.dtype = next(self.clip_model.parameters()).dtype + self.dtype = get_param(self.clip_model).dtype def send_clip_to_ram(self): if not shared.opts.interrogate_keep_models_in_memory: diff --git a/modules/sd_models_xl.py b/modules/sd_models_xl.py index 1de31b0da19..c3602a7e1a6 100644 --- a/modules/sd_models_xl.py +++ b/modules/sd_models_xl.py @@ -6,6 +6,7 @@ import sgm.modules.diffusionmodules.denoiser_scaling import sgm.modules.diffusionmodules.discretizer from modules import devices, shared, prompt_parser +from modules.torch_utils import get_param def get_learned_conditioning(self: sgm.models.diffusion.DiffusionEngine, batch: prompt_parser.SdConditioning | list[str]): @@ -90,7 +91,7 @@ def get_target_prompt_token_count(self, token_count): def extend_sdxl(model): """this adds a bunch of parameters to make SDXL model look a bit more like SD1.5 to the rest of the codebase.""" - dtype = next(model.model.diffusion_model.parameters()).dtype + dtype = get_param(model.model.diffusion_model).dtype model.model.diffusion_model.dtype = dtype model.model.conditioning_key = 'crossattn' model.cond_stage_key = 'txt' diff --git a/modules/torch_utils.py b/modules/torch_utils.py new file mode 100644 index 00000000000..e5b52393ec8 --- /dev/null +++ b/modules/torch_utils.py @@ -0,0 +1,17 @@ +from __future__ import annotations + +import torch.nn + + +def get_param(model) -> torch.nn.Parameter: + """ + Find the first parameter in a model or module. + """ + if hasattr(model, "model") and hasattr(model.model, "parameters"): + # Unpeel a model descriptor to get at the actual Torch module. + model = model.model + + for param in model.parameters(): + return param + + raise ValueError(f"No parameters found in model {model!r}") diff --git a/modules/upscaler_utils.py b/modules/upscaler_utils.py index 8e413854303..c60e3beb89d 100644 --- a/modules/upscaler_utils.py +++ b/modules/upscaler_utils.py @@ -7,6 +7,7 @@ from PIL import Image from modules import images, shared +from modules.torch_utils import get_param logger = logging.getLogger(__name__) @@ -17,8 +18,8 @@ def upscale_without_tiling(model, img: Image.Image): img = np.ascontiguousarray(np.transpose(img, (2, 0, 1))) / 255 img = torch.from_numpy(img).float() - model_weight = next(iter(model.model.parameters())) - img = img.unsqueeze(0).to(device=model_weight.device, dtype=model_weight.dtype) + param = get_param(model) + img = img.unsqueeze(0).to(device=param.device, dtype=param.dtype) with torch.no_grad(): output = model(img) diff --git a/modules/xlmr.py b/modules/xlmr.py index a407a3cade8..6e000a56e75 100644 --- a/modules/xlmr.py +++ b/modules/xlmr.py @@ -5,6 +5,9 @@ from transformers import XLMRobertaModel,XLMRobertaTokenizer from typing import Optional +from modules.torch_utils import get_param + + class BertSeriesConfig(BertConfig): def __init__(self, vocab_size=30522, hidden_size=768, num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072, hidden_act="gelu", hidden_dropout_prob=0.1, attention_probs_dropout_prob=0.1, max_position_embeddings=512, type_vocab_size=2, initializer_range=0.02, layer_norm_eps=1e-12, pad_token_id=0, position_embedding_type="absolute", use_cache=True, classifier_dropout=None,project_dim=512, pooler_fn="average",learn_encoder=False,model_type='bert',**kwargs): @@ -62,7 +65,7 @@ def __init__(self, config=None, **kargs): self.post_init() def encode(self,c): - device = next(self.parameters()).device + device = get_param(self).device text = self.tokenizer(c, truncation=True, max_length=77, diff --git a/modules/xlmr_m18.py b/modules/xlmr_m18.py index a727e865529..e3e8196108c 100644 --- a/modules/xlmr_m18.py +++ b/modules/xlmr_m18.py @@ -5,6 +5,9 @@ from transformers import XLMRobertaModel,XLMRobertaTokenizer from typing import Optional +from modules.torch_utils import get_param + + class BertSeriesConfig(BertConfig): def __init__(self, vocab_size=30522, hidden_size=768, num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072, hidden_act="gelu", hidden_dropout_prob=0.1, attention_probs_dropout_prob=0.1, max_position_embeddings=512, type_vocab_size=2, initializer_range=0.02, layer_norm_eps=1e-12, pad_token_id=0, position_embedding_type="absolute", use_cache=True, classifier_dropout=None,project_dim=512, pooler_fn="average",learn_encoder=False,model_type='bert',**kwargs): @@ -68,7 +71,7 @@ def __init__(self, config=None, **kargs): self.post_init() def encode(self,c): - device = next(self.parameters()).device + device = get_param(self).device text = self.tokenizer(c, truncation=True, max_length=77, diff --git a/test/test_torch_utils.py b/test/test_torch_utils.py new file mode 100644 index 00000000000..f1aec832945 --- /dev/null +++ b/test/test_torch_utils.py @@ -0,0 +1,19 @@ +import types + +import pytest +import torch + +from modules.torch_utils import get_param + + +@pytest.mark.parametrize("wrapped", [True, False]) +def test_get_param(wrapped): + mod = torch.nn.Linear(1, 1) + cpu = torch.device("cpu") + mod.to(dtype=torch.float16, device=cpu) + if wrapped: + # more or less how spandrel wraps a thing + mod = types.SimpleNamespace(model=mod) + p = get_param(mod) + assert p.dtype == torch.float16 + assert p.device == cpu From d4945f4422e5a0bf31a6dbe4c1aeedd78c09eacb Mon Sep 17 00:00:00 2001 From: Learwin <6223515+Learwin@users.noreply.github.com> Date: Sun, 31 Dec 2023 13:22:30 +0100 Subject: [PATCH 0505/1174] Removed weight slider for negative prompts --- extensions-builtin/Lora/ui_edit_user_metadata.py | 7 +------ extensions-builtin/Lora/ui_extra_networks_lora.py | 6 +----- 2 files changed, 2 insertions(+), 11 deletions(-) diff --git a/extensions-builtin/Lora/ui_edit_user_metadata.py b/extensions-builtin/Lora/ui_edit_user_metadata.py index f7859b21f07..3160aecfa38 100644 --- a/extensions-builtin/Lora/ui_edit_user_metadata.py +++ b/extensions-builtin/Lora/ui_edit_user_metadata.py @@ -54,14 +54,13 @@ def __init__(self, ui, tabname, page): self.slider_preferred_weight = None self.edit_notes = None - def save_lora_user_metadata(self, name, desc, sd_version, activation_text, preferred_weight, negative_text, negative_weight, notes): + def save_lora_user_metadata(self, name, desc, sd_version, activation_text, preferred_weight, negative_text, notes): user_metadata = self.get_user_metadata(name) user_metadata["description"] = desc user_metadata["sd version"] = sd_version user_metadata["activation text"] = activation_text user_metadata["preferred weight"] = preferred_weight user_metadata["negative text"] = negative_text - user_metadata["negative weight"] = negative_weight user_metadata["notes"] = notes self.write_user_metadata(name, user_metadata) @@ -130,7 +129,6 @@ def put_values_into_components(self, name): user_metadata.get('activation text', ''), float(user_metadata.get('preferred weight', 0.0)), user_metadata.get('negative text', ''), - float(user_metadata.get('negative weight', 0.0)), gr.update(visible=True if tags else False), gr.update(value=self.generate_random_prompt_from_tags(tags), visible=True if tags else False), ] @@ -167,7 +165,6 @@ def create_editor(self): self.edit_activation_text = gr.Text(label='Activation text', info="Will be added to prompt along with Lora") self.slider_preferred_weight = gr.Slider(label='Preferred weight', info="Set to 0 to disable", minimum=0.0, maximum=2.0, step=0.01) self.edit_negative_text = gr.Text(label='Negative prompt', info="Will be added to negative prompts") - self.slider_negative_weight = gr.Slider(label='Preferred negative weight', info="Set to 0 to disable", minimum=0.0, maximum=2.0, step=0.01) with gr.Row() as row_random_prompt: with gr.Column(scale=8): random_prompt = gr.Textbox(label='Random prompt', lines=4, max_lines=4, interactive=False) @@ -204,7 +201,6 @@ def select_tag(activation_text, evt: gr.SelectData): self.edit_activation_text, self.slider_preferred_weight, self.edit_negative_text, - self.slider_negative_weight, row_random_prompt, random_prompt, ] @@ -219,7 +215,6 @@ def select_tag(activation_text, evt: gr.SelectData): self.edit_activation_text, self.slider_preferred_weight, self.edit_negative_text, - self.slider_negative_weight, self.edit_notes, ] diff --git a/extensions-builtin/Lora/ui_extra_networks_lora.py b/extensions-builtin/Lora/ui_extra_networks_lora.py index 9a6624e3f1e..e714fac4692 100644 --- a/extensions-builtin/Lora/ui_extra_networks_lora.py +++ b/extensions-builtin/Lora/ui_extra_networks_lora.py @@ -46,13 +46,9 @@ def create_item(self, name, index=None, enable_filter=True): item["prompt"] += " + " + quote_js(" " + activation_text) negative_prompt = item["user_metadata"].get("negative text") - preferred_negative_weight = item["user_metadata"].get("negative weight") item["negative_prompt"] = quote_js("") if negative_prompt: - neg_prompt = negative_prompt - if (preferred_negative_weight > 0): - neg_prompt = '(' + negative_prompt + ':' + str(preferred_negative_weight) + ')' - item["negative_prompt"] = quote_js(neg_prompt) + item["negative_prompt"] = quote_js('(' + negative_prompt + ':1)') sd_version = item["user_metadata"].get("sd version") if sd_version in network.SdVersion.__members__: From b6f74e936e4de3b8d190bffaf3bed67d6d4bd211 Mon Sep 17 00:00:00 2001 From: Learwin <6223515+Learwin@users.noreply.github.com> Date: Sun, 31 Dec 2023 13:36:36 +0100 Subject: [PATCH 0506/1174] Revert change from linting for unrelated file --- modules/upscaler_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/upscaler_utils.py b/modules/upscaler_utils.py index 1d610dbffd9..39f78a0b7bc 100644 --- a/modules/upscaler_utils.py +++ b/modules/upscaler_utils.py @@ -6,7 +6,7 @@ import tqdm from PIL import Image -from modules import images +from modules import devices, images logger = logging.getLogger(__name__) From a70dfb64a86b9b6d869deffdb0ffebe980365473 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sun, 31 Dec 2023 22:38:30 +0300 Subject: [PATCH 0507/1174] change import statements for #14478 --- modules/devices.py | 4 ++-- modules/interrogate.py | 5 ++--- modules/sd_models_xl.py | 4 ++-- modules/upscaler_utils.py | 5 ++--- modules/xlmr.py | 4 ++-- modules/xlmr_m18.py | 5 ++--- test/test_torch_utils.py | 4 ++-- 7 files changed, 14 insertions(+), 17 deletions(-) diff --git a/modules/devices.py b/modules/devices.py index bd6bd579b1a..ff279ac5016 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -4,7 +4,7 @@ import torch from modules import errors, shared -from modules.torch_utils import get_param +from modules import torch_utils if sys.platform == "darwin": from modules import mac_specific @@ -132,7 +132,7 @@ def cond_cast_float(input): def manual_cast_forward(self, *args, **kwargs): - org_dtype = get_param(self).dtype + org_dtype = torch_utils.get_param(self).dtype self.to(dtype) args = [arg.to(dtype) if isinstance(arg, torch.Tensor) else arg for arg in args] kwargs = {k: v.to(dtype) if isinstance(v, torch.Tensor) else v for k, v in kwargs.items()} diff --git a/modules/interrogate.py b/modules/interrogate.py index 5be5a10f39c..35a627ca196 100644 --- a/modules/interrogate.py +++ b/modules/interrogate.py @@ -10,8 +10,7 @@ from torchvision import transforms from torchvision.transforms.functional import InterpolationMode -from modules import devices, paths, shared, lowvram, modelloader, errors -from modules.torch_utils import get_param +from modules import devices, paths, shared, lowvram, modelloader, errors, torch_utils blip_image_eval_size = 384 clip_model_name = 'ViT-L/14' @@ -132,7 +131,7 @@ def load(self): self.clip_model = self.clip_model.to(devices.device_interrogate) - self.dtype = get_param(self.clip_model).dtype + self.dtype = torch_utils.get_param(self.clip_model).dtype def send_clip_to_ram(self): if not shared.opts.interrogate_keep_models_in_memory: diff --git a/modules/sd_models_xl.py b/modules/sd_models_xl.py index c3602a7e1a6..0de17af3db2 100644 --- a/modules/sd_models_xl.py +++ b/modules/sd_models_xl.py @@ -6,7 +6,7 @@ import sgm.modules.diffusionmodules.denoiser_scaling import sgm.modules.diffusionmodules.discretizer from modules import devices, shared, prompt_parser -from modules.torch_utils import get_param +from modules import torch_utils def get_learned_conditioning(self: sgm.models.diffusion.DiffusionEngine, batch: prompt_parser.SdConditioning | list[str]): @@ -91,7 +91,7 @@ def get_target_prompt_token_count(self, token_count): def extend_sdxl(model): """this adds a bunch of parameters to make SDXL model look a bit more like SD1.5 to the rest of the codebase.""" - dtype = get_param(model.model.diffusion_model).dtype + dtype = torch_utils.get_param(model.model.diffusion_model).dtype model.model.diffusion_model.dtype = dtype model.model.conditioning_key = 'crossattn' model.cond_stage_key = 'txt' diff --git a/modules/upscaler_utils.py b/modules/upscaler_utils.py index c60e3beb89d..f5cb92d51b8 100644 --- a/modules/upscaler_utils.py +++ b/modules/upscaler_utils.py @@ -6,8 +6,7 @@ import tqdm from PIL import Image -from modules import images, shared -from modules.torch_utils import get_param +from modules import images, shared, torch_utils logger = logging.getLogger(__name__) @@ -18,7 +17,7 @@ def upscale_without_tiling(model, img: Image.Image): img = np.ascontiguousarray(np.transpose(img, (2, 0, 1))) / 255 img = torch.from_numpy(img).float() - param = get_param(model) + param = torch_utils.get_param(model) img = img.unsqueeze(0).to(device=param.device, dtype=param.dtype) with torch.no_grad(): diff --git a/modules/xlmr.py b/modules/xlmr.py index 6e000a56e75..319771b7bf0 100644 --- a/modules/xlmr.py +++ b/modules/xlmr.py @@ -5,7 +5,7 @@ from transformers import XLMRobertaModel,XLMRobertaTokenizer from typing import Optional -from modules.torch_utils import get_param +from modules import torch_utils class BertSeriesConfig(BertConfig): @@ -65,7 +65,7 @@ def __init__(self, config=None, **kargs): self.post_init() def encode(self,c): - device = get_param(self).device + device = torch_utils.get_param(self).device text = self.tokenizer(c, truncation=True, max_length=77, diff --git a/modules/xlmr_m18.py b/modules/xlmr_m18.py index e3e8196108c..f60555049f5 100644 --- a/modules/xlmr_m18.py +++ b/modules/xlmr_m18.py @@ -4,8 +4,7 @@ from transformers.models.xlm_roberta.configuration_xlm_roberta import XLMRobertaConfig from transformers import XLMRobertaModel,XLMRobertaTokenizer from typing import Optional - -from modules.torch_utils import get_param +from modules import torch_utils class BertSeriesConfig(BertConfig): @@ -71,7 +70,7 @@ def __init__(self, config=None, **kargs): self.post_init() def encode(self,c): - device = get_param(self).device + device = torch_utils.get_param(self).device text = self.tokenizer(c, truncation=True, max_length=77, diff --git a/test/test_torch_utils.py b/test/test_torch_utils.py index f1aec832945..23ccb93a464 100644 --- a/test/test_torch_utils.py +++ b/test/test_torch_utils.py @@ -3,7 +3,7 @@ import pytest import torch -from modules.torch_utils import get_param +from modules import torch_utils @pytest.mark.parametrize("wrapped", [True, False]) @@ -14,6 +14,6 @@ def test_get_param(wrapped): if wrapped: # more or less how spandrel wraps a thing mod = types.SimpleNamespace(model=mod) - p = get_param(mod) + p = torch_utils.get_param(mod) assert p.dtype == torch.float16 assert p.device == cpu From 00901bfbe0095303554f4440b4c12fac262e2e89 Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Mon, 1 Jan 2024 15:47:57 +0900 Subject: [PATCH 0508/1174] handle selectable script_index is None --- modules/scripts.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/modules/scripts.py b/modules/scripts.py index 3a76691186b..017aed5a01e 100644 --- a/modules/scripts.py +++ b/modules/scripts.py @@ -696,6 +696,8 @@ def setup_ui(self): self.setup_ui_for_section(None, self.selectable_scripts) def select_script(script_index): + if script_index is None: + script_index = 0 selected_script = self.selectable_scripts[script_index - 1] if script_index>0 else None return [gr.update(visible=selected_script == s) for s in self.selectable_scripts] @@ -739,7 +741,7 @@ def onload_script_visibility(params): def run(self, p, *args): script_index = args[0] - if script_index == 0: + if script_index == 0 or script_index is None: return None script = self.selectable_scripts[script_index-1] From 5692bf1517c3409ad46262c56e65f256389825b1 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Mon, 1 Jan 2024 11:11:14 +0300 Subject: [PATCH 0509/1174] add missing field for DDIM sampler that was breaking img2img --- modules/sd_samplers_timesteps.py | 1 + 1 file changed, 1 insertion(+) diff --git a/modules/sd_samplers_timesteps.py b/modules/sd_samplers_timesteps.py index b17a8f93c2b..f8afa8bd7e5 100644 --- a/modules/sd_samplers_timesteps.py +++ b/modules/sd_samplers_timesteps.py @@ -80,6 +80,7 @@ def __init__(self, funcname, sd_model): self.eta_default = 0.0 self.model_wrap_cfg = CFGDenoiserTimesteps(self) + self.model_wrap = self.model_wrap_cfg.inner_model def get_timesteps(self, p, steps): discard_next_to_last_sigma = self.config is not None and self.config.options.get('discard_next_to_last_sigma', False) From 003b91f08361c99ecdd97257624d81a2046d3823 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Mon, 1 Jan 2024 13:45:01 +0300 Subject: [PATCH 0510/1174] rename generation_parameters_copypaste module to infotext --- modules/{generation_parameters_copypaste.py => infotext.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename modules/{generation_parameters_copypaste.py => infotext.py} (100%) diff --git a/modules/generation_parameters_copypaste.py b/modules/infotext.py similarity index 100% rename from modules/generation_parameters_copypaste.py rename to modules/infotext.py From c5496c76461c90bd186ae8804aa65a33cd136d48 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Mon, 1 Jan 2024 13:52:37 +0300 Subject: [PATCH 0511/1174] infotext.py: add support for old modules.generation_parameters_copypaste name --- modules/infotext.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/modules/infotext.py b/modules/infotext.py index 86a36c3273c..bcbeb0fdcd1 100644 --- a/modules/infotext.py +++ b/modules/infotext.py @@ -4,12 +4,15 @@ import json import os import re +import sys import gradio as gr from modules.paths import data_path from modules import shared, ui_tempdir, script_callbacks, processing from PIL import Image +sys.modules['modules.generation_parameters_copypaste'] = sys.modules[__name__] # alias for old name + re_param_code = r'\s*(\w[\w \-/]+):\s*("(?:\\.|[^\\"])+"|[^,]*)(?:,|$)' re_param = re.compile(re_param_code) re_imagesize = re.compile(r"^(\d+)x(\d+)$") From d859cec696a953dbfd6f69f7735e68661748d579 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Mon, 1 Jan 2024 13:53:12 +0300 Subject: [PATCH 0512/1174] infotext.py: rename usages in the codebase --- .../scripts/extra_options_section.py | 4 ++-- modules/api/api.py | 10 +++++----- modules/img2img.py | 2 +- modules/postprocessing.py | 4 ++-- modules/processing.py | 4 ++-- modules/processing_scripts/refiner.py | 2 +- modules/processing_scripts/seed.py | 2 +- modules/shared_items.py | 4 ++-- modules/txt2img.py | 2 +- modules/ui.py | 4 ++-- modules/ui_common.py | 4 ++-- modules/ui_extra_networks.py | 2 +- modules/ui_extra_networks_user_metadata.py | 4 ++-- modules/ui_postprocessing.py | 2 +- 14 files changed, 25 insertions(+), 25 deletions(-) diff --git a/extensions-builtin/extra-options-section/scripts/extra_options_section.py b/extensions-builtin/extra-options-section/scripts/extra_options_section.py index ac2c3de4643..8aa901fd481 100644 --- a/extensions-builtin/extra-options-section/scripts/extra_options_section.py +++ b/extensions-builtin/extra-options-section/scripts/extra_options_section.py @@ -1,7 +1,7 @@ import math import gradio as gr -from modules import scripts, shared, ui_components, ui_settings, generation_parameters_copypaste +from modules import scripts, shared, ui_components, ui_settings, infotext from modules.ui_components import FormColumn @@ -25,7 +25,7 @@ def ui(self, is_img2img): extra_options = shared.opts.extra_options_img2img if is_img2img else shared.opts.extra_options_txt2img elem_id_tabname = "extra_options_" + ("img2img" if is_img2img else "txt2img") - mapping = {k: v for v, k in generation_parameters_copypaste.infotext_to_setting_name_mapping} + mapping = {k: v for v, k in infotext.infotext_to_setting_name_mapping} with gr.Blocks() as interface: with gr.Accordion("Options", open=False, elem_id=elem_id_tabname) if shared.opts.extra_options_accordion and extra_options else gr.Group(elem_id=elem_id_tabname): diff --git a/modules/api/api.py b/modules/api/api.py index 843c59b0400..0e2807de2b6 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -17,7 +17,7 @@ from secrets import compare_digest import modules.shared as shared -from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing, errors, restart, shared_items, script_callbacks, generation_parameters_copypaste, sd_models +from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing, errors, restart, shared_items, script_callbacks, infotext, sd_models from modules.api import models from modules.shared import opts from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images @@ -369,9 +369,9 @@ def apply_infotext(self, request, tabname, *, script_runner=None, mentioned_scri if not request.infotext: return {} - possible_fields = generation_parameters_copypaste.paste_fields[tabname]["fields"] + possible_fields = infotext.paste_fields[tabname]["fields"] set_fields = request.model_dump(exclude_unset=True) if hasattr(request, "request") else request.dict(exclude_unset=True) # pydantic v1/v2 have differenrt names for this - params = generation_parameters_copypaste.parse_generation_parameters(request.infotext) + params = infotext.parse_generation_parameters(request.infotext) def get_field_value(field, params): value = field.function(params) if field.function else params.get(field.label) @@ -408,7 +408,7 @@ def get_field_value(field, params): if request.override_settings is None: request.override_settings = {} - overriden_settings = generation_parameters_copypaste.get_override_settings(params) + overriden_settings = infotext.get_override_settings(params) for _, setting_name, value in overriden_settings: if setting_name not in request.override_settings: request.override_settings[setting_name] = value @@ -584,7 +584,7 @@ def pnginfoapi(self, req: models.PNGInfoRequest): if geninfo is None: geninfo = "" - params = generation_parameters_copypaste.parse_generation_parameters(geninfo) + params = infotext.parse_generation_parameters(geninfo) script_callbacks.infotext_pasted_callback(geninfo, params) return models.PNGInfoResponse(info=geninfo, items=items, parameters=params) diff --git a/modules/img2img.py b/modules/img2img.py index c583290a0ea..75b3d346f9a 100644 --- a/modules/img2img.py +++ b/modules/img2img.py @@ -7,7 +7,7 @@ import gradio as gr from modules import images as imgutil -from modules.generation_parameters_copypaste import create_override_settings_dict, parse_generation_parameters +from modules.infotext import create_override_settings_dict, parse_generation_parameters from modules.processing import Processed, StableDiffusionProcessingImg2Img, process_images from modules.shared import opts, state from modules.sd_models import get_closet_checkpoint_match diff --git a/modules/postprocessing.py b/modules/postprocessing.py index 0c59fad480b..f776f7b69b9 100644 --- a/modules/postprocessing.py +++ b/modules/postprocessing.py @@ -2,7 +2,7 @@ from PIL import Image -from modules import shared, images, devices, scripts, scripts_postprocessing, ui_common, generation_parameters_copypaste +from modules import shared, images, devices, scripts, scripts_postprocessing, ui_common, infotext from modules.shared import opts @@ -86,7 +86,7 @@ def get_images(extras_mode, image, image_folder, input_dir): basename = '' forced_filename = None - infotext = ", ".join([k if k == v else f'{k}: {generation_parameters_copypaste.quote(v)}' for k, v in pp.info.items() if v is not None]) + infotext = ", ".join([k if k == v else f'{k}: {infotext.quote(v)}' for k, v in pp.info.items() if v is not None]) if opts.enable_pnginfo: pp.image.info = existing_pnginfo diff --git a/modules/processing.py b/modules/processing.py index 7789f9a4239..b30df60db3a 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -16,7 +16,7 @@ from typing import Any import modules.sd_hijack -from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, extra_networks, sd_vae_approx, scripts, sd_samplers_common, sd_unet, errors, rng +from modules import devices, prompt_parser, masking, sd_samplers, lowvram, infotext, extra_networks, sd_vae_approx, scripts, sd_samplers_common, sd_unet, errors, rng from modules.rng import slerp # noqa: F401 from modules.sd_hijack import model_hijack from modules.sd_samplers_common import images_tensor_to_samples, decode_first_stage, approximation_indexes @@ -733,7 +733,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter "User": p.user if opts.add_user_name_to_info else None, } - generation_params_text = ", ".join([k if k == v else f'{k}: {generation_parameters_copypaste.quote(v)}' for k, v in generation_params.items() if v is not None]) + generation_params_text = ", ".join([k if k == v else f'{k}: {infotext.quote(v)}' for k, v in generation_params.items() if v is not None]) prompt_text = p.main_prompt if use_main_prompt else all_prompts[index] negative_prompt_text = f"\nNegative prompt: {p.main_negative_prompt if use_main_prompt else all_negative_prompts[index]}" if all_negative_prompts[index] else "" diff --git a/modules/processing_scripts/refiner.py b/modules/processing_scripts/refiner.py index cefad32b761..e9941413f24 100644 --- a/modules/processing_scripts/refiner.py +++ b/modules/processing_scripts/refiner.py @@ -1,7 +1,7 @@ import gradio as gr from modules import scripts, sd_models -from modules.generation_parameters_copypaste import PasteField +from modules.infotext import PasteField from modules.ui_common import create_refresh_button from modules.ui_components import InputAccordion diff --git a/modules/processing_scripts/seed.py b/modules/processing_scripts/seed.py index a3e16a12e9b..602932785fe 100644 --- a/modules/processing_scripts/seed.py +++ b/modules/processing_scripts/seed.py @@ -3,7 +3,7 @@ import gradio as gr from modules import scripts, ui, errors -from modules.generation_parameters_copypaste import PasteField +from modules.infotext import PasteField from modules.shared import cmd_opts from modules.ui_components import ToolButton diff --git a/modules/shared_items.py b/modules/shared_items.py index 991971ad0fb..e1392472099 100644 --- a/modules/shared_items.py +++ b/modules/shared_items.py @@ -67,14 +67,14 @@ def reload_hypernetworks(): def get_infotext_names(): - from modules import generation_parameters_copypaste, shared + from modules import infotext, shared res = {} for info in shared.opts.data_labels.values(): if info.infotext: res[info.infotext] = 1 - for tab_data in generation_parameters_copypaste.paste_fields.values(): + for tab_data in infotext.paste_fields.values(): for _, name in tab_data.get("fields") or []: if isinstance(name, str): res[name] = 1 diff --git a/modules/txt2img.py b/modules/txt2img.py index e4e18ceb6dd..3a481915f3a 100644 --- a/modules/txt2img.py +++ b/modules/txt2img.py @@ -2,7 +2,7 @@ import modules.scripts from modules import processing -from modules.generation_parameters_copypaste import create_override_settings_dict +from modules.infotext import create_override_settings_dict from modules.shared import opts import modules.shared as shared from modules.ui import plaintext_to_html diff --git a/modules/ui.py b/modules/ui.py index 9db2407ed30..6451e14c14a 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -21,14 +21,14 @@ from modules.shared import opts, cmd_opts -import modules.generation_parameters_copypaste as parameters_copypaste +import modules.infotext as parameters_copypaste import modules.hypernetworks.ui as hypernetworks_ui import modules.textual_inversion.ui as textual_inversion_ui import modules.textual_inversion.textual_inversion as textual_inversion import modules.shared as shared from modules import prompt_parser from modules.sd_hijack import model_hijack -from modules.generation_parameters_copypaste import image_from_url_text, PasteField +from modules.infotext import image_from_url_text, PasteField create_setting_component = ui_settings.create_setting_component diff --git a/modules/ui_common.py b/modules/ui_common.py index 032ec4af762..fd32676f956 100644 --- a/modules/ui_common.py +++ b/modules/ui_common.py @@ -8,10 +8,10 @@ import subprocess as sp from modules import call_queue, shared -from modules.generation_parameters_copypaste import image_from_url_text +from modules.infotext import image_from_url_text import modules.images from modules.ui_components import ToolButton -import modules.generation_parameters_copypaste as parameters_copypaste +import modules.infotext as parameters_copypaste folder_symbol = '\U0001f4c2' # 📂 refresh_symbol = '\U0001f504' # 🔄 diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index b8c02241365..790af135665 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -10,7 +10,7 @@ import html from fastapi.exceptions import HTTPException -from modules.generation_parameters_copypaste import image_from_url_text +from modules.infotext import image_from_url_text from modules.ui_components import ToolButton extra_pages = [] diff --git a/modules/ui_extra_networks_user_metadata.py b/modules/ui_extra_networks_user_metadata.py index 36a807fcdf9..87aeb6f3359 100644 --- a/modules/ui_extra_networks_user_metadata.py +++ b/modules/ui_extra_networks_user_metadata.py @@ -5,7 +5,7 @@ import gradio as gr -from modules import generation_parameters_copypaste, images, sysinfo, errors, ui_extra_networks +from modules import infotext, images, sysinfo, errors, ui_extra_networks class UserMetadataEditor: @@ -181,7 +181,7 @@ def save_preview(self, index, gallery, name): index = len(gallery) - 1 if index >= len(gallery) else index img_info = gallery[index if index >= 0 else 0] - image = generation_parameters_copypaste.image_from_url_text(img_info) + image = infotext.image_from_url_text(img_info) geninfo, items = images.read_info_from_image(image) images.save_image_with_geninfo(image, geninfo, item["local_preview"]) diff --git a/modules/ui_postprocessing.py b/modules/ui_postprocessing.py index 13d888e48d1..b74a153231c 100644 --- a/modules/ui_postprocessing.py +++ b/modules/ui_postprocessing.py @@ -1,6 +1,6 @@ import gradio as gr from modules import scripts, shared, ui_common, postprocessing, call_queue, ui_toprow -import modules.generation_parameters_copypaste as parameters_copypaste +import modules.infotext as parameters_copypaste def create_ui(): From d613cd17c72c753bd1e314dff74dc22d9a949374 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Mon, 1 Jan 2024 14:38:29 +0300 Subject: [PATCH 0513/1174] add automatic backwards version compatibility --- modules/infotext.py | 4 +++- modules/infotext_versions.py | 35 +++++++++++++++++++++++++++++++++++ modules/shared_options.py | 1 + 3 files changed, 39 insertions(+), 1 deletion(-) create mode 100644 modules/infotext_versions.py diff --git a/modules/infotext.py b/modules/infotext.py index bcbeb0fdcd1..7f30446b9b4 100644 --- a/modules/infotext.py +++ b/modules/infotext.py @@ -8,7 +8,7 @@ import gradio as gr from modules.paths import data_path -from modules import shared, ui_tempdir, script_callbacks, processing +from modules import shared, ui_tempdir, script_callbacks, processing, infotext_versions, errors from PIL import Image sys.modules['modules.generation_parameters_copypaste'] = sys.modules[__name__] # alias for old name @@ -342,6 +342,8 @@ def parse_generation_parameters(x: str): if "Cache FP16 weight for LoRA" not in res and res["FP8 weight"] != "Disable": res["Cache FP16 weight for LoRA"] = False + infotext_versions.backcompat(res) + skip = set(shared.opts.infotext_skip_pasting) res = {k: v for k, v in res.items() if k not in skip} diff --git a/modules/infotext_versions.py b/modules/infotext_versions.py new file mode 100644 index 00000000000..01e885a26ef --- /dev/null +++ b/modules/infotext_versions.py @@ -0,0 +1,35 @@ +from modules import shared +from packaging import version +import re + + +v160 = version.parse("1.6.0") + + +def parse_version(text): + if text is None: + return None + + m = re.match(r'([^-]+-[^-]+)-.*', text) + if m: + text = m.group(1) + + try: + return version.parse(text) + except Exception as e: + return None + + +def backcompat(d): + """Checks infotext Version field, and enables backwards compatibility options according to it.""" + + if not shared.opts.auto_backcompat: + return + + ver = parse_version(d.get("Version")) + if ver is None: + return + + if ver < v160: + d["Old prompt editing timelines"] = True + diff --git a/modules/shared_options.py b/modules/shared_options.py index 752a4f1259a..281591da810 100644 --- a/modules/shared_options.py +++ b/modules/shared_options.py @@ -212,6 +212,7 @@ })) options_templates.update(options_section(('compatibility', "Compatibility", "sd"), { + "auto_backcompat": OptionInfo(True, "Automatic backward compatibility").info("automatically enable options for backwards compatibility when importing generation parameters from infotext that has program version."), "use_old_emphasis_implementation": OptionInfo(False, "Use old emphasis implementation. Can be useful to reproduce old seeds."), "use_old_karras_scheduler_sigmas": OptionInfo(False, "Use old karras scheduler sigmas (0.1 to 10)."), "no_dpmpp_sde_batch_determinism": OptionInfo(False, "Do not make DPM++ SDE deterministic across different batch sizes."), From 45b7bba3d06f2d4bc2fffc210cbfcb357b86add6 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Mon, 1 Jan 2024 14:51:56 +0300 Subject: [PATCH 0514/1174] add automatic version support for zero terminal SNR noise schedule option from #14145 --- modules/infotext_versions.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/modules/infotext_versions.py b/modules/infotext_versions.py index 01e885a26ef..9a204d84292 100644 --- a/modules/infotext_versions.py +++ b/modules/infotext_versions.py @@ -4,6 +4,7 @@ v160 = version.parse("1.6.0") +v170_tsnr = version.parse("v1.7.0-225") def parse_version(text): @@ -33,3 +34,6 @@ def backcompat(d): if ver < v160: d["Old prompt editing timelines"] = True + if ver < v170_tsnr: + d["Downcast alphas_cumprod"] = True + From d8126be578c7d4579c0f2ee4adbe35500bc71ce6 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Mon, 1 Jan 2024 15:00:39 +0300 Subject: [PATCH 0515/1174] linter --- modules/infotext.py | 2 +- modules/infotext_versions.py | 2 +- modules/postprocessing.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/modules/infotext.py b/modules/infotext.py index 7f30446b9b4..26e9b94934f 100644 --- a/modules/infotext.py +++ b/modules/infotext.py @@ -8,7 +8,7 @@ import gradio as gr from modules.paths import data_path -from modules import shared, ui_tempdir, script_callbacks, processing, infotext_versions, errors +from modules import shared, ui_tempdir, script_callbacks, processing, infotext_versions from PIL import Image sys.modules['modules.generation_parameters_copypaste'] = sys.modules[__name__] # alias for old name diff --git a/modules/infotext_versions.py b/modules/infotext_versions.py index 9a204d84292..a5afeebf136 100644 --- a/modules/infotext_versions.py +++ b/modules/infotext_versions.py @@ -17,7 +17,7 @@ def parse_version(text): try: return version.parse(text) - except Exception as e: + except Exception: return None diff --git a/modules/postprocessing.py b/modules/postprocessing.py index f776f7b69b9..facea899f37 100644 --- a/modules/postprocessing.py +++ b/modules/postprocessing.py @@ -2,7 +2,7 @@ from PIL import Image -from modules import shared, images, devices, scripts, scripts_postprocessing, ui_common, infotext +from modules import shared, images, devices, scripts, scripts_postprocessing, ui_common from modules.shared import opts From 0743ee9b3eda8dd4ceea625d710031577201f4ad Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Mon, 1 Jan 2024 15:50:47 +0300 Subject: [PATCH 0516/1174] re-layout checkboxes for XYZ grid a bit --- scripts/xyz_grid.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/scripts/xyz_grid.py b/scripts/xyz_grid.py index 34267c2c3bc..2d550994709 100644 --- a/scripts/xyz_grid.py +++ b/scripts/xyz_grid.py @@ -438,17 +438,16 @@ def ui(self, is_img2img): with gr.Column(): draw_legend = gr.Checkbox(label='Draw legend', value=True, elem_id=self.elem_id("draw_legend")) no_fixed_seeds = gr.Checkbox(label='Keep -1 for seeds', value=False, elem_id=self.elem_id("no_fixed_seeds")) + with gr.Row(): + vary_seeds_x = gr.Checkbox(label='Vary seeds for X', value=False, min_width=80, elem_id=self.elem_id("vary_seeds_x"), tooltip="Use different seeds for images along X axis.") + vary_seeds_y = gr.Checkbox(label='Vary seeds for Y', value=False, min_width=80, elem_id=self.elem_id("vary_seeds_y"), tooltip="Use different seeds for images along Y axis.") + vary_seeds_z = gr.Checkbox(label='Vary seeds for Z', value=False, min_width=80, elem_id=self.elem_id("vary_seeds_z"), tooltip="Use different seeds for images along Z axis.") with gr.Column(): include_lone_images = gr.Checkbox(label='Include Sub Images', value=False, elem_id=self.elem_id("include_lone_images")) include_sub_grids = gr.Checkbox(label='Include Sub Grids', value=False, elem_id=self.elem_id("include_sub_grids")) - with gr.Column(): - vary_seeds_x = gr.Checkbox(label='Vary seed on X axis', value=False, elem_id=self.elem_id("vary_seeds_x")) - vary_seeds_y = gr.Checkbox(label='Vary seed on Y axis', value=False, elem_id=self.elem_id("vary_seeds_y")) - vary_seeds_z = gr.Checkbox(label='Vary seed on Z axis', value=False, elem_id=self.elem_id("vary_seeds_z")) + csv_mode = gr.Checkbox(label='Use text inputs instead of dropdowns', value=False, elem_id=self.elem_id("csv_mode")) with gr.Column(): margin_size = gr.Slider(label="Grid margins (px)", minimum=0, maximum=500, value=0, step=2, elem_id=self.elem_id("margin_size")) - with gr.Column(): - csv_mode = gr.Checkbox(label='Use text inputs instead of dropdowns', value=False, elem_id=self.elem_id("csv_mode")) with gr.Row(variant="compact", elem_id="swap_axes"): swap_xy_axes_button = gr.Button(value="Swap X/Y axes", elem_id="xy_grid_swap_axes_button") @@ -531,7 +530,7 @@ def get_dropdown_update_from_params(axis, params): return [x_type, x_values, x_values_dropdown, y_type, y_values, y_values_dropdown, z_type, z_values, z_values_dropdown, draw_legend, include_lone_images, include_sub_grids, no_fixed_seeds, vary_seeds_x, vary_seeds_y, vary_seeds_z, margin_size, csv_mode] - def run(self, p, x_type, x_values, x_values_dropdown, y_type, y_values, y_values_dropdown, z_type, z_values, z_values_dropdown, draw_legend, include_lone_images, include_sub_grids, no_fixed_seeds, margin_size, csv_mode): + def run(self, p, x_type, x_values, x_values_dropdown, y_type, y_values, y_values_dropdown, z_type, z_values, z_values_dropdown, draw_legend, include_lone_images, include_sub_grids, no_fixed_seeds, vary_seeds_x, vary_seeds_y, vary_seeds_z, margin_size, csv_mode): x_type, y_type, z_type = x_type or 0, y_type or 0, z_type or 0 # if axle type is None set to 0 if not no_fixed_seeds: From ac0ecf3b4b9d147743c04f0ff4ddc4cf4595e11d Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Mon, 1 Jan 2024 16:28:58 +0300 Subject: [PATCH 0517/1174] option to convert VAE to bfloat16 (implementation of #9295) --- modules/processing.py | 23 ++++++++++++++++++----- modules/shared_options.py | 1 + 2 files changed, 19 insertions(+), 5 deletions(-) diff --git a/modules/processing.py b/modules/processing.py index 846e4796af2..f065688218a 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -628,20 +628,33 @@ def decode_latent_batch(model, batch, target_device=None, check_for_nans=False): sample = decode_first_stage(model, batch[i:i + 1])[0] if check_for_nans: + try: devices.test_for_nans(sample, "vae") except devices.NansException as e: - if devices.dtype_vae == torch.float32 or not shared.opts.auto_vae_precision: + if shared.opts.auto_vae_precision_bfloat16: + autofix_dtype = torch.bfloat16 + autofix_dtype_text = "bfloat16" + autofix_dtype_setting = "Automatically convert VAE to bfloat16" + autofix_dtype_comment = "" + elif shared.opts.auto_vae_precision: + autofix_dtype = torch.float32 + autofix_dtype_text = "32-bit float" + autofix_dtype_setting = "Automatically revert VAE to 32-bit floats" + autofix_dtype_comment = "\nTo always start with 32-bit VAE, use --no-half-vae commandline flag." + else: + raise e + + if devices.dtype_vae == autofix_dtype: raise e errors.print_error_explanation( "A tensor with all NaNs was produced in VAE.\n" - "Web UI will now convert VAE into 32-bit float and retry.\n" - "To disable this behavior, disable the 'Automatically revert VAE to 32-bit floats' setting.\n" - "To always start with 32-bit VAE, use --no-half-vae commandline flag." + f"Web UI will now convert VAE into {autofix_dtype_text} and retry.\n" + f"To disable this behavior, disable the '{autofix_dtype_setting}' setting.{autofix_dtype_comment}" ) - devices.dtype_vae = torch.float32 + devices.dtype_vae = autofix_dtype model.first_stage_model.to(devices.dtype_vae) batch = batch.to(devices.dtype_vae) diff --git a/modules/shared_options.py b/modules/shared_options.py index ce06f022eaf..e813546f6a5 100644 --- a/modules/shared_options.py +++ b/modules/shared_options.py @@ -177,6 +177,7 @@ "sd_vae_checkpoint_cache": OptionInfo(0, "VAE Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}), "sd_vae": OptionInfo("Automatic", "SD VAE", gr.Dropdown, lambda: {"choices": shared_items.sd_vae_items()}, refresh=shared_items.refresh_vae_list, infotext='VAE').info("choose VAE model: Automatic = use one with same filename as checkpoint; None = use VAE from checkpoint"), "sd_vae_overrides_per_model_preferences": OptionInfo(True, "Selected VAE overrides per-model preferences").info("you can set per-model VAE either by editing user metadata for checkpoints, or by making the VAE have same name as checkpoint"), + "auto_vae_precision_bfloat16": OptionInfo(False, "Automatically convert VAE to bfloat16").info("triggers when a tensor with NaNs is produced in VAE; disabling the option in this case will result in a black square image; if enabled, overrides the option below"), "auto_vae_precision": OptionInfo(True, "Automatically revert VAE to 32-bit floats").info("triggers when a tensor with NaNs is produced in VAE; disabling the option in this case will result in a black square image"), "sd_vae_encode_method": OptionInfo("Full", "VAE type for encode", gr.Radio, {"choices": ["Full", "TAESD"]}, infotext='VAE Encoder').info("method to encode image to latent (use in img2img, hires-fix or inpaint mask)"), "sd_vae_decode_method": OptionInfo("Full", "VAE type for decode", gr.Radio, {"choices": ["Full", "TAESD"]}, infotext='VAE Decoder').info("method to decode latent to image"), From 0aa7c53c0b9469849377aff83f43c9f75c19b3fa Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Mon, 1 Jan 2024 16:50:59 +0300 Subject: [PATCH 0518/1174] fix borked merge, rename fields to better match what they do, change setting default to true for #13653 --- modules/call_queue.py | 2 +- modules/img2img.py | 2 +- modules/processing.py | 2 +- modules/shared_options.py | 2 +- modules/shared_state.py | 12 ++++++------ modules/ui_toprow.py | 8 +++++++- scripts/loopback.py | 4 ++-- scripts/xyz_grid.py | 2 +- 8 files changed, 20 insertions(+), 14 deletions(-) diff --git a/modules/call_queue.py b/modules/call_queue.py index 01c6d17f685..bcd7c546200 100644 --- a/modules/call_queue.py +++ b/modules/call_queue.py @@ -78,7 +78,7 @@ def f(*args, extra_outputs_array=extra_outputs, **kwargs): shared.state.skipped = False shared.state.interrupted = False - shared.state.interrupted_next = False + shared.state.stopping_generation = False shared.state.job_count = 0 if not add_stats: diff --git a/modules/img2img.py b/modules/img2img.py index 829faa818ff..e7e8e2510ca 100644 --- a/modules/img2img.py +++ b/modules/img2img.py @@ -51,7 +51,7 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=Fal if state.skipped: state.skipped = False - if state.interrupted or state.interrupted_next: + if state.interrupted or state.stopping_generation: break try: diff --git a/modules/processing.py b/modules/processing.py index 00de2ed2e9a..f55b85ed350 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -865,7 +865,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: if state.skipped: state.skipped = False - if state.interrupted or state.interrupted_next: + if state.interrupted or state.stopping_generation: break sd_models.reload_model_weights() # model can be changed for example by refiner diff --git a/modules/shared_options.py b/modules/shared_options.py index 7852e0ea362..7581e276e25 100644 --- a/modules/shared_options.py +++ b/modules/shared_options.py @@ -120,7 +120,6 @@ "disable_mmap_load_safetensors": OptionInfo(False, "Disable memmapping for loading .safetensors files.").info("fixes very slow loading speed in some cases"), "hide_ldm_prints": OptionInfo(True, "Prevent Stability-AI's ldm/sgm modules from printing noise to console."), "dump_stacks_on_signal": OptionInfo(False, "Print stack traces before exiting the program with ctrl+c."), - "interrupt_after_current": OptionInfo(False, "Interrupt generation after current image is finished on batch processing"), })) options_templates.update(options_section(('API', "API", "system"), { @@ -286,6 +285,7 @@ "hires_fix_show_prompts": OptionInfo(False, "Hires fix: show hires prompt and negative prompt").needs_reload_ui(), "txt2img_settings_accordion": OptionInfo(False, "Settings in txt2img hidden under Accordion").needs_reload_ui(), "img2img_settings_accordion": OptionInfo(False, "Settings in img2img hidden under Accordion").needs_reload_ui(), + "interrupt_after_current": OptionInfo(True, "Don't Interrupt in the middle").info("when using Interrupt button, if generating more than one image, stop after the generation of an image has finished, instead of immediately"), })) options_templates.update(options_section(('ui', "User interface", "ui"), { diff --git a/modules/shared_state.py b/modules/shared_state.py index 532fdcd8d12..33996691c40 100644 --- a/modules/shared_state.py +++ b/modules/shared_state.py @@ -12,7 +12,7 @@ class State: skipped = False interrupted = False - interrupted_next = False + stopping_generation = False job = "" job_no = 0 job_count = 0 @@ -80,9 +80,9 @@ def interrupt(self): self.interrupted = True log.info("Received interrupt request") - def interrupt_next(self): - self.interrupted_next = True - log.info("Received interrupt request, interrupt after current job") + def stop_generating(self): + self.stopping_generation = True + log.info("Received stop generating request") def nextjob(self): if shared.opts.live_previews_enable and shared.opts.show_progress_every_n_steps == -1: @@ -96,7 +96,7 @@ def dict(self): obj = { "skipped": self.skipped, "interrupted": self.interrupted, - "interrupted_next": self.interrupted_next, + "stopping_generation": self.stopping_generation, "job": self.job, "job_count": self.job_count, "job_timestamp": self.job_timestamp, @@ -120,7 +120,7 @@ def begin(self, job: str = "(unknown)"): self.id_live_preview = 0 self.skipped = False self.interrupted = False - self.interrupted_next = False + self.stopping_generation = False self.textinfo = None self.job = job devices.torch_gc() diff --git a/modules/ui_toprow.py b/modules/ui_toprow.py index 9caf8faa2de..1abc9117ba3 100644 --- a/modules/ui_toprow.py +++ b/modules/ui_toprow.py @@ -106,8 +106,14 @@ def create_submit_box(self): outputs=[], ) + def interrupt_function(): + if shared.state.job_count > 1 and shared.opts.interrupt_after_current: + shared.state.stop_generating() + else: + shared.state.interrupt() + self.interrupt.click( - fn=lambda: shared.state.interrupt(), + fn=interrupt_function, inputs=[], outputs=[], ) diff --git a/scripts/loopback.py b/scripts/loopback.py index ad921269afb..800ee882a16 100644 --- a/scripts/loopback.py +++ b/scripts/loopback.py @@ -95,7 +95,7 @@ def calculate_denoising_strength(loop): processed = processing.process_images(p) # Generation cancelled. - if state.interrupted or state.interrupted_next: + if state.interrupted or state.stopping_generation: break if initial_seed is None: @@ -122,7 +122,7 @@ def calculate_denoising_strength(loop): p.inpainting_fill = original_inpainting_fill - if state.interrupted or state.interrupted_next: + if state.interrupted or state.stopping_generation: break if len(history) > 1: diff --git a/scripts/xyz_grid.py b/scripts/xyz_grid.py index 2deff365fa1..2f385ebf295 100644 --- a/scripts/xyz_grid.py +++ b/scripts/xyz_grid.py @@ -696,7 +696,7 @@ def fix_axis_seeds(axis_opt, axis_list): grid_infotext = [None] * (1 + len(zs)) def cell(x, y, z, ix, iy, iz): - if shared.state.interrupted or state.interrupted_next: + if shared.state.interrupted or state.stopping_generation: return Processed(p, [], p.seed, "") pc = copy(p) From 1ffdedc11d49862cf0d030fb0bcc25eb0449939b Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Mon, 1 Jan 2024 17:03:08 +0300 Subject: [PATCH 0519/1174] restore lines lost from #13789 merge --- modules/cmd_args.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/modules/cmd_args.py b/modules/cmd_args.py index 3775e670261..e58059a1fea 100644 --- a/modules/cmd_args.py +++ b/modules/cmd_args.py @@ -93,7 +93,7 @@ parser.add_argument("--theme", type=str, help="launches the UI with light or dark theme", default=None) parser.add_argument("--use-textbox-seed", action='store_true', help="use textbox for seeds in UI (no up/down, but possible to input long seeds)", default=False) parser.add_argument("--disable-console-progressbars", action='store_true', help="do not output progressbars to console", default=False) -parser.add_argument("--enable-console-prompts", action='store_true', help="print prompts to console when generating with txt2img and img2img", default=False) +parser.add_argument("--enable-console-prompts", action='store_true', help="does not do anything", default=False) # Legacy compatibility, use as default value shared.opts.enable_console_prompts parser.add_argument('--vae-path', type=str, help='Checkpoint to use as VAE; setting this argument disables all settings related to VAE', default=None) parser.add_argument("--disable-safe-unpickle", action='store_true', help="disable checking pytorch models for malicious code", default=False) parser.add_argument("--api", action='store_true', help="use api=True to launch the API together with the webui (use --nowebui instead for only the API)") @@ -115,8 +115,9 @@ parser.add_argument("--no-hashing", action='store_true', help="disable sha256 hashing of checkpoints to help loading performance", default=False) parser.add_argument("--no-download-sd-model", action='store_true', help="don't download SD1.5 model even if no model is found in --ckpt-dir", default=False) parser.add_argument('--subpath', type=str, help='customize the subpath for gradio, use with reverse proxy') -parser.add_argument('--add-stop-route', action='store_true', help='add /_stop route to stop server') +parser.add_argument('--add-stop-route', action='store_true', help='does not do anything') parser.add_argument('--api-server-stop', action='store_true', help='enable server stop/restart/kill via api') parser.add_argument('--timeout-keep-alive', type=int, default=30, help='set timeout_keep_alive for uvicorn') parser.add_argument("--disable-all-extensions", action='store_true', help="prevent all extensions from running regardless of any other settings", default=False) -parser.add_argument("--disable-extra-extensions", action='store_true', help=" prevent all extensions except built-in from running regardless of any other settings", default=False) +parser.add_argument("--disable-extra-extensions", action='store_true', help="prevent all extensions except built-in from running regardless of any other settings", default=False) +parser.add_argument("--skip-load-model-at-start", action='store_true', help="if load a model at web start, only take effect when --nowebui", ) From 5d7d1823afab0a051a3fbbdb3213bae8051350b7 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Mon, 1 Jan 2024 17:25:30 +0300 Subject: [PATCH 0520/1174] rename infotext.py again, this time to infotext_utils.py; I didn't realize infotext would be used for variable names in multiple places, which makes it awkward to import the module; also fix the bug I caused by this rename that breaks tests --- .../scripts/extra_options_section.py | 4 ++-- modules/api/api.py | 10 +++++----- modules/img2img.py | 2 +- modules/{infotext.py => infotext_utils.py} | 0 modules/postprocessing.py | 4 ++-- modules/processing.py | 4 ++-- modules/processing_scripts/refiner.py | 2 +- modules/processing_scripts/seed.py | 2 +- modules/shared_items.py | 4 ++-- modules/txt2img.py | 2 +- modules/ui.py | 4 ++-- modules/ui_common.py | 4 ++-- modules/ui_extra_networks.py | 2 +- modules/ui_extra_networks_user_metadata.py | 4 ++-- modules/ui_postprocessing.py | 2 +- 15 files changed, 25 insertions(+), 25 deletions(-) rename modules/{infotext.py => infotext_utils.py} (100%) diff --git a/extensions-builtin/extra-options-section/scripts/extra_options_section.py b/extensions-builtin/extra-options-section/scripts/extra_options_section.py index 8aa901fd481..4c10d9c7d6a 100644 --- a/extensions-builtin/extra-options-section/scripts/extra_options_section.py +++ b/extensions-builtin/extra-options-section/scripts/extra_options_section.py @@ -1,7 +1,7 @@ import math import gradio as gr -from modules import scripts, shared, ui_components, ui_settings, infotext +from modules import scripts, shared, ui_components, ui_settings, infotext_utils from modules.ui_components import FormColumn @@ -25,7 +25,7 @@ def ui(self, is_img2img): extra_options = shared.opts.extra_options_img2img if is_img2img else shared.opts.extra_options_txt2img elem_id_tabname = "extra_options_" + ("img2img" if is_img2img else "txt2img") - mapping = {k: v for v, k in infotext.infotext_to_setting_name_mapping} + mapping = {k: v for v, k in infotext_utils.infotext_to_setting_name_mapping} with gr.Blocks() as interface: with gr.Accordion("Options", open=False, elem_id=elem_id_tabname) if shared.opts.extra_options_accordion and extra_options else gr.Group(elem_id=elem_id_tabname): diff --git a/modules/api/api.py b/modules/api/api.py index 0e2807de2b6..9d1292e9571 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -17,7 +17,7 @@ from secrets import compare_digest import modules.shared as shared -from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing, errors, restart, shared_items, script_callbacks, infotext, sd_models +from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing, errors, restart, shared_items, script_callbacks, infotext_utils, sd_models from modules.api import models from modules.shared import opts from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images @@ -369,9 +369,9 @@ def apply_infotext(self, request, tabname, *, script_runner=None, mentioned_scri if not request.infotext: return {} - possible_fields = infotext.paste_fields[tabname]["fields"] + possible_fields = infotext_utils.paste_fields[tabname]["fields"] set_fields = request.model_dump(exclude_unset=True) if hasattr(request, "request") else request.dict(exclude_unset=True) # pydantic v1/v2 have differenrt names for this - params = infotext.parse_generation_parameters(request.infotext) + params = infotext_utils.parse_generation_parameters(request.infotext) def get_field_value(field, params): value = field.function(params) if field.function else params.get(field.label) @@ -408,7 +408,7 @@ def get_field_value(field, params): if request.override_settings is None: request.override_settings = {} - overriden_settings = infotext.get_override_settings(params) + overriden_settings = infotext_utils.get_override_settings(params) for _, setting_name, value in overriden_settings: if setting_name not in request.override_settings: request.override_settings[setting_name] = value @@ -584,7 +584,7 @@ def pnginfoapi(self, req: models.PNGInfoRequest): if geninfo is None: geninfo = "" - params = infotext.parse_generation_parameters(geninfo) + params = infotext_utils.parse_generation_parameters(geninfo) script_callbacks.infotext_pasted_callback(geninfo, params) return models.PNGInfoResponse(info=geninfo, items=items, parameters=params) diff --git a/modules/img2img.py b/modules/img2img.py index e7e8e2510ca..04de8e62cf2 100644 --- a/modules/img2img.py +++ b/modules/img2img.py @@ -7,7 +7,7 @@ import gradio as gr from modules import images as imgutil -from modules.infotext import create_override_settings_dict, parse_generation_parameters +from modules.infotext_utils import create_override_settings_dict, parse_generation_parameters from modules.processing import Processed, StableDiffusionProcessingImg2Img, process_images from modules.shared import opts, state from modules.sd_models import get_closet_checkpoint_match diff --git a/modules/infotext.py b/modules/infotext_utils.py similarity index 100% rename from modules/infotext.py rename to modules/infotext_utils.py diff --git a/modules/postprocessing.py b/modules/postprocessing.py index facea899f37..7850328f63f 100644 --- a/modules/postprocessing.py +++ b/modules/postprocessing.py @@ -2,7 +2,7 @@ from PIL import Image -from modules import shared, images, devices, scripts, scripts_postprocessing, ui_common +from modules import shared, images, devices, scripts, scripts_postprocessing, ui_common, infotext_utils from modules.shared import opts @@ -86,7 +86,7 @@ def get_images(extras_mode, image, image_folder, input_dir): basename = '' forced_filename = None - infotext = ", ".join([k if k == v else f'{k}: {infotext.quote(v)}' for k, v in pp.info.items() if v is not None]) + infotext = ", ".join([k if k == v else f'{k}: {infotext_utils.quote(v)}' for k, v in pp.info.items() if v is not None]) if opts.enable_pnginfo: pp.image.info = existing_pnginfo diff --git a/modules/processing.py b/modules/processing.py index f55b85ed350..213a2879cc6 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -16,7 +16,7 @@ from typing import Any import modules.sd_hijack -from modules import devices, prompt_parser, masking, sd_samplers, lowvram, infotext, extra_networks, sd_vae_approx, scripts, sd_samplers_common, sd_unet, errors, rng +from modules import devices, prompt_parser, masking, sd_samplers, lowvram, infotext_utils, extra_networks, sd_vae_approx, scripts, sd_samplers_common, sd_unet, errors, rng from modules.rng import slerp # noqa: F401 from modules.sd_hijack import model_hijack from modules.sd_samplers_common import images_tensor_to_samples, decode_first_stage, approximation_indexes @@ -746,7 +746,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter "User": p.user if opts.add_user_name_to_info else None, } - generation_params_text = ", ".join([k if k == v else f'{k}: {infotext.quote(v)}' for k, v in generation_params.items() if v is not None]) + generation_params_text = ", ".join([k if k == v else f'{k}: {infotext_utils.quote(v)}' for k, v in generation_params.items() if v is not None]) prompt_text = p.main_prompt if use_main_prompt else all_prompts[index] negative_prompt_text = f"\nNegative prompt: {p.main_negative_prompt if use_main_prompt else all_negative_prompts[index]}" if all_negative_prompts[index] else "" diff --git a/modules/processing_scripts/refiner.py b/modules/processing_scripts/refiner.py index e9941413f24..ba33d8a4b80 100644 --- a/modules/processing_scripts/refiner.py +++ b/modules/processing_scripts/refiner.py @@ -1,7 +1,7 @@ import gradio as gr from modules import scripts, sd_models -from modules.infotext import PasteField +from modules.infotext_utils import PasteField from modules.ui_common import create_refresh_button from modules.ui_components import InputAccordion diff --git a/modules/processing_scripts/seed.py b/modules/processing_scripts/seed.py index 602932785fe..2d3cbb97f06 100644 --- a/modules/processing_scripts/seed.py +++ b/modules/processing_scripts/seed.py @@ -3,7 +3,7 @@ import gradio as gr from modules import scripts, ui, errors -from modules.infotext import PasteField +from modules.infotext_utils import PasteField from modules.shared import cmd_opts from modules.ui_components import ToolButton diff --git a/modules/shared_items.py b/modules/shared_items.py index e1392472099..13fb2814f2b 100644 --- a/modules/shared_items.py +++ b/modules/shared_items.py @@ -67,14 +67,14 @@ def reload_hypernetworks(): def get_infotext_names(): - from modules import infotext, shared + from modules import infotext_utils, shared res = {} for info in shared.opts.data_labels.values(): if info.infotext: res[info.infotext] = 1 - for tab_data in infotext.paste_fields.values(): + for tab_data in infotext_utils.paste_fields.values(): for _, name in tab_data.get("fields") or []: if isinstance(name, str): res[name] = 1 diff --git a/modules/txt2img.py b/modules/txt2img.py index 3a481915f3a..49660e89167 100644 --- a/modules/txt2img.py +++ b/modules/txt2img.py @@ -2,7 +2,7 @@ import modules.scripts from modules import processing -from modules.infotext import create_override_settings_dict +from modules.infotext_utils import create_override_settings_dict from modules.shared import opts import modules.shared as shared from modules.ui import plaintext_to_html diff --git a/modules/ui.py b/modules/ui.py index 378529c79a4..52b15646ab5 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -21,14 +21,14 @@ from modules.shared import opts, cmd_opts -import modules.infotext as parameters_copypaste +import modules.infotext_utils as parameters_copypaste import modules.hypernetworks.ui as hypernetworks_ui import modules.textual_inversion.ui as textual_inversion_ui import modules.textual_inversion.textual_inversion as textual_inversion import modules.shared as shared from modules import prompt_parser from modules.sd_hijack import model_hijack -from modules.infotext import image_from_url_text, PasteField +from modules.infotext_utils import image_from_url_text, PasteField create_setting_component = ui_settings.create_setting_component diff --git a/modules/ui_common.py b/modules/ui_common.py index fd32676f956..f48ad426072 100644 --- a/modules/ui_common.py +++ b/modules/ui_common.py @@ -8,10 +8,10 @@ import subprocess as sp from modules import call_queue, shared -from modules.infotext import image_from_url_text +from modules.infotext_utils import image_from_url_text import modules.images from modules.ui_components import ToolButton -import modules.infotext as parameters_copypaste +import modules.infotext_utils as parameters_copypaste folder_symbol = '\U0001f4c2' # 📂 refresh_symbol = '\U0001f504' # 🔄 diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index 790af135665..beea1316083 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -10,7 +10,7 @@ import html from fastapi.exceptions import HTTPException -from modules.infotext import image_from_url_text +from modules.infotext_utils import image_from_url_text from modules.ui_components import ToolButton extra_pages = [] diff --git a/modules/ui_extra_networks_user_metadata.py b/modules/ui_extra_networks_user_metadata.py index 87aeb6f3359..989a649b799 100644 --- a/modules/ui_extra_networks_user_metadata.py +++ b/modules/ui_extra_networks_user_metadata.py @@ -5,7 +5,7 @@ import gradio as gr -from modules import infotext, images, sysinfo, errors, ui_extra_networks +from modules import infotext_utils, images, sysinfo, errors, ui_extra_networks class UserMetadataEditor: @@ -181,7 +181,7 @@ def save_preview(self, index, gallery, name): index = len(gallery) - 1 if index >= len(gallery) else index img_info = gallery[index if index >= 0 else 0] - image = infotext.image_from_url_text(img_info) + image = infotext_utils.image_from_url_text(img_info) geninfo, items = images.read_info_from_image(image) images.save_image_with_geninfo(image, geninfo, item["local_preview"]) diff --git a/modules/ui_postprocessing.py b/modules/ui_postprocessing.py index b74a153231c..1edb68c5ced 100644 --- a/modules/ui_postprocessing.py +++ b/modules/ui_postprocessing.py @@ -1,6 +1,6 @@ import gradio as gr from modules import scripts, shared, ui_common, postprocessing, call_queue, ui_toprow -import modules.infotext as parameters_copypaste +import modules.infotext_utils as parameters_copypaste def create_ui(): From 501993ebf210bf3b55173ec1910f0c84c7e75424 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Mon, 1 Jan 2024 19:31:06 +0300 Subject: [PATCH 0521/1174] added a button to run hires fix on selected image in the gallery --- javascript/ui.js | 8 +++ modules/processing.py | 46 ++++++++++++--- modules/txt2img.py | 19 +++++- modules/ui.py | 108 +++++++++++++++++++---------------- modules/ui_common.py | 57 +++++++++++------- modules/ui_postprocessing.py | 8 +-- 6 files changed, 160 insertions(+), 86 deletions(-) diff --git a/javascript/ui.js b/javascript/ui.js index 18c9f891afc..78d1caee5bf 100644 --- a/javascript/ui.js +++ b/javascript/ui.js @@ -150,6 +150,14 @@ function submit() { return res; } +function submit_txt2img_upscale() { + res = submit.apply(null, arguments); + + res[2] = selected_gallery_index(); + + return res; +} + function submit_img2img() { showSubmitButtons('img2img', false); diff --git a/modules/processing.py b/modules/processing.py index 213a2879cc6..045c7d7955e 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -179,6 +179,7 @@ class StableDiffusionProcessing: token_merging_ratio = 0 token_merging_ratio_hr = 0 disable_extra_networks: bool = False + firstpass_image: Image = None scripts_value: scripts.ScriptRunner = field(default=None, init=False) script_args_value: list = field(default=None, init=False) @@ -1238,18 +1239,45 @@ def init(self, all_prompts, all_seeds, all_subseeds): def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts): self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model) - x = self.rng.next() - samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x)) - del x + if self.firstpass_image is not None and self.enable_hr: + # here we don't need to generate image, we just take self.firstpass_image and prepare it for hires fix - if not self.enable_hr: - return samples - devices.torch_gc() + if self.latent_scale_mode is None: + image = np.array(self.firstpass_image).astype(np.float32) / 255.0 * 2.0 - 1.0 + image = np.moveaxis(image, 2, 0) + + samples = None + decoded_samples = torch.asarray(np.expand_dims(image, 0)) + + else: + image = np.array(self.firstpass_image).astype(np.float32) / 255.0 + image = np.moveaxis(image, 2, 0) + image = torch.from_numpy(np.expand_dims(image, axis=0)) + image = image.to(shared.device, dtype=devices.dtype_vae) + + if opts.sd_vae_encode_method != 'Full': + self.extra_generation_params['VAE Encoder'] = opts.sd_vae_encode_method + + samples = images_tensor_to_samples(image, approximation_indexes.get(opts.sd_vae_encode_method), self.sd_model) + decoded_samples = None + devices.torch_gc() - if self.latent_scale_mode is None: - decoded_samples = torch.stack(decode_latent_batch(self.sd_model, samples, target_device=devices.cpu, check_for_nans=True)).to(dtype=torch.float32) else: - decoded_samples = None + # here we generate an image normally + + x = self.rng.next() + samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x)) + del x + + if not self.enable_hr: + return samples + + devices.torch_gc() + + if self.latent_scale_mode is None: + decoded_samples = torch.stack(decode_latent_batch(self.sd_model, samples, target_device=devices.cpu, check_for_nans=True)).to(dtype=torch.float32) + else: + decoded_samples = None with sd_models.SkipWritingToConfig(): sd_models.reload_model_weights(info=self.hr_checkpoint_info) diff --git a/modules/txt2img.py b/modules/txt2img.py index 49660e89167..4a6fe72a65d 100644 --- a/modules/txt2img.py +++ b/modules/txt2img.py @@ -1,7 +1,7 @@ from contextlib import closing import modules.scripts -from modules import processing +from modules import processing, infotext_utils from modules.infotext_utils import create_override_settings_dict from modules.shared import opts import modules.shared as shared @@ -9,9 +9,23 @@ import gradio as gr -def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, steps: int, sampler_name: str, n_iter: int, batch_size: int, cfg_scale: float, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, hr_checkpoint_name: str, hr_sampler_name: str, hr_prompt: str, hr_negative_prompt, override_settings_texts, request: gr.Request, *args): +def txt2img_upscale(id_task: str, request: gr.Request, gallery, gallery_index, *args): + assert len(gallery) > 0, 'No image to upscale' + + image_info = gallery[gallery_index] if 0 <= gallery_index < len(gallery) else gallery[0] + image = infotext_utils.image_from_url_text(image_info) + + return txt2img(id_task, request, *args, firstpass_image=image) + + +def txt2img(id_task: str, request: gr.Request, prompt: str, negative_prompt: str, prompt_styles, steps: int, sampler_name: str, n_iter: int, batch_size: int, cfg_scale: float, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, hr_checkpoint_name: str, hr_sampler_name: str, hr_prompt: str, hr_negative_prompt, override_settings_texts, *args, firstpass_image=None): override_settings = create_override_settings_dict(override_settings_texts) + if firstpass_image is not None: + enable_hr = True + batch_size = 1 + n_iter = 1 + p = processing.StableDiffusionProcessingTxt2Img( sd_model=shared.sd_model, outpath_samples=opts.outdir_samples or opts.outdir_txt2img_samples, @@ -38,6 +52,7 @@ def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, step hr_prompt=hr_prompt, hr_negative_prompt=hr_negative_prompt, override_settings=override_settings, + firstpass_image=firstpass_image, ) p.scripts = modules.scripts.scripts_txt2img diff --git a/modules/ui.py b/modules/ui.py index 52b15646ab5..3d548430d36 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -375,50 +375,60 @@ def create_ui(): show_progress=False, ) - txt2img_gallery, generation_info, html_info, html_log = create_output_panel("txt2img", opts.outdir_txt2img_samples, toprow) + output_panel = create_output_panel("txt2img", opts.outdir_txt2img_samples, toprow) + + txt2img_inputs = [ + dummy_component, + toprow.prompt, + toprow.negative_prompt, + toprow.ui_styles.dropdown, + steps, + sampler_name, + batch_count, + batch_size, + cfg_scale, + height, + width, + enable_hr, + denoising_strength, + hr_scale, + hr_upscaler, + hr_second_pass_steps, + hr_resize_x, + hr_resize_y, + hr_checkpoint_name, + hr_sampler_name, + hr_prompt, + hr_negative_prompt, + override_settings, + ] + custom_inputs + + txt2img_outputs = [ + output_panel.gallery, + output_panel.infotext, + output_panel.html_info, + output_panel.html_log, + ] txt2img_args = dict( fn=wrap_gradio_gpu_call(modules.txt2img.txt2img, extra_outputs=[None, '', '']), _js="submit", - inputs=[ - dummy_component, - toprow.prompt, - toprow.negative_prompt, - toprow.ui_styles.dropdown, - steps, - sampler_name, - batch_count, - batch_size, - cfg_scale, - height, - width, - enable_hr, - denoising_strength, - hr_scale, - hr_upscaler, - hr_second_pass_steps, - hr_resize_x, - hr_resize_y, - hr_checkpoint_name, - hr_sampler_name, - hr_prompt, - hr_negative_prompt, - override_settings, - - ] + custom_inputs, - - outputs=[ - txt2img_gallery, - generation_info, - html_info, - html_log, - ], + inputs=txt2img_inputs, + outputs=txt2img_outputs, show_progress=False, ) toprow.prompt.submit(**txt2img_args) toprow.submit.click(**txt2img_args) + output_panel.button_upscale.click( + fn=wrap_gradio_gpu_call(modules.txt2img.txt2img_upscale, extra_outputs=[None, '', '']), + _js="submit_txt2img_upscale", + inputs=txt2img_inputs[0:1] + [output_panel.gallery, dummy_component] + txt2img_inputs[1:], + outputs=txt2img_outputs, + show_progress=False, + ) + res_switch_btn.click(fn=None, _js="function(){switchWidthHeight('txt2img')}", inputs=None, outputs=None, show_progress=False) toprow.restore_progress_button.click( @@ -426,10 +436,10 @@ def create_ui(): _js="restoreProgressTxt2img", inputs=[dummy_component], outputs=[ - txt2img_gallery, - generation_info, - html_info, - html_log, + output_panel.gallery, + output_panel.infotext, + output_panel.html_info, + output_panel.html_log, ], show_progress=False, ) @@ -479,7 +489,7 @@ def create_ui(): toprow.negative_token_button.click(fn=wrap_queued_call(update_negative_prompt_token_counter), inputs=[toprow.negative_prompt, steps], outputs=[toprow.negative_token_counter]) extra_networks_ui = ui_extra_networks.create_ui(txt2img_interface, [txt2img_generation_tab], 'txt2img') - ui_extra_networks.setup_ui(extra_networks_ui, txt2img_gallery) + ui_extra_networks.setup_ui(extra_networks_ui, output_panel.gallery) extra_tabs.__exit__() @@ -710,7 +720,7 @@ def select_img2img_tab(tab): outputs=[inpaint_controls, mask_alpha], ) - img2img_gallery, generation_info, html_info, html_log = create_output_panel("img2img", opts.outdir_img2img_samples, toprow) + output_panel = create_output_panel("img2img", opts.outdir_img2img_samples, toprow) img2img_args = dict( fn=wrap_gradio_gpu_call(modules.img2img.img2img, extra_outputs=[None, '', '']), @@ -755,10 +765,10 @@ def select_img2img_tab(tab): img2img_batch_png_info_dir, ] + custom_inputs, outputs=[ - img2img_gallery, - generation_info, - html_info, - html_log, + output_panel.gallery, + output_panel.infotext, + output_panel.html_info, + output_panel.html_log, ], show_progress=False, ) @@ -796,10 +806,10 @@ def select_img2img_tab(tab): _js="restoreProgressImg2img", inputs=[dummy_component], outputs=[ - img2img_gallery, - generation_info, - html_info, - html_log, + output_panel.gallery, + output_panel.infotext, + output_panel.html_info, + output_panel.html_log, ], show_progress=False, ) @@ -839,7 +849,7 @@ def select_img2img_tab(tab): )) extra_networks_ui_img2img = ui_extra_networks.create_ui(img2img_interface, [img2img_generation_tab], 'img2img') - ui_extra_networks.setup_ui(extra_networks_ui_img2img, img2img_gallery) + ui_extra_networks.setup_ui(extra_networks_ui_img2img, output_panel.gallery) extra_tabs.__exit__() diff --git a/modules/ui_common.py b/modules/ui_common.py index f48ad426072..ff84197c1de 100644 --- a/modules/ui_common.py +++ b/modules/ui_common.py @@ -1,3 +1,4 @@ +import dataclasses import json import html import os @@ -104,7 +105,17 @@ def __init__(self, d=None): return gr.File.update(value=fullfns, visible=True), plaintext_to_html(f"Saved: {filenames[0]}") +@dataclasses.dataclass +class OutputPanel: + gallery = None + infotext = None + html_info = None + html_log = None + button_upscale = None + + def create_output_panel(tabname, outdir, toprow=None): + res = OutputPanel() def open_folder(f): if not os.path.exists(f): @@ -136,9 +147,8 @@ def open_folder(f): with gr.Column(variant='panel', elem_id=f"{tabname}_results_panel"): with gr.Group(elem_id=f"{tabname}_gallery_container"): - result_gallery = gr.Gallery(label='Output', show_label=False, elem_id=f"{tabname}_gallery", columns=4, preview=True, height=shared.opts.gallery_height or None) + res.gallery = gr.Gallery(label='Output', show_label=False, elem_id=f"{tabname}_gallery", columns=4, preview=True, height=shared.opts.gallery_height or None) - generation_info = None with gr.Row(elem_id=f"image_buttons_{tabname}", elem_classes="image-buttons"): open_folder_button = ToolButton(folder_symbol, elem_id=f'{tabname}_open_folder', visible=not shared.cmd_opts.hide_ui_dir_config, tooltip="Open images output directory.") @@ -152,6 +162,9 @@ def open_folder(f): 'extras': ToolButton('📐', elem_id=f'{tabname}_send_to_extras', tooltip="Send image and generation parameters to extras tab.") } + if tabname == 'txt2img': + res.button_upscale = ToolButton('✨', elem_id=f'{tabname}_upscale', tooltip="Create an upscaled version of the current image using hires fix settings.") + open_folder_button.click( fn=lambda: open_folder(shared.opts.outdir_samples or outdir), inputs=[], @@ -162,17 +175,17 @@ def open_folder(f): download_files = gr.File(None, file_count="multiple", interactive=False, show_label=False, visible=False, elem_id=f'download_files_{tabname}') with gr.Group(): - html_info = gr.HTML(elem_id=f'html_info_{tabname}', elem_classes="infotext") - html_log = gr.HTML(elem_id=f'html_log_{tabname}', elem_classes="html-log") + res.html_info = gr.HTML(elem_id=f'html_info_{tabname}', elem_classes="infotext") + res.html_log = gr.HTML(elem_id=f'html_log_{tabname}', elem_classes="html-log") - generation_info = gr.Textbox(visible=False, elem_id=f'generation_info_{tabname}') + res.infotext = gr.Textbox(visible=False, elem_id=f'generation_info_{tabname}') if tabname == 'txt2img' or tabname == 'img2img': generation_info_button = gr.Button(visible=False, elem_id=f"{tabname}_generation_info_button") generation_info_button.click( fn=update_generation_info, _js="function(x, y, z){ return [x, y, selected_gallery_index()] }", - inputs=[generation_info, html_info, html_info], - outputs=[html_info, html_info], + inputs=[res.infotext, res.html_info, res.html_info], + outputs=[res.html_info, res.html_info], show_progress=False, ) @@ -180,14 +193,14 @@ def open_folder(f): fn=call_queue.wrap_gradio_call(save_files), _js="(x, y, z, w) => [x, y, false, selected_gallery_index()]", inputs=[ - generation_info, - result_gallery, - html_info, - html_info, + res.infotext, + res.gallery, + res.html_info, + res.html_info, ], outputs=[ download_files, - html_log, + res.html_log, ], show_progress=False, ) @@ -196,21 +209,21 @@ def open_folder(f): fn=call_queue.wrap_gradio_call(save_files), _js="(x, y, z, w) => [x, y, true, selected_gallery_index()]", inputs=[ - generation_info, - result_gallery, - html_info, - html_info, + res.infotext, + res.gallery, + res.html_info, + res.html_info, ], outputs=[ download_files, - html_log, + res.html_log, ] ) else: - html_info_x = gr.HTML(elem_id=f'html_info_x_{tabname}') - html_info = gr.HTML(elem_id=f'html_info_{tabname}', elem_classes="infotext") - html_log = gr.HTML(elem_id=f'html_log_{tabname}') + res.infotext = gr.HTML(elem_id=f'html_info_x_{tabname}') + res.html_info = gr.HTML(elem_id=f'html_info_{tabname}', elem_classes="infotext") + res.html_log = gr.HTML(elem_id=f'html_log_{tabname}') paste_field_names = [] if tabname == "txt2img": @@ -220,11 +233,11 @@ def open_folder(f): for paste_tabname, paste_button in buttons.items(): parameters_copypaste.register_paste_params_button(parameters_copypaste.ParamBinding( - paste_button=paste_button, tabname=paste_tabname, source_tabname="txt2img" if tabname == "txt2img" else None, source_image_component=result_gallery, + paste_button=paste_button, tabname=paste_tabname, source_tabname="txt2img" if tabname == "txt2img" else None, source_image_component=res.gallery, paste_field_names=paste_field_names )) - return result_gallery, generation_info if tabname != "extras" else html_info_x, html_info, html_log + return res def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id): diff --git a/modules/ui_postprocessing.py b/modules/ui_postprocessing.py index 1edb68c5ced..8f09e6588a3 100644 --- a/modules/ui_postprocessing.py +++ b/modules/ui_postprocessing.py @@ -28,7 +28,7 @@ def create_ui(): toprow.create_inline_toprow_image() submit = toprow.submit - result_images, html_info_x, html_info, html_log = ui_common.create_output_panel("extras", shared.opts.outdir_extras_samples) + output_panel = ui_common.create_output_panel("extras", shared.opts.outdir_extras_samples) tab_single.select(fn=lambda: 0, inputs=[], outputs=[tab_index]) tab_batch.select(fn=lambda: 1, inputs=[], outputs=[tab_index]) @@ -48,9 +48,9 @@ def create_ui(): *script_inputs ], outputs=[ - result_images, - html_info_x, - html_log, + output_panel.gallery, + output_panel.infotext, + output_panel.html_log, ], show_progress=False, ) From c32c51a0fc49ca4fbfe02bfbae45ec49e7bf9876 Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Mon, 1 Jan 2024 19:20:54 +0200 Subject: [PATCH 0522/1174] Fix lint issue from 501993eb --- javascript/ui.js | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/javascript/ui.js b/javascript/ui.js index 78d1caee5bf..3430b3fefd8 100644 --- a/javascript/ui.js +++ b/javascript/ui.js @@ -151,7 +151,7 @@ function submit() { } function submit_txt2img_upscale() { - res = submit.apply(null, arguments); + var res = submit(...arguments); res[2] = selected_gallery_index(); From c2ea571005dab29b285e31a0ad4a97258360bf2d Mon Sep 17 00:00:00 2001 From: Jibaku789 <151478027+Jibaku789@users.noreply.github.com> Date: Mon, 1 Jan 2024 14:57:41 -0600 Subject: [PATCH 0523/1174] Add inpaint options to paste fields --- modules/ui.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/modules/ui.py b/modules/ui.py index 3d548430d36..4cdb0e9c0f3 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -840,6 +840,10 @@ def select_img2img_tab(tab): (toprow.ui_styles.dropdown, lambda d: d["Styles array"] if isinstance(d.get("Styles array"), list) else gr.update()), (denoising_strength, "Denoising strength"), (mask_blur, "Mask blur"), + (inpainting_mask_invert, 'Mask mode'), + (inpainting_fill, 'Masked content'), + (inpaint_full_res, 'Inpaint area'), + (inpaint_full_res_padding, 'Only masked padding, pixels'), *scripts.scripts_img2img.infotext_fields ] parameters_copypaste.add_paste_fields("img2img", init_img, img2img_paste_fields, override_settings) From a5b6a5a3adcc845237d872750ded34240cc6a810 Mon Sep 17 00:00:00 2001 From: Jibaku789 <151478027+Jibaku789@users.noreply.github.com> Date: Mon, 1 Jan 2024 14:58:55 -0600 Subject: [PATCH 0524/1174] Add inpaint options to img2img.py --- modules/img2img.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/modules/img2img.py b/modules/img2img.py index 04de8e62cf2..9e09c0a00cc 100644 --- a/modules/img2img.py +++ b/modules/img2img.py @@ -225,6 +225,18 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s if mask: p.extra_generation_params["Mask blur"] = mask_blur + if inpainting_mask_invert is not None: + p.extra_generation_params["Mask mode"] = inpainting_mask_invert + + if inpainting_fill is not None: + p.extra_generation_params["Masked content"] = inpainting_fill + + if inpaint_full_res is not None: + p.extra_generation_params["Inpaint area"] = inpaint_full_res + + if inpaint_full_res_padding is not None: + p.extra_generation_params["Only masked padding, pixels"] = inpaint_full_res_padding + with closing(p): if is_batch: assert not shared.cmd_opts.hide_ui_dir_config, "Launched with --hide-ui-dir-config, batch img2img disabled" From 1341b2208185cd89b0019bda2df63b406ec0cb5e Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Tue, 2 Jan 2024 06:47:26 +0300 Subject: [PATCH 0525/1174] add an option to hide upscaling progressbar --- modules/shared_options.py | 1 + modules/upscaler_utils.py | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/modules/shared_options.py b/modules/shared_options.py index cca3f7be3c7..63488f4e700 100644 --- a/modules/shared_options.py +++ b/modules/shared_options.py @@ -115,6 +115,7 @@ "memmon_poll_rate": OptionInfo(8, "VRAM usage polls per second during generation.", gr.Slider, {"minimum": 0, "maximum": 40, "step": 1}).info("0 = disable"), "samples_log_stdout": OptionInfo(False, "Always print all generation info to standard output"), "multiple_tqdm": OptionInfo(True, "Add a second progress bar to the console that shows progress for an entire job."), + "enable_upscale_progressbar": OptionInfo(True, "Show a progress bar in the console for tiled upscaling."), "print_hypernet_extra": OptionInfo(False, "Print extra hypernetwork information to console."), "list_hidden_files": OptionInfo(True, "Load models/files in hidden directories").info("directory is hidden if its name starts with \".\""), "disable_mmap_load_safetensors": OptionInfo(False, "Disable memmapping for loading .safetensors files.").info("fixes very slow loading speed in some cases"), diff --git a/modules/upscaler_utils.py b/modules/upscaler_utils.py index f5cb92d51b8..9379f512b12 100644 --- a/modules/upscaler_utils.py +++ b/modules/upscaler_utils.py @@ -47,7 +47,7 @@ def upscale_with_model( grid = images.split_grid(img, tile_size, tile_size, tile_overlap) newtiles = [] - with tqdm.tqdm(total=grid.tile_count, desc=desc) as p: + with tqdm.tqdm(total=grid.tile_count, desc=desc, disable=not shared.opts.enable_upscale_progressbar) as p: for y, h, row in grid.tiles: newrow = [] for x, w, tile in row: @@ -103,7 +103,7 @@ def tiled_upscale_2( ).type_as(img) weights = torch.zeros_like(result) logger.debug("Upscaling %s to %s with tiles", img.shape, result.shape) - with tqdm.tqdm(total=len(h_idx_list) * len(w_idx_list), desc=desc) as pbar: + with tqdm.tqdm(total=len(h_idx_list) * len(w_idx_list), desc=desc, disable=not shared.opts.enable_upscale_progressbar) as pbar: for h_idx in h_idx_list: if shared.state.interrupted or shared.state.skipped: break From 80873b1538e6ca0c7ebe558f8ce4213b06fd8307 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Tue, 2 Jan 2024 07:05:05 +0300 Subject: [PATCH 0526/1174] fix #14497 --- modules/img2img.py | 15 --------------- modules/infotext_utils.py | 12 ++++++++++++ modules/processing.py | 13 +++++++++++++ 3 files changed, 25 insertions(+), 15 deletions(-) diff --git a/modules/img2img.py b/modules/img2img.py index 9e09c0a00cc..f81405df5db 100644 --- a/modules/img2img.py +++ b/modules/img2img.py @@ -222,21 +222,6 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s if shared.opts.enable_console_prompts: print(f"\nimg2img: {prompt}", file=shared.progress_print_out) - if mask: - p.extra_generation_params["Mask blur"] = mask_blur - - if inpainting_mask_invert is not None: - p.extra_generation_params["Mask mode"] = inpainting_mask_invert - - if inpainting_fill is not None: - p.extra_generation_params["Masked content"] = inpainting_fill - - if inpaint_full_res is not None: - p.extra_generation_params["Inpaint area"] = inpaint_full_res - - if inpaint_full_res_padding is not None: - p.extra_generation_params["Only masked padding, pixels"] = inpaint_full_res_padding - with closing(p): if is_batch: assert not shared.cmd_opts.hide_ui_dir_config, "Launched with --hide-ui-dir-config, batch img2img disabled" diff --git a/modules/infotext_utils.py b/modules/infotext_utils.py index 26e9b94934f..e582ee479c6 100644 --- a/modules/infotext_utils.py +++ b/modules/infotext_utils.py @@ -312,6 +312,18 @@ def parse_generation_parameters(x: str): if "Hires negative prompt" not in res: res["Hires negative prompt"] = "" + if "Mask mode" not in res: + res["Mask mode"] = "Inpaint masked" + + if "Masked content" not in res: + res["Masked content"] = 'original' + + if "Inpaint area" not in res: + res["Inpaint area"] = "Whole picture" + + if "Masked area padding" not in res: + res["Masked area padding"] = 32 + restore_old_hires_fix_params(res) # Missing RNG means the default was set, which is GPU RNG diff --git a/modules/processing.py b/modules/processing.py index 045c7d7955e..84e7b1b4577 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -1530,6 +1530,7 @@ def init(self, all_prompts, all_seeds, all_subseeds): if self.inpainting_mask_invert: image_mask = ImageOps.invert(image_mask) + self.extra_generation_params["Mask mode"] = "Inpaint not masked" if self.mask_blur_x > 0: np_mask = np.array(image_mask) @@ -1543,6 +1544,9 @@ def init(self, all_prompts, all_seeds, all_subseeds): np_mask = cv2.GaussianBlur(np_mask, (1, kernel_size), self.mask_blur_y) image_mask = Image.fromarray(np_mask) + if self.mask_blur_x > 0 or self.mask_blur_y > 0: + self.extra_generation_params["Mask blur"] = self.mask_blur + if self.inpaint_full_res: self.mask_for_overlay = image_mask mask = image_mask.convert('L') @@ -1553,6 +1557,9 @@ def init(self, all_prompts, all_seeds, all_subseeds): mask = mask.crop(crop_region) image_mask = images.resize_image(2, mask, self.width, self.height) self.paste_to = (x1, y1, x2-x1, y2-y1) + + self.extra_generation_params["Inpaint area"] = "Only masked" + self.extra_generation_params["Masked area padding"] = self.inpaint_full_res_padding else: image_mask = images.resize_image(self.resize_mode, image_mask, self.width, self.height) np_mask = np.array(image_mask) @@ -1594,6 +1601,9 @@ def init(self, all_prompts, all_seeds, all_subseeds): if self.inpainting_fill != 1: image = masking.fill(image, latent_mask) + if self.inpainting_fill == 0: + self.extra_generation_params["Masked content"] = 'fill' + if add_color_corrections: self.color_corrections.append(setup_color_correction(image)) @@ -1643,8 +1653,11 @@ def init(self, all_prompts, all_seeds, all_subseeds): # this needs to be fixed to be done in sample() using actual seeds for batches if self.inpainting_fill == 2: self.init_latent = self.init_latent * self.mask + create_random_tensors(self.init_latent.shape[1:], all_seeds[0:self.init_latent.shape[0]]) * self.nmask + self.extra_generation_params["Masked content"] = 'latent noise' + elif self.inpainting_fill == 3: self.init_latent = self.init_latent * self.mask + self.extra_generation_params["Masked content"] = 'latent nothing' self.image_conditioning = self.img2img_image_conditioning(image * 2 - 1, self.init_latent, image_mask, self.mask_round) From 980970d39091e572500434c69660bc6eed22498d Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Tue, 2 Jan 2024 07:08:32 +0300 Subject: [PATCH 0527/1174] final touches --- modules/ui.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/ui.py b/modules/ui.py index 4cdb0e9c0f3..7116d71c582 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -843,7 +843,7 @@ def select_img2img_tab(tab): (inpainting_mask_invert, 'Mask mode'), (inpainting_fill, 'Masked content'), (inpaint_full_res, 'Inpaint area'), - (inpaint_full_res_padding, 'Only masked padding, pixels'), + (inpaint_full_res_padding, 'Masked area padding'), *scripts.scripts_img2img.infotext_fields ] parameters_copypaste.add_paste_fields("img2img", init_img, img2img_paste_fields, override_settings) From cf14a6a7aaf8ccb40552990785d5c9e400d93610 Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Sun, 31 Dec 2023 16:11:18 +0200 Subject: [PATCH 0528/1174] Refactor upscale_2 helper out of ScuNET/SwinIR; make sure devices are right --- .../ScuNET/scripts/scunet_model.py | 48 +++------- .../SwinIR/scripts/swinir_model.py | 62 ++----------- modules/upscaler_utils.py | 89 ++++++++++++++----- 3 files changed, 87 insertions(+), 112 deletions(-) diff --git a/extensions-builtin/ScuNET/scripts/scunet_model.py b/extensions-builtin/ScuNET/scripts/scunet_model.py index f799cb76dec..fe5e5a19265 100644 --- a/extensions-builtin/ScuNET/scripts/scunet_model.py +++ b/extensions-builtin/ScuNET/scripts/scunet_model.py @@ -1,13 +1,9 @@ import sys import PIL.Image -import numpy as np -import torch import modules.upscaler -from modules import devices, modelloader, script_callbacks, errors -from modules.shared import opts -from modules.upscaler_utils import tiled_upscale_2 +from modules import devices, errors, modelloader, script_callbacks, shared, upscaler_utils class UpscalerScuNET(modules.upscaler.Upscaler): @@ -40,46 +36,23 @@ def __init__(self, dirname): self.scalers = scalers def do_upscale(self, img: PIL.Image.Image, selected_file): - devices.torch_gc() - try: model = self.load_model(selected_file) except Exception as e: print(f"ScuNET: Unable to load model from {selected_file}: {e}", file=sys.stderr) return img - device = devices.get_device_for('scunet') - tile = opts.SCUNET_tile - h, w = img.height, img.width - np_img = np.array(img) - np_img = np_img[:, :, ::-1] # RGB to BGR - np_img = np_img.transpose((2, 0, 1)) / 255 # HWC to CHW - torch_img = torch.from_numpy(np_img).float().unsqueeze(0).to(device) # type: ignore - - if tile > h or tile > w: - _img = torch.zeros(1, 3, max(h, tile), max(w, tile), dtype=torch_img.dtype, device=torch_img.device) - _img[:, :, :h, :w] = torch_img # pad image - torch_img = _img - - with torch.no_grad(): - torch_output = tiled_upscale_2( - torch_img, - model, - tile_size=opts.SCUNET_tile, - tile_overlap=opts.SCUNET_tile_overlap, - scale=1, - device=devices.get_device_for('scunet'), - desc="ScuNET tiles", - ).squeeze(0) - torch_output = torch_output[:, :h * 1, :w * 1] # remove padding, if any - np_output: np.ndarray = torch_output.float().cpu().clamp_(0, 1).numpy() - del torch_img, torch_output + img = upscaler_utils.upscale_2( + img, + model, + tile_size=shared.opts.SCUNET_tile, + tile_overlap=shared.opts.SCUNET_tile_overlap, + scale=1, # ScuNET is a denoising model, not an upscaler + desc='ScuNET', + ) devices.torch_gc() - - output = np_output.transpose((1, 2, 0)) # CHW to HWC - output = output[:, :, ::-1] # BGR to RGB - return PIL.Image.fromarray((output * 255).astype(np.uint8)) + return img def load_model(self, path: str): device = devices.get_device_for('scunet') @@ -93,7 +66,6 @@ def load_model(self, path: str): def on_ui_settings(): import gradio as gr - from modules import shared shared.opts.add_option("SCUNET_tile", shared.OptionInfo(256, "Tile size for SCUNET upscalers.", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}, section=('upscaling', "Upscaling")).info("0 = no tiling")) shared.opts.add_option("SCUNET_tile_overlap", shared.OptionInfo(8, "Tile overlap for SCUNET upscalers.", gr.Slider, {"minimum": 0, "maximum": 64, "step": 1}, section=('upscaling', "Upscaling")).info("Low values = visible seam")) diff --git a/extensions-builtin/SwinIR/scripts/swinir_model.py b/extensions-builtin/SwinIR/scripts/swinir_model.py index 8a555c794f7..bc427feac24 100644 --- a/extensions-builtin/SwinIR/scripts/swinir_model.py +++ b/extensions-builtin/SwinIR/scripts/swinir_model.py @@ -1,14 +1,10 @@ import logging import sys -import numpy as np -import torch from PIL import Image -from modules import modelloader, devices, script_callbacks, shared -from modules.shared import opts +from modules import devices, modelloader, script_callbacks, shared, upscaler_utils from modules.upscaler import Upscaler, UpscalerData -from modules.upscaler_utils import tiled_upscale_2 SWINIR_MODEL_URL = "https://github.com/JingyunLiang/SwinIR/releases/download/v0.0/003_realSR_BSRGAN_DFOWMFC_s64w8_SwinIR-L_x4_GAN.pth" @@ -36,9 +32,7 @@ def __init__(self, dirname): self.scalers = scalers def do_upscale(self, img: Image.Image, model_file: str) -> Image.Image: - current_config = (model_file, opts.SWIN_tile) - - device = self._get_device() + current_config = (model_file, shared.opts.SWIN_tile) if self._cached_model_config == current_config: model = self._cached_model @@ -51,12 +45,13 @@ def do_upscale(self, img: Image.Image, model_file: str) -> Image.Image: self._cached_model = model self._cached_model_config = current_config - img = upscale( + img = upscaler_utils.upscale_2( img, model, - tile=opts.SWIN_tile, - tile_overlap=opts.SWIN_tile_overlap, - device=device, + tile_size=shared.opts.SWIN_tile, + tile_overlap=shared.opts.SWIN_tile_overlap, + scale=4, # TODO: This was hard-coded before too... + desc="SwinIR", ) devices.torch_gc() return img @@ -77,7 +72,7 @@ def load_model(self, path, scale=4): dtype=devices.dtype, expected_architecture="SwinIR", ) - if getattr(opts, 'SWIN_torch_compile', False): + if getattr(shared.opts, 'SWIN_torch_compile', False): try: model_descriptor.model.compile() except Exception: @@ -88,47 +83,6 @@ def _get_device(self): return devices.get_device_for('swinir') -def upscale( - img, - model, - *, - tile: int, - tile_overlap: int, - window_size=8, - scale=4, - device, -): - - img = np.array(img) - img = img[:, :, ::-1] - img = np.moveaxis(img, 2, 0) / 255 - img = torch.from_numpy(img).float() - img = img.unsqueeze(0).to(device, dtype=devices.dtype) - with torch.no_grad(), devices.autocast(): - _, _, h_old, w_old = img.size() - h_pad = (h_old // window_size + 1) * window_size - h_old - w_pad = (w_old // window_size + 1) * window_size - w_old - img = torch.cat([img, torch.flip(img, [2])], 2)[:, :, : h_old + h_pad, :] - img = torch.cat([img, torch.flip(img, [3])], 3)[:, :, :, : w_old + w_pad] - output = tiled_upscale_2( - img, - model, - tile_size=tile, - tile_overlap=tile_overlap, - scale=scale, - device=device, - desc="SwinIR tiles", - ) - output = output[..., : h_old * scale, : w_old * scale] - output = output.data.squeeze().float().cpu().clamp_(0, 1).numpy() - if output.ndim == 3: - output = np.transpose( - output[[2, 1, 0], :, :], (1, 2, 0) - ) # CHW-RGB to HCW-BGR - output = (output * 255.0).round().astype(np.uint8) # float32 to uint8 - return Image.fromarray(output, "RGB") - - def on_ui_settings(): import gradio as gr diff --git a/modules/upscaler_utils.py b/modules/upscaler_utils.py index 9379f512b12..e4c63f09758 100644 --- a/modules/upscaler_utils.py +++ b/modules/upscaler_utils.py @@ -11,23 +11,40 @@ logger = logging.getLogger(__name__) -def upscale_without_tiling(model, img: Image.Image): - img = np.array(img) - img = img[:, :, ::-1] - img = np.ascontiguousarray(np.transpose(img, (2, 0, 1))) / 255 - img = torch.from_numpy(img).float() - +def pil_image_to_torch_bgr(img: Image.Image) -> torch.Tensor: + img = np.array(img.convert("RGB")) + img = img[:, :, ::-1] # flip RGB to BGR + img = np.transpose(img, (2, 0, 1)) # HWC to CHW + img = np.ascontiguousarray(img) / 255 # Rescale to [0, 1] + return torch.from_numpy(img) + + +def torch_bgr_to_pil_image(tensor: torch.Tensor) -> Image.Image: + if tensor.ndim == 4: + # If we're given a tensor with a batch dimension, squeeze it out + # (but only if it's a batch of size 1). + if tensor.shape[0] != 1: + raise ValueError(f"{tensor.shape} does not describe a BCHW tensor") + tensor = tensor.squeeze(0) + assert tensor.ndim == 3, f"{tensor.shape} does not describe a CHW tensor" + # TODO: is `tensor.float().cpu()...numpy()` the most efficient idiom? + arr = tensor.float().cpu().clamp_(0, 1).numpy() # clamp + arr = 255.0 * np.moveaxis(arr, 0, 2) # CHW to HWC, rescale + arr = arr.astype(np.uint8) + arr = arr[:, :, ::-1] # flip BGR to RGB + return Image.fromarray(arr, "RGB") + + +def upscale_pil_patch(model, img: Image.Image) -> Image.Image: + """ + Upscale a given PIL image using the given model. + """ param = torch_utils.get_param(model) - img = img.unsqueeze(0).to(device=param.device, dtype=param.dtype) with torch.no_grad(): - output = model(img) - - output = output.squeeze().float().cpu().clamp_(0, 1).numpy() - output = 255. * np.moveaxis(output, 0, 2) - output = output.astype(np.uint8) - output = output[:, :, ::-1] - return Image.fromarray(output, 'RGB') + tensor = pil_image_to_torch_bgr(img).unsqueeze(0) # add batch dimension + tensor = tensor.to(device=param.device, dtype=param.dtype) + return torch_bgr_to_pil_image(model(tensor)) def upscale_with_model( @@ -40,7 +57,7 @@ def upscale_with_model( ) -> Image.Image: if tile_size <= 0: logger.debug("Upscaling %s without tiling", img) - output = upscale_without_tiling(model, img) + output = upscale_pil_patch(model, img) logger.debug("=> %s", output) return output @@ -52,7 +69,7 @@ def upscale_with_model( newrow = [] for x, w, tile in row: logger.debug("Tile (%d, %d) %s...", x, y, tile) - output = upscale_without_tiling(model, tile) + output = upscale_pil_patch(model, tile) scale_factor = output.width // tile.width logger.debug("=> %s (scale factor %s)", output, scale_factor) newrow.append([x * scale_factor, w * scale_factor, output]) @@ -71,19 +88,22 @@ def upscale_with_model( def tiled_upscale_2( - img, + img: torch.Tensor, model, *, tile_size: int, tile_overlap: int, scale: int, - device, desc="Tiled upscale", ): # Alternative implementation of `upscale_with_model` originally used by # SwinIR and ScuNET. It differs from `upscale_with_model` in that tiling and # weighting is done in PyTorch space, as opposed to `images.Grid` doing it in # Pillow space without weighting. + + # Grab the device the model is on, and use it. + device = torch_utils.get_param(model).device + b, c, h, w = img.size() tile_size = min(tile_size, h, w) @@ -100,7 +120,8 @@ def tiled_upscale_2( h * scale, w * scale, device=device, - ).type_as(img) + dtype=img.dtype, + ) weights = torch.zeros_like(result) logger.debug("Upscaling %s to %s with tiles", img.shape, result.shape) with tqdm.tqdm(total=len(h_idx_list) * len(w_idx_list), desc=desc, disable=not shared.opts.enable_upscale_progressbar) as pbar: @@ -112,11 +133,13 @@ def tiled_upscale_2( if shared.state.interrupted or shared.state.skipped: break + # Only move this patch to the device if it's not already there. in_patch = img[ ..., h_idx : h_idx + tile_size, w_idx : w_idx + tile_size, - ] + ].to(device=device) + out_patch = model(in_patch) result[ @@ -138,3 +161,29 @@ def tiled_upscale_2( output = result.div_(weights) return output + + +def upscale_2( + img: Image.Image, + model, + *, + tile_size: int, + tile_overlap: int, + scale: int, + desc: str, +): + """ + Convenience wrapper around `tiled_upscale_2` that handles PIL images. + """ + tensor = pil_image_to_torch_bgr(img).float().unsqueeze(0) # add batch dimension + + with torch.no_grad(): + output = tiled_upscale_2( + tensor, + model, + tile_size=tile_size, + tile_overlap=tile_overlap, + scale=scale, + desc=desc, + ) + return torch_bgr_to_pil_image(output) From 2cacbc124c49f45da5b66b79d9b0a3ab943472eb Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Sun, 31 Dec 2023 19:52:32 +0200 Subject: [PATCH 0529/1174] load_spandrel_model: make `half` `prefer_half` As discussed with the Spandrel folks, it's good to heed Spandrel's "supports half precision" flag to avoid e.g. black blotches and what-not. --- modules/modelloader.py | 20 ++++++++++++++------ modules/realesrgan_model.py | 2 +- 2 files changed, 15 insertions(+), 7 deletions(-) diff --git a/modules/modelloader.py b/modules/modelloader.py index a7194137571..e100bb2469c 100644 --- a/modules/modelloader.py +++ b/modules/modelloader.py @@ -139,23 +139,31 @@ def load_upscalers(): def load_spandrel_model( - path: str, + path: str | os.PathLike, *, device: str | torch.device | None, - half: bool = False, + prefer_half: bool = False, dtype: str | torch.dtype | None = None, expected_architecture: str | None = None, ) -> spandrel.ModelDescriptor: import spandrel - model_descriptor = spandrel.ModelLoader(device=device).load_from_file(path) + model_descriptor = spandrel.ModelLoader(device=device).load_from_file(str(path)) if expected_architecture and model_descriptor.architecture != expected_architecture: logger.warning( f"Model {path!r} is not a {expected_architecture!r} model (got {model_descriptor.architecture!r})", ) - if half: - model_descriptor.model.half() + half = False + if prefer_half: + if model_descriptor.supports_half: + model_descriptor.model.half() + half = True + else: + logger.info("Model %s does not support half precision, ignoring --half", path) if dtype: model_descriptor.model.to(dtype=dtype) model_descriptor.model.eval() - logger.debug("Loaded %s from %s (device=%s, half=%s, dtype=%s)", model_descriptor, path, device, half, dtype) + logger.debug( + "Loaded %s from %s (device=%s, half=%s, dtype=%s)", + model_descriptor, path, device, half, dtype, + ) return model_descriptor diff --git a/modules/realesrgan_model.py b/modules/realesrgan_model.py index 4d35b695c3b..ff9d8ac0d69 100644 --- a/modules/realesrgan_model.py +++ b/modules/realesrgan_model.py @@ -39,7 +39,7 @@ def do_upscale(self, img, path): model_descriptor = modelloader.load_spandrel_model( info.local_data_path, device=self.device, - half=(not cmd_opts.no_half and not cmd_opts.upcast_sampling), + prefer_half=(not cmd_opts.no_half and not cmd_opts.upcast_sampling), expected_architecture="ESRGAN", # "RealESRGAN" isn't a specific thing for Spandrel ) return upscale_with_model( From 62bd7624d2adb123e84d4625804e9ad94d1db018 Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Tue, 2 Jan 2024 10:40:26 +0200 Subject: [PATCH 0530/1174] Remove licenses for code that's no longer copy-pasted; adjust README --- README.md | 11 +- html/licenses.html | 310 +-------------------------------------------- 2 files changed, 7 insertions(+), 314 deletions(-) diff --git a/README.md b/README.md index 9f9f33b1295..72908fa7764 100644 --- a/README.md +++ b/README.md @@ -151,11 +151,12 @@ Licenses for borrowed code can be found in `Settings -> Licenses` screen, and al - Stable Diffusion - https://github.com/Stability-AI/stablediffusion, https://github.com/CompVis/taming-transformers - k-diffusion - https://github.com/crowsonkb/k-diffusion.git -- GFPGAN - https://github.com/TencentARC/GFPGAN.git -- CodeFormer - https://github.com/sczhou/CodeFormer -- ESRGAN - https://github.com/xinntao/ESRGAN -- SwinIR - https://github.com/JingyunLiang/SwinIR -- Swin2SR - https://github.com/mv-lab/swin2sr +- Spandrel - https://github.com/chaiNNer-org/spandrel implementing + - GFPGAN - https://github.com/TencentARC/GFPGAN.git + - CodeFormer - https://github.com/sczhou/CodeFormer + - ESRGAN - https://github.com/xinntao/ESRGAN + - SwinIR - https://github.com/JingyunLiang/SwinIR + - Swin2SR - https://github.com/mv-lab/swin2sr - LDSR - https://github.com/Hafiidz/latent-diffusion - MiDaS - https://github.com/isl-org/MiDaS - Ideas for optimizations - https://github.com/basujindal/stable-diffusion diff --git a/html/licenses.html b/html/licenses.html index ef6f2c0a42b..9f5d1e9dc5c 100644 --- a/html/licenses.html +++ b/html/licenses.html @@ -4,107 +4,6 @@ #licenses pre { margin: 1em 0 2em 0;} -

    CodeFormer

    -Parts of CodeFormer code had to be copied to be compatible with GFPGAN. -
    -S-Lab License 1.0
    -
    -Copyright 2022 S-Lab
    -
    -Redistribution and use for non-commercial purpose in source and
    -binary forms, with or without modification, are permitted provided
    -that the following conditions are met:
    -
    -1. Redistributions of source code must retain the above copyright
    -   notice, this list of conditions and the following disclaimer.
    -
    -2. Redistributions in binary form must reproduce the above copyright
    -   notice, this list of conditions and the following disclaimer in
    -   the documentation and/or other materials provided with the
    -   distribution.
    -
    -3. Neither the name of the copyright holder nor the names of its
    -   contributors may be used to endorse or promote products derived
    -   from this software without specific prior written permission.
    -
    -THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
    -"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
    -LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
    -A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
    -HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
    -SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
    -LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
    -DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
    -THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
    -(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
    -OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
    -
    -In the event that redistribution and/or use for commercial purpose in
    -source or binary forms, with or without modification is required,
    -please contact the contributor(s) of the work.
    -
    - - -

    ESRGAN

    -Code for architecture and reading models copied. -
    -MIT License
    -
    -Copyright (c) 2021 victorca25
    -
    -Permission is hereby granted, free of charge, to any person obtaining a copy
    -of this software and associated documentation files (the "Software"), to deal
    -in the Software without restriction, including without limitation the rights
    -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
    -copies of the Software, and to permit persons to whom the Software is
    -furnished to do so, subject to the following conditions:
    -
    -The above copyright notice and this permission notice shall be included in all
    -copies or substantial portions of the Software.
    -
    -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
    -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
    -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
    -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
    -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
    -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
    -SOFTWARE.
    -
    - -

    Real-ESRGAN

    -Some code is copied to support ESRGAN models. -
    -BSD 3-Clause License
    -
    -Copyright (c) 2021, Xintao Wang
    -All rights reserved.
    -
    -Redistribution and use in source and binary forms, with or without
    -modification, are permitted provided that the following conditions are met:
    -
    -1. Redistributions of source code must retain the above copyright notice, this
    -   list of conditions and the following disclaimer.
    -
    -2. Redistributions in binary form must reproduce the above copyright notice,
    -   this list of conditions and the following disclaimer in the documentation
    -   and/or other materials provided with the distribution.
    -
    -3. Neither the name of the copyright holder nor the names of its
    -   contributors may be used to endorse or promote products derived from
    -   this software without specific prior written permission.
    -
    -THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
    -AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
    -IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
    -DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
    -FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
    -DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
    -SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
    -CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
    -OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
    -OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
    -
    -

    InvokeAI

    Some code for compatibility with OSX is taken from lstein's repository.
    @@ -183,213 +82,6 @@ 

    SwinIR

    -Code added by contributors, most likely copied from this repository. - -
    -                                 Apache License
    -                           Version 2.0, January 2004
    -                        http://www.apache.org/licenses/
    -
    -   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
    -
    -   1. Definitions.
    -
    -      "License" shall mean the terms and conditions for use, reproduction,
    -      and distribution as defined by Sections 1 through 9 of this document.
    -
    -      "Licensor" shall mean the copyright owner or entity authorized by
    -      the copyright owner that is granting the License.
    -
    -      "Legal Entity" shall mean the union of the acting entity and all
    -      other entities that control, are controlled by, or are under common
    -      control with that entity. For the purposes of this definition,
    -      "control" means (i) the power, direct or indirect, to cause the
    -      direction or management of such entity, whether by contract or
    -      otherwise, or (ii) ownership of fifty percent (50%) or more of the
    -      outstanding shares, or (iii) beneficial ownership of such entity.
    -
    -      "You" (or "Your") shall mean an individual or Legal Entity
    -      exercising permissions granted by this License.
    -
    -      "Source" form shall mean the preferred form for making modifications,
    -      including but not limited to software source code, documentation
    -      source, and configuration files.
    -
    -      "Object" form shall mean any form resulting from mechanical
    -      transformation or translation of a Source form, including but
    -      not limited to compiled object code, generated documentation,
    -      and conversions to other media types.
    -
    -      "Work" shall mean the work of authorship, whether in Source or
    -      Object form, made available under the License, as indicated by a
    -      copyright notice that is included in or attached to the work
    -      (an example is provided in the Appendix below).
    -
    -      "Derivative Works" shall mean any work, whether in Source or Object
    -      form, that is based on (or derived from) the Work and for which the
    -      editorial revisions, annotations, elaborations, or other modifications
    -      represent, as a whole, an original work of authorship. For the purposes
    -      of this License, Derivative Works shall not include works that remain
    -      separable from, or merely link (or bind by name) to the interfaces of,
    -      the Work and Derivative Works thereof.
    -
    -      "Contribution" shall mean any work of authorship, including
    -      the original version of the Work and any modifications or additions
    -      to that Work or Derivative Works thereof, that is intentionally
    -      submitted to Licensor for inclusion in the Work by the copyright owner
    -      or by an individual or Legal Entity authorized to submit on behalf of
    -      the copyright owner. For the purposes of this definition, "submitted"
    -      means any form of electronic, verbal, or written communication sent
    -      to the Licensor or its representatives, including but not limited to
    -      communication on electronic mailing lists, source code control systems,
    -      and issue tracking systems that are managed by, or on behalf of, the
    -      Licensor for the purpose of discussing and improving the Work, but
    -      excluding communication that is conspicuously marked or otherwise
    -      designated in writing by the copyright owner as "Not a Contribution."
    -
    -      "Contributor" shall mean Licensor and any individual or Legal Entity
    -      on behalf of whom a Contribution has been received by Licensor and
    -      subsequently incorporated within the Work.
    -
    -   2. Grant of Copyright License. Subject to the terms and conditions of
    -      this License, each Contributor hereby grants to You a perpetual,
    -      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
    -      copyright license to reproduce, prepare Derivative Works of,
    -      publicly display, publicly perform, sublicense, and distribute the
    -      Work and such Derivative Works in Source or Object form.
    -
    -   3. Grant of Patent License. Subject to the terms and conditions of
    -      this License, each Contributor hereby grants to You a perpetual,
    -      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
    -      (except as stated in this section) patent license to make, have made,
    -      use, offer to sell, sell, import, and otherwise transfer the Work,
    -      where such license applies only to those patent claims licensable
    -      by such Contributor that are necessarily infringed by their
    -      Contribution(s) alone or by combination of their Contribution(s)
    -      with the Work to which such Contribution(s) was submitted. If You
    -      institute patent litigation against any entity (including a
    -      cross-claim or counterclaim in a lawsuit) alleging that the Work
    -      or a Contribution incorporated within the Work constitutes direct
    -      or contributory patent infringement, then any patent licenses
    -      granted to You under this License for that Work shall terminate
    -      as of the date such litigation is filed.
    -
    -   4. Redistribution. You may reproduce and distribute copies of the
    -      Work or Derivative Works thereof in any medium, with or without
    -      modifications, and in Source or Object form, provided that You
    -      meet the following conditions:
    -
    -      (a) You must give any other recipients of the Work or
    -          Derivative Works a copy of this License; and
    -
    -      (b) You must cause any modified files to carry prominent notices
    -          stating that You changed the files; and
    -
    -      (c) You must retain, in the Source form of any Derivative Works
    -          that You distribute, all copyright, patent, trademark, and
    -          attribution notices from the Source form of the Work,
    -          excluding those notices that do not pertain to any part of
    -          the Derivative Works; and
    -
    -      (d) If the Work includes a "NOTICE" text file as part of its
    -          distribution, then any Derivative Works that You distribute must
    -          include a readable copy of the attribution notices contained
    -          within such NOTICE file, excluding those notices that do not
    -          pertain to any part of the Derivative Works, in at least one
    -          of the following places: within a NOTICE text file distributed
    -          as part of the Derivative Works; within the Source form or
    -          documentation, if provided along with the Derivative Works; or,
    -          within a display generated by the Derivative Works, if and
    -          wherever such third-party notices normally appear. The contents
    -          of the NOTICE file are for informational purposes only and
    -          do not modify the License. You may add Your own attribution
    -          notices within Derivative Works that You distribute, alongside
    -          or as an addendum to the NOTICE text from the Work, provided
    -          that such additional attribution notices cannot be construed
    -          as modifying the License.
    -
    -      You may add Your own copyright statement to Your modifications and
    -      may provide additional or different license terms and conditions
    -      for use, reproduction, or distribution of Your modifications, or
    -      for any such Derivative Works as a whole, provided Your use,
    -      reproduction, and distribution of the Work otherwise complies with
    -      the conditions stated in this License.
    -
    -   5. Submission of Contributions. Unless You explicitly state otherwise,
    -      any Contribution intentionally submitted for inclusion in the Work
    -      by You to the Licensor shall be under the terms and conditions of
    -      this License, without any additional terms or conditions.
    -      Notwithstanding the above, nothing herein shall supersede or modify
    -      the terms of any separate license agreement you may have executed
    -      with Licensor regarding such Contributions.
    -
    -   6. Trademarks. This License does not grant permission to use the trade
    -      names, trademarks, service marks, or product names of the Licensor,
    -      except as required for reasonable and customary use in describing the
    -      origin of the Work and reproducing the content of the NOTICE file.
    -
    -   7. Disclaimer of Warranty. Unless required by applicable law or
    -      agreed to in writing, Licensor provides the Work (and each
    -      Contributor provides its Contributions) on an "AS IS" BASIS,
    -      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
    -      implied, including, without limitation, any warranties or conditions
    -      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
    -      PARTICULAR PURPOSE. You are solely responsible for determining the
    -      appropriateness of using or redistributing the Work and assume any
    -      risks associated with Your exercise of permissions under this License.
    -
    -   8. Limitation of Liability. In no event and under no legal theory,
    -      whether in tort (including negligence), contract, or otherwise,
    -      unless required by applicable law (such as deliberate and grossly
    -      negligent acts) or agreed to in writing, shall any Contributor be
    -      liable to You for damages, including any direct, indirect, special,
    -      incidental, or consequential damages of any character arising as a
    -      result of this License or out of the use or inability to use the
    -      Work (including but not limited to damages for loss of goodwill,
    -      work stoppage, computer failure or malfunction, or any and all
    -      other commercial damages or losses), even if such Contributor
    -      has been advised of the possibility of such damages.
    -
    -   9. Accepting Warranty or Additional Liability. While redistributing
    -      the Work or Derivative Works thereof, You may choose to offer,
    -      and charge a fee for, acceptance of support, warranty, indemnity,
    -      or other liability obligations and/or rights consistent with this
    -      License. However, in accepting such obligations, You may act only
    -      on Your own behalf and on Your sole responsibility, not on behalf
    -      of any other Contributor, and only if You agree to indemnify,
    -      defend, and hold each Contributor harmless for any liability
    -      incurred by, or claims asserted against, such Contributor by reason
    -      of your accepting any such warranty or additional liability.
    -
    -   END OF TERMS AND CONDITIONS
    -
    -   APPENDIX: How to apply the Apache License to your work.
    -
    -      To apply the Apache License to your work, attach the following
    -      boilerplate notice, with the fields enclosed by brackets "[]"
    -      replaced with your own identifying information. (Don't include
    -      the brackets!)  The text should be enclosed in the appropriate
    -      comment syntax for the file format. We also recommend that a
    -      file or class name and description of purpose be included on the
    -      same "printed page" as the copyright notice for easier
    -      identification within third-party archives.
    -
    -   Copyright [2021] [SwinIR Authors]
    -
    -   Licensed under the Apache License, Version 2.0 (the "License");
    -   you may not use this file except in compliance with the License.
    -   You may obtain a copy of the License at
    -
    -       http://www.apache.org/licenses/LICENSE-2.0
    -
    -   Unless required by applicable law or agreed to in writing, software
    -   distributed under the License is distributed on an "AS IS" BASIS,
    -   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    -   See the License for the specific language governing permissions and
    -   limitations under the License.
    -
    -

    Memory Efficient Attention

    The sub-quadratic cross attention optimization uses modified code from the Memory Efficient Attention package that Alex Birch optimized for 3D tensors. This license is updated to reflect that.
    @@ -687,4 +379,4 @@ 

    TAESD \ No newline at end of file +

    From 7ad6899bf987a8ee615efbcfc99562457f89cd8b Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Tue, 2 Jan 2024 17:14:05 +0200 Subject: [PATCH 0531/1174] torch_bgr_to_pil_image: round, don't truncate This matches what `realesrgan` does. --- modules/upscaler_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/upscaler_utils.py b/modules/upscaler_utils.py index e4c63f09758..4f1417cf06a 100644 --- a/modules/upscaler_utils.py +++ b/modules/upscaler_utils.py @@ -30,7 +30,7 @@ def torch_bgr_to_pil_image(tensor: torch.Tensor) -> Image.Image: # TODO: is `tensor.float().cpu()...numpy()` the most efficient idiom? arr = tensor.float().cpu().clamp_(0, 1).numpy() # clamp arr = 255.0 * np.moveaxis(arr, 0, 2) # CHW to HWC, rescale - arr = arr.astype(np.uint8) + arr = arr.round().astype(np.uint8) arr = arr[:, :, ::-1] # flip BGR to RGB return Image.fromarray(arr, "RGB") From fccd0b00c2ca17360b7b956cd2e9bd1fb42c017d Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Wed, 3 Jan 2024 18:55:43 +0900 Subject: [PATCH 0532/1174] reduce unnecessary re-indexing extra networks dir --- modules/ui_extra_networks.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index beea1316083..e1c679eca35 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -417,21 +417,21 @@ def create_ui(interface: gr.Blocks, unrelated_tabs, tabname): dropdown_sort.change(fn=lambda: None, _js="function(){ applyExtraNetworkSort('" + tabname + "'); }") + def create_html(): + ui.pages_contents = [pg.create_html(ui.tabname) for pg in ui.stored_extra_pages] + def pages_html(): if not ui.pages_contents: - return refresh() - + create_html() return ui.pages_contents def refresh(): for pg in ui.stored_extra_pages: pg.refresh() - - ui.pages_contents = [pg.create_html(ui.tabname) for pg in ui.stored_extra_pages] - + create_html() return ui.pages_contents - interface.load(fn=pages_html, inputs=[], outputs=[*ui.pages]) + interface.load(fn=pages_html, inputs=[], outputs=ui.pages) button_refresh.click(fn=refresh, inputs=[], outputs=ui.pages) return ui From bfc48fbc244130770991fab284f6fedcef2054e7 Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Thu, 4 Jan 2024 03:46:05 +0900 Subject: [PATCH 0533/1174] paste infotext cast int as float --- modules/infotext_utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/modules/infotext_utils.py b/modules/infotext_utils.py index e582ee479c6..a21329e6238 100644 --- a/modules/infotext_utils.py +++ b/modules/infotext_utils.py @@ -477,6 +477,8 @@ def paste_func(prompt): if valtype == bool and v == "False": val = False + elif valtype == int: + val = float(v) else: val = valtype(v) From dfdc51246c678b585e1bdfdb7d2f202b0ca0e362 Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Wed, 3 Jan 2024 22:38:13 +0200 Subject: [PATCH 0534/1174] SwinIR: use prefer_half --- extensions-builtin/SwinIR/scripts/swinir_model.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/extensions-builtin/SwinIR/scripts/swinir_model.py b/extensions-builtin/SwinIR/scripts/swinir_model.py index bc427feac24..6a8e21b0250 100644 --- a/extensions-builtin/SwinIR/scripts/swinir_model.py +++ b/extensions-builtin/SwinIR/scripts/swinir_model.py @@ -1,6 +1,7 @@ import logging import sys +import torch from PIL import Image from modules import devices, modelloader, script_callbacks, shared, upscaler_utils @@ -69,7 +70,7 @@ def load_model(self, path, scale=4): model_descriptor = modelloader.load_spandrel_model( filename, device=self._get_device(), - dtype=devices.dtype, + prefer_half=(devices.dtype == torch.float16), expected_architecture="SwinIR", ) if getattr(shared.opts, 'SWIN_torch_compile', False): From 3d31d5c27beb433fa37b30f135ec06a278a87630 Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Wed, 3 Jan 2024 22:38:49 +0200 Subject: [PATCH 0535/1174] SwinIR: pass model.scale --- extensions-builtin/SwinIR/scripts/swinir_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/extensions-builtin/SwinIR/scripts/swinir_model.py b/extensions-builtin/SwinIR/scripts/swinir_model.py index 6a8e21b0250..16bf9b792fc 100644 --- a/extensions-builtin/SwinIR/scripts/swinir_model.py +++ b/extensions-builtin/SwinIR/scripts/swinir_model.py @@ -51,7 +51,7 @@ def do_upscale(self, img: Image.Image, model_file: str) -> Image.Image: model, tile_size=shared.opts.SWIN_tile, tile_overlap=shared.opts.SWIN_tile_overlap, - scale=4, # TODO: This was hard-coded before too... + scale=model.scale, desc="SwinIR", ) devices.torch_gc() From 62470ee23443cb2ad3943a152ccae26a689c86e1 Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Wed, 3 Jan 2024 22:39:12 +0200 Subject: [PATCH 0536/1174] upscale_2: cast image to model's dtype --- modules/upscaler_utils.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/modules/upscaler_utils.py b/modules/upscaler_utils.py index e4c63f09758..5db74877617 100644 --- a/modules/upscaler_utils.py +++ b/modules/upscaler_utils.py @@ -94,6 +94,7 @@ def tiled_upscale_2( tile_size: int, tile_overlap: int, scale: int, + device: torch.device, desc="Tiled upscale", ): # Alternative implementation of `upscale_with_model` originally used by @@ -101,9 +102,6 @@ def tiled_upscale_2( # weighting is done in PyTorch space, as opposed to `images.Grid` doing it in # Pillow space without weighting. - # Grab the device the model is on, and use it. - device = torch_utils.get_param(model).device - b, c, h, w = img.size() tile_size = min(tile_size, h, w) @@ -175,7 +173,8 @@ def upscale_2( """ Convenience wrapper around `tiled_upscale_2` that handles PIL images. """ - tensor = pil_image_to_torch_bgr(img).float().unsqueeze(0) # add batch dimension + param = torch_utils.get_param(model) + tensor = pil_image_to_torch_bgr(img).to(dtype=param.dtype).unsqueeze(0) # add batch dimension with torch.no_grad(): output = tiled_upscale_2( @@ -185,5 +184,6 @@ def upscale_2( tile_overlap=tile_overlap, scale=scale, desc=desc, + device=param.device, ) return torch_bgr_to_pil_image(output) From 50158a1fc9b4dd47a7bef70d34fbb0b30d5e8b47 Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Thu, 4 Jan 2024 06:21:53 +0900 Subject: [PATCH 0537/1174] handle config.json failed to load --- modules/launch_utils.py | 4 +++- modules/options.py | 12 +++++++++--- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/modules/launch_utils.py b/modules/launch_utils.py index c2cbd8ce73e..e2ad412a6bb 100644 --- a/modules/launch_utils.py +++ b/modules/launch_utils.py @@ -249,7 +249,9 @@ def list_extensions(settings_file): with open(settings_file, "r", encoding="utf8") as file: settings = json.load(file) except Exception: - errors.report("Could not load settings", exc_info=True) + errors.report(f'\nCould not load settings\nThe config file "{settings_file}" is likely corrupted\nIt has been moved to the "tmp/config.json"\nReverting config to default\n\n''', exc_info=True) + os.replace(settings_file, os.path.join(script_path, "tmp", "config.json")) + settings = {} disabled_extensions = set(settings.get('disabled_extensions', [])) disable_all_extensions = settings.get('disable_all_extensions', 'none') diff --git a/modules/options.py b/modules/options.py index 09ff9403ddc..503b40e98db 100644 --- a/modules/options.py +++ b/modules/options.py @@ -1,3 +1,4 @@ +import os import json import sys from dataclasses import dataclass @@ -6,6 +7,7 @@ from modules import errors from modules.shared_cmd_options import cmd_opts +from modules.paths_internal import script_path class OptionInfo: @@ -193,9 +195,13 @@ def same_type(self, x, y): return type_x == type_y def load(self, filename): - with open(filename, "r", encoding="utf8") as file: - self.data = json.load(file) - + try: + with open(filename, "r", encoding="utf8") as file: + self.data = json.load(file) + except Exception: + errors.report(f'\nCould not load settings\nThe config file "{filename}" is likely corrupted\nIt has been moved to the "tmp/config.json"\nReverting config to default\n\n''', exc_info=True) + os.replace(filename, os.path.join(script_path, "tmp", "config.json")) + self.data = {} # 1.6.0 VAE defaults if self.data.get('sd_vae_as_default') is not None and self.data.get('sd_vae_overrides_per_model_preferences') is None: self.data['sd_vae_overrides_per_model_preferences'] = not self.data.get('sd_vae_as_default') From d9034b48a526f0a0c3e8f0dbf7c171bf4f0597fd Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Thu, 4 Jan 2024 00:16:58 +0200 Subject: [PATCH 0538/1174] Avoid unnecessary `isfile`/`exists` calls --- modules/cache.py | 17 ++++++++--------- modules/extensions.py | 11 ++++++----- modules/extra_networks.py | 7 ++++--- modules/infotext_utils.py | 4 +++- modules/launch_utils.py | 7 ++++--- modules/postprocessing.py | 7 ++++--- modules/shared_init.py | 4 +++- modules/ui_gradio_extensions.py | 8 +++----- modules/ui_loadsave.py | 5 +++-- modules/util.py | 6 +++--- 10 files changed, 41 insertions(+), 35 deletions(-) diff --git a/modules/cache.py b/modules/cache.py index 2d37e7b99d5..a9822a0eb21 100644 --- a/modules/cache.py +++ b/modules/cache.py @@ -62,16 +62,15 @@ def cache(subsection): if cache_data is None: with cache_lock: if cache_data is None: - if not os.path.isfile(cache_filename): + try: + with open(cache_filename, "r", encoding="utf8") as file: + cache_data = json.load(file) + except FileNotFoundError: + cache_data = {} + except Exception: + os.replace(cache_filename, os.path.join(script_path, "tmp", "cache.json")) + print('[ERROR] issue occurred while trying to read cache.json, move current cache to tmp/cache.json and create new cache') cache_data = {} - else: - try: - with open(cache_filename, "r", encoding="utf8") as file: - cache_data = json.load(file) - except Exception: - os.replace(cache_filename, os.path.join(script_path, "tmp", "cache.json")) - print('[ERROR] issue occurred while trying to read cache.json, move current cache to tmp/cache.json and create new cache') - cache_data = {} s = cache_data.get(subsection, {}) cache_data[subsection] = s diff --git a/modules/extensions.py b/modules/extensions.py index 1899cd52975..99e7ee60f64 100644 --- a/modules/extensions.py +++ b/modules/extensions.py @@ -32,11 +32,12 @@ def __init__(self, path, canonical_name): self.config = configparser.ConfigParser() filepath = os.path.join(path, self.filename) - if os.path.isfile(filepath): - try: - self.config.read(filepath) - except Exception: - errors.report(f"Error reading {self.filename} for extension {canonical_name}.", exc_info=True) + # `self.config.read()` will quietly swallow OSErrors (which FileNotFoundError is), + # so no need to check whether the file exists beforehand. + try: + self.config.read(filepath) + except Exception: + errors.report(f"Error reading {self.filename} for extension {canonical_name}.", exc_info=True) self.canonical_name = self.config.get("Extension", "Name", fallback=canonical_name) self.canonical_name = canonical_name.lower().strip() diff --git a/modules/extra_networks.py b/modules/extra_networks.py index b9533677887..cd030fa3184 100644 --- a/modules/extra_networks.py +++ b/modules/extra_networks.py @@ -215,9 +215,10 @@ def get_user_metadata(filename): metadata = {} try: - if os.path.isfile(metadata_filename): - with open(metadata_filename, "r", encoding="utf8") as file: - metadata = json.load(file) + with open(metadata_filename, "r", encoding="utf8") as file: + metadata = json.load(file) + except FileNotFoundError: + pass except Exception as e: errors.display(e, f"reading extra network user metadata from {metadata_filename}") diff --git a/modules/infotext_utils.py b/modules/infotext_utils.py index e582ee479c6..6978a0bf0f0 100644 --- a/modules/infotext_utils.py +++ b/modules/infotext_utils.py @@ -453,9 +453,11 @@ def connect_paste(button, paste_fields, input_comp, override_settings_component, def paste_func(prompt): if not prompt and not shared.cmd_opts.hide_ui_dir_config: filename = os.path.join(data_path, "params.txt") - if os.path.exists(filename): + try: with open(filename, "r", encoding="utf8") as file: prompt = file.read() + except OSError: + pass params = parse_generation_parameters(prompt) script_callbacks.infotext_pasted_callback(prompt, params) diff --git a/modules/launch_utils.py b/modules/launch_utils.py index c2cbd8ce73e..febd8c24a56 100644 --- a/modules/launch_utils.py +++ b/modules/launch_utils.py @@ -245,9 +245,10 @@ def list_extensions(settings_file): settings = {} try: - if os.path.isfile(settings_file): - with open(settings_file, "r", encoding="utf8") as file: - settings = json.load(file) + with open(settings_file, "r", encoding="utf8") as file: + settings = json.load(file) + except FileNotFoundError: + pass except Exception: errors.report("Could not load settings", exc_info=True) diff --git a/modules/postprocessing.py b/modules/postprocessing.py index 7850328f63f..7449b0dc51c 100644 --- a/modules/postprocessing.py +++ b/modules/postprocessing.py @@ -97,11 +97,12 @@ def get_images(extras_mode, image, image_folder, input_dir): if pp.caption: caption_filename = os.path.splitext(fullfn)[0] + ".txt" - if os.path.isfile(caption_filename): + existing_caption = "" + try: with open(caption_filename, encoding="utf8") as file: existing_caption = file.read().strip() - else: - existing_caption = "" + except FileNotFoundError: + pass action = shared.opts.postprocessing_existing_caption_action if action == 'Prepend' and existing_caption: diff --git a/modules/shared_init.py b/modules/shared_init.py index d3fb687e0cd..586be3423d0 100644 --- a/modules/shared_init.py +++ b/modules/shared_init.py @@ -18,8 +18,10 @@ def initialize(): shared.options_templates = shared_options.options_templates shared.opts = options.Options(shared_options.options_templates, shared_options.restricted_opts) shared.restricted_opts = shared_options.restricted_opts - if os.path.exists(shared.config_filename): + try: shared.opts.load(shared.config_filename) + except FileNotFoundError: + pass from modules import devices devices.device, devices.device_interrogate, devices.device_gfpgan, devices.device_esrgan, devices.device_codeformer = \ diff --git a/modules/ui_gradio_extensions.py b/modules/ui_gradio_extensions.py index a86c368ef70..f5278d22f02 100644 --- a/modules/ui_gradio_extensions.py +++ b/modules/ui_gradio_extensions.py @@ -35,13 +35,11 @@ def stylesheet(fn): return f'' for cssfile in scripts.list_files_with_name("style.css"): - if not os.path.isfile(cssfile): - continue - head += stylesheet(cssfile) - if os.path.exists(os.path.join(data_path, "user.css")): - head += stylesheet(os.path.join(data_path, "user.css")) + user_css = os.path.join(data_path, "user.css") + if os.path.exists(user_css): + head += stylesheet(user_css) return head diff --git a/modules/ui_loadsave.py b/modules/ui_loadsave.py index 693ff75c5d8..2555cdb6c60 100644 --- a/modules/ui_loadsave.py +++ b/modules/ui_loadsave.py @@ -26,8 +26,9 @@ def __init__(self, filename): self.ui_defaults_review = None try: - if os.path.exists(self.filename): - self.ui_settings = self.read_from_file() + self.ui_settings = self.read_from_file() + except FileNotFoundError: + pass except Exception as e: self.error_loading = True errors.display(e, "loading settings") diff --git a/modules/util.py b/modules/util.py index 4861bcb0861..d503f26729e 100644 --- a/modules/util.py +++ b/modules/util.py @@ -21,11 +21,11 @@ def html_path(filename): def html(filename): path = html_path(filename) - if os.path.exists(path): + try: with open(path, encoding="utf8") as file: return file.read() - - return "" + except OSError: + return "" def walk_files(path, allowed_extensions=None): From 420f56c2e85ebbd3f530cf2c7b22022fda13ae13 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Thu, 4 Jan 2024 02:28:05 +0300 Subject: [PATCH 0539/1174] mass file lister as an attempt to tackle #14507 --- modules/extra_networks.py | 5 +-- modules/ui_extra_networks.py | 18 ++++++---- modules/util.py | 70 ++++++++++++++++++++++++++++++++++++ 3 files changed, 85 insertions(+), 8 deletions(-) diff --git a/modules/extra_networks.py b/modules/extra_networks.py index b9533677887..04249dffd27 100644 --- a/modules/extra_networks.py +++ b/modules/extra_networks.py @@ -206,7 +206,7 @@ def parse_prompts(prompts): return res, extra_data -def get_user_metadata(filename): +def get_user_metadata(filename, lister=None): if filename is None: return {} @@ -215,7 +215,8 @@ def get_user_metadata(filename): metadata = {} try: - if os.path.isfile(metadata_filename): + exists = lister.exists(metadata_filename) if lister else os.path.exists(metadata_filename) + if exists: with open(metadata_filename, "r", encoding="utf8") as file: metadata = json.load(file) except Exception as e: diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index e1c679eca35..c06c8664950 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -3,7 +3,7 @@ import urllib.parse from pathlib import Path -from modules import shared, ui_extra_networks_user_metadata, errors, extra_networks +from modules import shared, ui_extra_networks_user_metadata, errors, extra_networks, util from modules.images import read_info_from_image, save_image_with_geninfo import gradio as gr import json @@ -107,6 +107,7 @@ def __init__(self, title): self.allow_negative_prompt = False self.metadata = {} self.items = {} + self.lister = util.MassFileLister() def refresh(self): pass @@ -123,7 +124,7 @@ def read_user_metadata(self, item): def link_preview(self, filename): quoted_filename = urllib.parse.quote(filename.replace('\\', '/')) - mtime = os.path.getmtime(filename) + mtime, _ = self.lister.mctime(filename) return f"./sd_extra_networks/thumb?filename={quoted_filename}&mtime={mtime}" def search_terms_from_path(self, filename, possible_directories=None): @@ -137,6 +138,8 @@ def search_terms_from_path(self, filename, possible_directories=None): return "" def create_html(self, tabname): + self.lister.reset() + items_html = '' self.metadata = {} @@ -282,10 +285,10 @@ def get_sort_keys(self, path): List of default keys used for sorting in the UI. """ pth = Path(path) - stat = pth.stat() + mtime, ctime = self.lister.mctime(path) return { - "date_created": int(stat.st_ctime or 0), - "date_modified": int(stat.st_mtime or 0), + "date_created": int(mtime), + "date_modified": int(ctime), "name": pth.name.lower(), "path": str(pth.parent).lower(), } @@ -298,7 +301,7 @@ def find_preview(self, path): potential_files = sum([[path + "." + ext, path + ".preview." + ext] for ext in allowed_preview_extensions()], []) for file in potential_files: - if os.path.isfile(file): + if self.lister.exists(file): return self.link_preview(file) return None @@ -308,6 +311,9 @@ def find_description(self, path): Find and read a description file for a given path (without extension). """ for file in [f"{path}.txt", f"{path}.description.txt"]: + if not self.lister.exists(file): + continue + try: with open(file, "r", encoding="utf-8", errors="replace") as f: return f.read() diff --git a/modules/util.py b/modules/util.py index 4861bcb0861..c2a27590dca 100644 --- a/modules/util.py +++ b/modules/util.py @@ -66,3 +66,73 @@ def truncate_path(target_path, base_path=cwd): except ValueError: pass return abs_target + + +class MassFileListerCachedDir: + """A class that caches file metadata for a specific directory.""" + + def __init__(self, dirname): + self.files = None + self.files_cased = None + self.dirname = dirname + + stats = ((x.name, x.stat(follow_symlinks=False)) for x in os.scandir(self.dirname)) + files = [(n, s.st_mtime, s.st_ctime) for n, s in stats] + self.files = {x[0].lower(): x for x in files} + self.files_cased = {x[0]: x for x in files} + + +class MassFileLister: + """A class that provides a way to check for the existence and mtime/ctile of files without doing more than one stat call per file.""" + + def __init__(self): + self.cached_dirs = {} + + def find(self, path): + """ + Find the metadata for a file at the given path. + + Returns: + tuple or None: A tuple of (name, mtime, ctime) if the file exists, or None if it does not. + """ + + dirname, filename = os.path.split(path) + + cached_dir = self.cached_dirs.get(dirname) + if cached_dir is None: + cached_dir = MassFileListerCachedDir(dirname) + self.cached_dirs[dirname] = cached_dir + + stats = cached_dir.files_cased.get(filename) + if stats is not None: + return stats + + stats = cached_dir.files.get(filename.lower()) + if stats is None: + return None + + try: + os_stats = os.stat(path, follow_symlinks=False) + return filename, os_stats.st_mtime, os_stats.st_ctime + except Exception: + return None + + def exists(self, path): + """Check if a file exists at the given path.""" + + return self.find(path) is not None + + def mctime(self, path): + """ + Get the modification and creation times for a file at the given path. + + Returns: + tuple: A tuple of (mtime, ctime) if the file exists, or (0, 0) if it does not. + """ + + stats = self.find(path) + return (0, 0) if stats is None else stats[1:3] + + def reset(self): + """Clear the cache of all directories.""" + self.cached_dirs.clear() From 320a217b78047f30e1aa5e735742669a7f4c6bd8 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Thu, 4 Jan 2024 02:39:02 +0300 Subject: [PATCH 0540/1174] forgot something --- modules/ui_extra_networks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index c06c8664950..62db36f53cc 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -114,7 +114,7 @@ def refresh(self): def read_user_metadata(self, item): filename = item.get("filename", None) - metadata = extra_networks.get_user_metadata(filename) + metadata = extra_networks.get_user_metadata(filename, lister=self.lister) desc = metadata.get("description", None) if desc is not None: From 15ec54dd969d6dc3fea7790ca5cce5badcfda426 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Thu, 4 Jan 2024 19:47:00 +0300 Subject: [PATCH 0541/1174] Have upscale button use the same seed as hires fix. --- modules/scripts.py | 21 +++++++++++++++++++++ modules/txt2img.py | 14 +++++++++++++- modules/ui.py | 10 +++++----- modules/ui_common.py | 26 +++++++++++++------------- modules/ui_postprocessing.py | 2 +- 5 files changed, 53 insertions(+), 20 deletions(-) diff --git a/modules/scripts.py b/modules/scripts.py index 017aed5a01e..cf938ebb972 100644 --- a/modules/scripts.py +++ b/modules/scripts.py @@ -91,6 +91,9 @@ class Script: setup_for_ui_only = False """If true, the script setup will only be run in Gradio UI, not in API""" + controls = None + """A list of controls retured by the ui().""" + def title(self): """this function should return the title of the script. This is what will be displayed in the dropdown menu.""" @@ -624,6 +627,7 @@ def create_script_ui_inner(self, script): import modules.api.models as api_models controls = wrap_call(script.ui, script.filename, "ui", script.is_img2img) + script.controls = controls if controls is None: return @@ -918,6 +922,23 @@ def setup_scrips(self, p, *, is_ui=True): except Exception: errors.report(f"Error running setup: {script.filename}", exc_info=True) + def set_named_arg(self, args, script_type, arg_elem_id, value): + script = next((x for x in self.scripts if type(x).__name__ == script_type), None) + if script is None: + return + + for i, control in enumerate(script.controls): + if arg_elem_id in control.elem_id: + index = script.args_from + i + + if isinstance(args, list): + args[index] = value + return args + elif isinstance(args, tuple): + return args[:index] + (value,) + args[index+1:] + else: + return None + scripts_txt2img: ScriptRunner = None scripts_img2img: ScriptRunner = None diff --git a/modules/txt2img.py b/modules/txt2img.py index 4a6fe72a65d..41bb9da33dd 100644 --- a/modules/txt2img.py +++ b/modules/txt2img.py @@ -1,3 +1,4 @@ +import json from contextlib import closing import modules.scripts @@ -9,12 +10,19 @@ import gradio as gr -def txt2img_upscale(id_task: str, request: gr.Request, gallery, gallery_index, *args): +def txt2img_upscale(id_task: str, request: gr.Request, gallery, gallery_index, generation_info, *args): assert len(gallery) > 0, 'No image to upscale' + assert 0 <= gallery_index < len(gallery), f'Bad image index: {gallery_index}' + + geninfo = json.loads(generation_info) + all_seeds = geninfo["all_seeds"] image_info = gallery[gallery_index] if 0 <= gallery_index < len(gallery) else gallery[0] image = infotext_utils.image_from_url_text(image_info) + gallery_index_from_end = len(gallery) - gallery_index + image.seed = all_seeds[-gallery_index_from_end if gallery_index_from_end < len(all_seeds) + 1 else 0] + return txt2img(id_task, request, *args, firstpass_image=image) @@ -22,6 +30,10 @@ def txt2img(id_task: str, request: gr.Request, prompt: str, negative_prompt: str override_settings = create_override_settings_dict(override_settings_texts) if firstpass_image is not None: + seed = getattr(firstpass_image, 'seed', None) + if seed: + args = modules.scripts.scripts_txt2img.set_named_arg(args, 'ScriptSeed', 'seed', seed) + enable_hr = True batch_size = 1 n_iter = 1 diff --git a/modules/ui.py b/modules/ui.py index 7116d71c582..2d2e333b293 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -405,8 +405,8 @@ def create_ui(): txt2img_outputs = [ output_panel.gallery, + output_panel.generation_info, output_panel.infotext, - output_panel.html_info, output_panel.html_log, ] @@ -424,7 +424,7 @@ def create_ui(): output_panel.button_upscale.click( fn=wrap_gradio_gpu_call(modules.txt2img.txt2img_upscale, extra_outputs=[None, '', '']), _js="submit_txt2img_upscale", - inputs=txt2img_inputs[0:1] + [output_panel.gallery, dummy_component] + txt2img_inputs[1:], + inputs=txt2img_inputs[0:1] + [output_panel.gallery, dummy_component, output_panel.generation_info] + txt2img_inputs[1:], outputs=txt2img_outputs, show_progress=False, ) @@ -437,8 +437,8 @@ def create_ui(): inputs=[dummy_component], outputs=[ output_panel.gallery, + output_panel.generation_info, output_panel.infotext, - output_panel.html_info, output_panel.html_log, ], show_progress=False, @@ -766,8 +766,8 @@ def select_img2img_tab(tab): ] + custom_inputs, outputs=[ output_panel.gallery, + output_panel.generation_info, output_panel.infotext, - output_panel.html_info, output_panel.html_log, ], show_progress=False, @@ -807,8 +807,8 @@ def select_img2img_tab(tab): inputs=[dummy_component], outputs=[ output_panel.gallery, + output_panel.generation_info, output_panel.infotext, - output_panel.html_info, output_panel.html_log, ], show_progress=False, diff --git a/modules/ui_common.py b/modules/ui_common.py index ff84197c1de..f17259c2956 100644 --- a/modules/ui_common.py +++ b/modules/ui_common.py @@ -108,8 +108,8 @@ def __init__(self, d=None): @dataclasses.dataclass class OutputPanel: gallery = None + generation_info = None infotext = None - html_info = None html_log = None button_upscale = None @@ -175,17 +175,17 @@ def open_folder(f): download_files = gr.File(None, file_count="multiple", interactive=False, show_label=False, visible=False, elem_id=f'download_files_{tabname}') with gr.Group(): - res.html_info = gr.HTML(elem_id=f'html_info_{tabname}', elem_classes="infotext") + res.infotext = gr.HTML(elem_id=f'html_info_{tabname}', elem_classes="infotext") res.html_log = gr.HTML(elem_id=f'html_log_{tabname}', elem_classes="html-log") - res.infotext = gr.Textbox(visible=False, elem_id=f'generation_info_{tabname}') + res.generation_info = gr.Textbox(visible=False, elem_id=f'generation_info_{tabname}') if tabname == 'txt2img' or tabname == 'img2img': generation_info_button = gr.Button(visible=False, elem_id=f"{tabname}_generation_info_button") generation_info_button.click( fn=update_generation_info, _js="function(x, y, z){ return [x, y, selected_gallery_index()] }", - inputs=[res.infotext, res.html_info, res.html_info], - outputs=[res.html_info, res.html_info], + inputs=[res.generation_info, res.infotext, res.infotext], + outputs=[res.infotext, res.infotext], show_progress=False, ) @@ -193,10 +193,10 @@ def open_folder(f): fn=call_queue.wrap_gradio_call(save_files), _js="(x, y, z, w) => [x, y, false, selected_gallery_index()]", inputs=[ - res.infotext, + res.generation_info, res.gallery, - res.html_info, - res.html_info, + res.infotext, + res.infotext, ], outputs=[ download_files, @@ -209,10 +209,10 @@ def open_folder(f): fn=call_queue.wrap_gradio_call(save_files), _js="(x, y, z, w) => [x, y, true, selected_gallery_index()]", inputs=[ - res.infotext, + res.generation_info, res.gallery, - res.html_info, - res.html_info, + res.infotext, + res.infotext, ], outputs=[ download_files, @@ -221,8 +221,8 @@ def open_folder(f): ) else: - res.infotext = gr.HTML(elem_id=f'html_info_x_{tabname}') - res.html_info = gr.HTML(elem_id=f'html_info_{tabname}', elem_classes="infotext") + res.generation_info = gr.HTML(elem_id=f'html_info_x_{tabname}') + res.infotext = gr.HTML(elem_id=f'html_info_{tabname}', elem_classes="infotext") res.html_log = gr.HTML(elem_id=f'html_log_{tabname}') paste_field_names = [] diff --git a/modules/ui_postprocessing.py b/modules/ui_postprocessing.py index 8f09e6588a3..7a132ac22ed 100644 --- a/modules/ui_postprocessing.py +++ b/modules/ui_postprocessing.py @@ -49,7 +49,7 @@ def create_ui(): ], outputs=[ output_panel.gallery, - output_panel.infotext, + output_panel.generation_info, output_panel.html_log, ], show_progress=False, From 9805f35c6f3ef0b0fc4e3648aa3d6eddf0a907af Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Thu, 4 Jan 2024 19:13:36 +0200 Subject: [PATCH 0542/1174] Ensure GRADIO_ANALYTICS_ENABLED is set early enough --- modules/initialize.py | 2 ++ modules/launch_utils.py | 3 +-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/modules/initialize.py b/modules/initialize.py index 4a3cd98cf86..7c1ac99ef07 100644 --- a/modules/initialize.py +++ b/modules/initialize.py @@ -1,5 +1,6 @@ import importlib import logging +import os import sys import warnings from threading import Thread @@ -18,6 +19,7 @@ def imports(): warnings.filterwarnings(action="ignore", category=DeprecationWarning, module="pytorch_lightning") warnings.filterwarnings(action="ignore", category=UserWarning, module="torchvision") + os.environ.setdefault('GRADIO_ANALYTICS_ENABLED', 'False') import gradio # noqa: F401 startup_timer.record("import gradio") diff --git a/modules/launch_utils.py b/modules/launch_utils.py index 7ebdf0b40a3..c2a7ae932ce 100644 --- a/modules/launch_utils.py +++ b/modules/launch_utils.py @@ -27,8 +27,7 @@ # Whether to default to printing command output default_command_live = (os.environ.get('WEBUI_LAUNCH_LIVE_OUTPUT') == "1") -if 'GRADIO_ANALYTICS_ENABLED' not in os.environ: - os.environ['GRADIO_ANALYTICS_ENABLED'] = 'False' +os.environ.setdefault('GRADIO_ANALYTICS_ENABLED', 'False') def check_python_version(): From 6fa42e919f27d7316a2c7b61674fb3eb17f3a1bb Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Thu, 4 Jan 2024 16:13:08 +0200 Subject: [PATCH 0543/1174] Fix logging configuration again * Only use `tqdm.write()` if `tqdm` is active, defer to stderr * Correct log formatter for TqdmLoggingHandler * If `rich` is installed and `SD_WEBUI_RICH_LOG` is set, use `rich`'s formatter --- modules/logging_config.py | 62 ++++++++++++++++++++++++--------------- 1 file changed, 39 insertions(+), 23 deletions(-) diff --git a/modules/logging_config.py b/modules/logging_config.py index 79269875608..11eee9a63db 100644 --- a/modules/logging_config.py +++ b/modules/logging_config.py @@ -1,41 +1,57 @@ -import os import logging +import os try: - from tqdm.auto import tqdm + from tqdm import tqdm + class TqdmLoggingHandler(logging.Handler): - def __init__(self, level=logging.INFO): - super().__init__(level) + def __init__(self, fallback_handler: logging.Handler): + super().__init__() + self.fallback_handler = fallback_handler def emit(self, record): try: - msg = self.format(record) - tqdm.write(msg) - self.flush() + # If there are active tqdm progress bars, + # attempt to not interfere with them. + if tqdm._instances: + tqdm.write(self.format(record)) + else: + self.fallback_handler.emit(record) except Exception: - self.handleError(record) + self.fallback_handler.emit(record) - TQDM_IMPORTED = True except ImportError: - # tqdm does not exist before first launch - # I will import once the UI finishes seting up the enviroment and reloads. - TQDM_IMPORTED = False + TqdmLoggingHandler = None + def setup_logging(loglevel): if loglevel is None: loglevel = os.environ.get("SD_WEBUI_LOG_LEVEL") - loghandlers = [] + if not loglevel: + return + + if logging.root.handlers: + # Already configured, do not interfere + return + + if os.environ.get("SD_WEBUI_RICH_LOG"): + from rich.logging import RichHandler + handler = RichHandler() + else: + handler = logging.StreamHandler() + + if TqdmLoggingHandler: + handler = TqdmLoggingHandler(handler) + + formatter = logging.Formatter( + '%(asctime)s %(levelname)s [%(name)s] %(message)s', + '%Y-%m-%d %H:%M:%S', + ) - if TQDM_IMPORTED: - loghandlers.append(TqdmLoggingHandler()) + handler.setFormatter(formatter) - if loglevel: - log_level = getattr(logging, loglevel.upper(), None) or logging.INFO - logging.basicConfig( - level=log_level, - format='%(asctime)s %(levelname)s [%(name)s] %(message)s', - datefmt='%Y-%m-%d %H:%M:%S', - handlers=loghandlers - ) + log_level = getattr(logging, loglevel.upper(), None) or logging.INFO + logging.root.setLevel(log_level) + logging.root.addHandler(handler) From f8f38c7c28e48f9f79225c969e3e82b1adcfb910 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Fri, 5 Jan 2024 16:31:48 +0800 Subject: [PATCH 0544/1174] Fix dtype casting for OFT module --- extensions-builtin/Lora/network_oft.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/extensions-builtin/Lora/network_oft.py b/extensions-builtin/Lora/network_oft.py index fa647020f0a..342fcd0dcaf 100644 --- a/extensions-builtin/Lora/network_oft.py +++ b/extensions-builtin/Lora/network_oft.py @@ -56,7 +56,7 @@ def __init__(self, net: network.Network, weights: network.NetworkWeights): self.block_size, self.num_blocks = factorization(self.out_dim, self.dim) def calc_updown(self, orig_weight): - oft_blocks = self.oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype) + oft_blocks = self.oft_blocks.to(orig_weight.device) eye = torch.eye(self.block_size, device=self.oft_blocks.device) if self.is_kohya: @@ -66,7 +66,7 @@ def calc_updown(self, orig_weight): block_Q = block_Q * ((new_norm_Q + 1e-8) / (norm_Q + 1e-8)) oft_blocks = torch.matmul(eye + block_Q, (eye - block_Q).float().inverse()) - R = oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype) + R = oft_blocks.to(orig_weight.device) # This errors out for MultiheadAttention, might need to be handled up-stream merged_weight = rearrange(orig_weight, '(k n) ... -> k n ...', k=self.num_blocks, n=self.block_size) @@ -77,6 +77,6 @@ def calc_updown(self, orig_weight): ) merged_weight = rearrange(merged_weight, 'k m ... -> (k m) ...') - updown = merged_weight.to(orig_weight.device, dtype=orig_weight.dtype) - orig_weight + updown = merged_weight.to(orig_weight.device) - orig_weight.to(merged_weight.dtype) output_shape = orig_weight.shape return self.finalize_updown(updown, orig_weight, output_shape) From 18ca987c92f52690daec43a6c67363c341bb6008 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Fri, 5 Jan 2024 16:32:19 +0800 Subject: [PATCH 0545/1174] Add general forward method for all modules. --- extensions-builtin/Lora/network.py | 34 ++++++++++++++++++++++++++++- extensions-builtin/Lora/networks.py | 12 +++++----- 2 files changed, 39 insertions(+), 7 deletions(-) diff --git a/extensions-builtin/Lora/network.py b/extensions-builtin/Lora/network.py index a62e5eff9e1..f9b571b5fa1 100644 --- a/extensions-builtin/Lora/network.py +++ b/extensions-builtin/Lora/network.py @@ -3,6 +3,10 @@ from collections import namedtuple import enum +import torch +import torch.nn as nn +import torch.nn.functional as F + from modules import sd_models, cache, errors, hashes, shared NetworkWeights = namedtuple('NetworkWeights', ['network_key', 'sd_key', 'w', 'sd_module']) @@ -115,6 +119,29 @@ def __init__(self, net: Network, weights: NetworkWeights): if hasattr(self.sd_module, 'weight'): self.shape = self.sd_module.weight.shape + self.ops = None + self.extra_kwargs = {} + if isinstance(self.sd_module, nn.Conv2d): + self.ops = F.conv2d + self.extra_kwargs = { + 'stride': self.sd_module.stride, + 'padding': self.sd_module.padding + } + elif isinstance(self.sd_module, nn.Linear): + self.ops = F.linear + elif isinstance(self.sd_module, nn.LayerNorm): + self.ops = F.layer_norm + self.extra_kwargs = { + 'normalized_shape': self.sd_module.normalized_shape, + 'eps': self.sd_module.eps + } + elif isinstance(self.sd_module, nn.GroupNorm): + self.ops = F.group_norm + self.extra_kwargs = { + 'num_groups': self.sd_module.num_groups, + 'eps': self.sd_module.eps + } + self.dim = None self.bias = weights.w.get("bias") self.alpha = weights.w["alpha"].item() if "alpha" in weights.w else None @@ -155,5 +182,10 @@ def calc_updown(self, target): raise NotImplementedError() def forward(self, x, y): - raise NotImplementedError() + """A general forward implementation for all modules""" + if self.ops is None: + raise NotImplementedError() + else: + updown, ex_bias = self.calc_updown(self.sd_module.weight) + return y + self.ops(x, weight=updown, bias=ex_bias, **self.extra_kwargs) diff --git a/extensions-builtin/Lora/networks.py b/extensions-builtin/Lora/networks.py index 72ebd624169..32e10b62544 100644 --- a/extensions-builtin/Lora/networks.py +++ b/extensions-builtin/Lora/networks.py @@ -458,23 +458,23 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn self.network_current_names = wanted_names -def network_forward(module, input, original_forward): +def network_forward(org_module, input, original_forward): """ Old way of applying Lora by executing operations during layer's forward. Stacking many loras this way results in big performance degradation. """ if len(loaded_networks) == 0: - return original_forward(module, input) + return original_forward(org_module, input) input = devices.cond_cast_unet(input) - network_restore_weights_from_backup(module) - network_reset_cached_weight(module) + network_restore_weights_from_backup(org_module) + network_reset_cached_weight(org_module) - y = original_forward(module, input) + y = original_forward(org_module, input) - network_layer_name = getattr(module, 'network_layer_name', None) + network_layer_name = getattr(org_module, 'network_layer_name', None) for lora in loaded_networks: module = lora.modules.get(network_layer_name, None) if module is None: From 44744d6005da5e424267698ee3279caa597dfebc Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Fri, 5 Jan 2024 16:38:05 +0800 Subject: [PATCH 0546/1174] linting --- extensions-builtin/Lora/network.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/extensions-builtin/Lora/network.py b/extensions-builtin/Lora/network.py index f9b571b5fa1..b8fd91941c8 100644 --- a/extensions-builtin/Lora/network.py +++ b/extensions-builtin/Lora/network.py @@ -3,7 +3,6 @@ from collections import namedtuple import enum -import torch import torch.nn as nn import torch.nn.functional as F @@ -124,7 +123,7 @@ def __init__(self, net: Network, weights: NetworkWeights): if isinstance(self.sd_module, nn.Conv2d): self.ops = F.conv2d self.extra_kwargs = { - 'stride': self.sd_module.stride, + 'stride': self.sd_module.stride, 'padding': self.sd_module.padding } elif isinstance(self.sd_module, nn.Linear): From 88ba095fd022fbc1acbbbc845e23493a5b2c5d8b Mon Sep 17 00:00:00 2001 From: Keshav Nischal <91429557+keshav-nischal@users.noreply.github.com> Date: Fri, 5 Jan 2024 14:15:58 +0530 Subject: [PATCH 0547/1174] Update README.md --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 9f9f33b1295..7f1ea457522 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,5 @@ # Stable Diffusion web UI -A browser interface based on Gradio library for Stable Diffusion. +A web interface for Stable Diffusion, implemented using Gradio library. ![](screenshot.png) From 233c66b36eba07b905f9743c2ad807aec33d9ccb Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Fri, 5 Jan 2024 12:28:32 +0300 Subject: [PATCH 0548/1174] Make the upscale button update the gallery with the new image rather than replace it. --- modules/processing.py | 6 ++- modules/txt2img.py | 85 +++++++++++++++++++++++++++++-------------- 2 files changed, 63 insertions(+), 28 deletions(-) diff --git a/modules/processing.py b/modules/processing.py index 84e7b1b4577..dcc807fe38d 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -732,7 +732,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter "Variation seed": (None if p.subseed_strength == 0 else (p.all_subseeds[0] if use_main_prompt else all_subseeds[index])), "Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength), "Seed resize from": (None if p.seed_resize_from_w <= 0 or p.seed_resize_from_h <= 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"), - "Denoising strength": getattr(p, 'denoising_strength', None), + "Denoising strength": p.extra_generation_params.get("Denoising strength"), "Conditional mask weight": getattr(p, "inpainting_mask_weight", shared.opts.inpainting_mask_weight) if p.is_using_inpainting_conditioning else None, "Clip skip": None if clip_skip <= 1 else clip_skip, "ENSD": opts.eta_noise_seed_delta if uses_ensd else None, @@ -1198,6 +1198,8 @@ def calculate_target_resolution(self): def init(self, all_prompts, all_seeds, all_subseeds): if self.enable_hr: + self.extra_generation_params["Denoising strength"] = self.denoising_strength + if self.hr_checkpoint_name and self.hr_checkpoint_name != 'Use same checkpoint': self.hr_checkpoint_info = sd_models.get_closet_checkpoint_match(self.hr_checkpoint_name) @@ -1516,6 +1518,8 @@ def mask_blur(self, value): self.mask_blur_y = value def init(self, all_prompts, all_seeds, all_subseeds): + self.extra_generation_params["Denoising strength"] = self.denoising_strength + self.image_cfg_scale: float = self.image_cfg_scale if shared.sd_model.cond_stage_key == "edit" else None self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model) diff --git a/modules/txt2img.py b/modules/txt2img.py index 41bb9da33dd..c4cc12d2f6d 100644 --- a/modules/txt2img.py +++ b/modules/txt2img.py @@ -7,36 +7,15 @@ from modules.shared import opts import modules.shared as shared from modules.ui import plaintext_to_html +from PIL import Image import gradio as gr -def txt2img_upscale(id_task: str, request: gr.Request, gallery, gallery_index, generation_info, *args): - assert len(gallery) > 0, 'No image to upscale' - assert 0 <= gallery_index < len(gallery), f'Bad image index: {gallery_index}' - - geninfo = json.loads(generation_info) - all_seeds = geninfo["all_seeds"] - - image_info = gallery[gallery_index] if 0 <= gallery_index < len(gallery) else gallery[0] - image = infotext_utils.image_from_url_text(image_info) - - gallery_index_from_end = len(gallery) - gallery_index - image.seed = all_seeds[-gallery_index_from_end if gallery_index_from_end < len(all_seeds) + 1 else 0] - - return txt2img(id_task, request, *args, firstpass_image=image) - - -def txt2img(id_task: str, request: gr.Request, prompt: str, negative_prompt: str, prompt_styles, steps: int, sampler_name: str, n_iter: int, batch_size: int, cfg_scale: float, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, hr_checkpoint_name: str, hr_sampler_name: str, hr_prompt: str, hr_negative_prompt, override_settings_texts, *args, firstpass_image=None): +def txt2img_create_processing(id_task: str, request: gr.Request, prompt: str, negative_prompt: str, prompt_styles, steps: int, sampler_name: str, n_iter: int, batch_size: int, cfg_scale: float, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, hr_checkpoint_name: str, hr_sampler_name: str, hr_prompt: str, hr_negative_prompt, override_settings_texts, *args, force_enable_hr=False): override_settings = create_override_settings_dict(override_settings_texts) - if firstpass_image is not None: - seed = getattr(firstpass_image, 'seed', None) - if seed: - args = modules.scripts.scripts_txt2img.set_named_arg(args, 'ScriptSeed', 'seed', seed) - + if force_enable_hr: enable_hr = True - batch_size = 1 - n_iter = 1 p = processing.StableDiffusionProcessingTxt2Img( sd_model=shared.sd_model, @@ -53,7 +32,7 @@ def txt2img(id_task: str, request: gr.Request, prompt: str, negative_prompt: str width=width, height=height, enable_hr=enable_hr, - denoising_strength=denoising_strength if enable_hr else None, + denoising_strength=denoising_strength, hr_scale=hr_scale, hr_upscaler=hr_upscaler, hr_second_pass_steps=hr_second_pass_steps, @@ -64,7 +43,6 @@ def txt2img(id_task: str, request: gr.Request, prompt: str, negative_prompt: str hr_prompt=hr_prompt, hr_negative_prompt=hr_negative_prompt, override_settings=override_settings, - firstpass_image=firstpass_image, ) p.scripts = modules.scripts.scripts_txt2img @@ -75,8 +53,61 @@ def txt2img(id_task: str, request: gr.Request, prompt: str, negative_prompt: str if shared.opts.enable_console_prompts: print(f"\ntxt2img: {prompt}", file=shared.progress_print_out) + return p + + +def txt2img_upscale(id_task: str, request: gr.Request, gallery, gallery_index, generation_info, *args): + assert len(gallery) > 0, 'No image to upscale' + assert 0 <= gallery_index < len(gallery), f'Bad image index: {gallery_index}' + + p = txt2img_create_processing(id_task, request, *args) + p.enable_hr = True + p.batch_size = 1 + p.n_iter = 1 + + geninfo = json.loads(generation_info) + all_seeds = geninfo["all_seeds"] + + image_info = gallery[gallery_index] if 0 <= gallery_index < len(gallery) else gallery[0] + p.firstpass_image = infotext_utils.image_from_url_text(image_info) + + gallery_index_from_end = len(gallery) - gallery_index + seed = all_seeds[-gallery_index_from_end if gallery_index_from_end < len(all_seeds) + 1 else 0] + p.script_args = modules.scripts.scripts_txt2img.set_named_arg(p.script_args, 'ScriptSeed', 'seed', seed) + + with closing(p): + processed = modules.scripts.scripts_txt2img.run(p, *p.script_args) + + if processed is None: + processed = processing.process_images(p) + + shared.total_tqdm.clear() + + new_gallery = [] + for i, image in enumerate(gallery): + fake_image = Image.new(mode="RGB", size=(1, 1)) + + if i == gallery_index: + already_saved_as = getattr(processed.images[0], 'already_saved_as', None) + if already_saved_as is not None: + fake_image.already_saved_as = already_saved_as + else: + fake_image = processed.images[0] + else: + fake_image.already_saved_as = image["name"] + + new_gallery.append(fake_image) + + geninfo["infotexts"][gallery_index] = processed.info + + return new_gallery, json.dumps(geninfo), plaintext_to_html(processed.info), plaintext_to_html(processed.comments, classname="comments") + + +def txt2img(id_task: str, request: gr.Request, *args): + p = txt2img_create_processing(id_task, request, *args) + with closing(p): - processed = modules.scripts.scripts_txt2img.run(p, *args) + processed = modules.scripts.scripts_txt2img.run(p, *p.script_args) if processed is None: processed = processing.process_images(p) From 16b4d2cf3f51f1d88b97d1d459dec59d3a2d0642 Mon Sep 17 00:00:00 2001 From: Nuullll Date: Sat, 6 Jan 2024 16:32:18 +0800 Subject: [PATCH 0549/1174] [IPEX] Fix SDPA attn_mask dtype --- modules/xpu_specific.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/modules/xpu_specific.py b/modules/xpu_specific.py index f7687a66cc2..4e11125b20b 100644 --- a/modules/xpu_specific.py +++ b/modules/xpu_specific.py @@ -41,6 +41,8 @@ def torch_xpu_scaled_dot_product_attention( # cast to same dtype first key = key.to(query.dtype) value = value.to(query.dtype) + if attn_mask is not None and attn_mask.dtype != torch.bool: + attn_mask = attn_mask.to(query.dtype) N = query.shape[:-2] # Batch size L = query.size(-2) # Target sequence length From ec9acb31450536a9258192351d6a26421efd7eb4 Mon Sep 17 00:00:00 2001 From: Nuullll Date: Sat, 6 Jan 2024 17:17:53 +0800 Subject: [PATCH 0550/1174] Handle CondFunc exception when resolving attributes --- modules/sd_hijack_utils.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/modules/sd_hijack_utils.py b/modules/sd_hijack_utils.py index f8684475ec5..79bf6e46862 100644 --- a/modules/sd_hijack_utils.py +++ b/modules/sd_hijack_utils.py @@ -11,10 +11,14 @@ def __new__(cls, orig_func, sub_func, cond_func): break except ImportError: pass - for attr_name in func_path[i:-1]: - resolved_obj = getattr(resolved_obj, attr_name) - orig_func = getattr(resolved_obj, func_path[-1]) - setattr(resolved_obj, func_path[-1], lambda *args, **kwargs: self(*args, **kwargs)) + try: + for attr_name in func_path[i:-1]: + resolved_obj = getattr(resolved_obj, attr_name) + orig_func = getattr(resolved_obj, func_path[-1]) + setattr(resolved_obj, func_path[-1], lambda *args, **kwargs: self(*args, **kwargs)) + except AttributeError: + print(f"Warning: Failed to resolve {orig_func} for CondFunc hijack") + pass self.__init__(orig_func, sub_func, cond_func) return lambda *args, **kwargs: self(*args, **kwargs) def __init__(self, orig_func, sub_func, cond_func): From 73786c047f14d6ae658b2c12f493f05486ba1789 Mon Sep 17 00:00:00 2001 From: Nuullll Date: Sat, 6 Jan 2024 19:09:56 +0800 Subject: [PATCH 0551/1174] [IPEX] Fix torch.Generator hijack --- modules/xpu_specific.py | 20 ++++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/modules/xpu_specific.py b/modules/xpu_specific.py index 4e11125b20b..1137891a61c 100644 --- a/modules/xpu_specific.py +++ b/modules/xpu_specific.py @@ -94,11 +94,23 @@ def torch_xpu_scaled_dot_product_attention( return torch.reshape(result, (*N, L, Ev)) +def is_xpu_device(device: str | torch.device = None): + if device is None: + return False + if isinstance(device, str): + return device.startswith("xpu") + return device.type == "xpu" + + if has_xpu: - # W/A for https://github.com/intel/intel-extension-for-pytorch/issues/452: torch.Generator API doesn't support XPU device - CondFunc('torch.Generator', - lambda orig_func, device=None: torch.xpu.Generator(device), - lambda orig_func, device=None: device is not None and device.type == "xpu") + try: + # torch.Generator supports "xpu" device since 2.1 + torch.Generator("xpu") + except: + # W/A for https://github.com/intel/intel-extension-for-pytorch/issues/452: torch.Generator API doesn't support XPU device (for IPEX < 2.1) + CondFunc('torch.Generator', + lambda orig_func, device=None: torch.xpu.Generator(device), + lambda orig_func, device=None: is_xpu_device(device)) # W/A for some OPs that could not handle different input dtypes CondFunc('torch.nn.functional.layer_norm', From 818d6a11e709bf07d48606bdccab944c46a5f4b0 Mon Sep 17 00:00:00 2001 From: Nuullll Date: Sat, 6 Jan 2024 19:14:06 +0800 Subject: [PATCH 0552/1174] Fix format --- modules/xpu_specific.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/modules/xpu_specific.py b/modules/xpu_specific.py index 1137891a61c..2971dbc3cf5 100644 --- a/modules/xpu_specific.py +++ b/modules/xpu_specific.py @@ -106,8 +106,8 @@ def is_xpu_device(device: str | torch.device = None): try: # torch.Generator supports "xpu" device since 2.1 torch.Generator("xpu") - except: - # W/A for https://github.com/intel/intel-extension-for-pytorch/issues/452: torch.Generator API doesn't support XPU device (for IPEX < 2.1) + except RuntimeError: + # W/A for https://github.com/intel/intel-extension-for-pytorch/issues/452: torch.Generator API doesn't support XPU device (for torch < 2.1) CondFunc('torch.Generator', lambda orig_func, device=None: torch.xpu.Generator(device), lambda orig_func, device=None: is_xpu_device(device)) From a183de04e3f965083e7f3462201327d30c36b958 Mon Sep 17 00:00:00 2001 From: Nuullll Date: Sat, 6 Jan 2024 20:03:33 +0800 Subject: [PATCH 0553/1174] Execute model_loaded_callback after moving to target device --- modules/sd_models.py | 6 +++--- modules/sd_vae.py | 3 ++- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/modules/sd_models.py b/modules/sd_models.py index 50bc209e465..2c04577152c 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -842,13 +842,13 @@ def reload_model_weights(sd_model=None, info=None, forced_reload=False): sd_hijack.model_hijack.hijack(sd_model) timer.record("hijack") - script_callbacks.model_loaded_callback(sd_model) - timer.record("script callbacks") - if not sd_model.lowvram: sd_model.to(devices.device) timer.record("move model to device") + script_callbacks.model_loaded_callback(sd_model) + timer.record("script callbacks") + print(f"Weights loaded in {timer.summary()}.") model_data.set_sd_model(sd_model) diff --git a/modules/sd_vae.py b/modules/sd_vae.py index 31306d8ba4b..43687e48dcf 100644 --- a/modules/sd_vae.py +++ b/modules/sd_vae.py @@ -273,10 +273,11 @@ def reload_vae_weights(sd_model=None, vae_file=unspecified): load_vae(sd_model, vae_file, vae_source) sd_hijack.model_hijack.hijack(sd_model) - script_callbacks.model_loaded_callback(sd_model) if not sd_model.lowvram: sd_model.to(devices.device) + script_callbacks.model_loaded_callback(sd_model) + print("VAE weights loaded.") return sd_model From 2f98a35fc4508494355c01ec45f5bec725f570a6 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sun, 7 Jan 2024 09:21:21 +0300 Subject: [PATCH 0554/1174] add assets repo; serve fonts locally rather than from google's servers --- modules/launch_utils.py | 3 +++ modules/sysinfo.py | 2 ++ modules/ui.py | 4 +++- style.css | 2 +- 4 files changed, 9 insertions(+), 2 deletions(-) diff --git a/modules/launch_utils.py b/modules/launch_utils.py index c2a7ae932ce..8e58d714550 100644 --- a/modules/launch_utils.py +++ b/modules/launch_utils.py @@ -344,11 +344,13 @@ def prepare_environment(): clip_package = os.environ.get('CLIP_PACKAGE', "https://github.com/openai/CLIP/archive/d50d76daa670286dd6cacf3bcd80b5e4823fc8e1.zip") openclip_package = os.environ.get('OPENCLIP_PACKAGE', "https://github.com/mlfoundations/open_clip/archive/bb6e834e9c70d9c27d0dc3ecedeebeaeb1ffad6b.zip") + assets_repo = os.environ.get('ASSETS_REPO', "https://github.com/AUTOMATIC1111/stable-diffusion-webui-assets.git") stable_diffusion_repo = os.environ.get('STABLE_DIFFUSION_REPO', "https://github.com/Stability-AI/stablediffusion.git") stable_diffusion_xl_repo = os.environ.get('STABLE_DIFFUSION_XL_REPO', "https://github.com/Stability-AI/generative-models.git") k_diffusion_repo = os.environ.get('K_DIFFUSION_REPO', 'https://github.com/crowsonkb/k-diffusion.git') blip_repo = os.environ.get('BLIP_REPO', 'https://github.com/salesforce/BLIP.git') + assets_commit_hash = os.environ.get('ASSETS_COMMIT_HASH', "6f7db241d2f8ba7457bac5ca9753331f0c266917") stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "cf1d67a6fd5ea1aa600c4df58e5b47da45f6bdbf") stable_diffusion_xl_commit_hash = os.environ.get('STABLE_DIFFUSION_XL_COMMIT_HASH', "45c443b316737a4ab6e40413d7794a7f5657c19f") k_diffusion_commit_hash = os.environ.get('K_DIFFUSION_COMMIT_HASH', "ab527a9a6d347f364e3d185ba6d714e22d80cb3c") @@ -405,6 +407,7 @@ def prepare_environment(): os.makedirs(os.path.join(script_path, dir_repos), exist_ok=True) + git_clone(assets_repo, repo_dir('stable-diffusion-webui-assets'), "assets", assets_commit_hash) git_clone(stable_diffusion_repo, repo_dir('stable-diffusion-stability-ai'), "Stable Diffusion", stable_diffusion_commit_hash) git_clone(stable_diffusion_xl_repo, repo_dir('generative-models'), "Stable Diffusion XL", stable_diffusion_xl_commit_hash) git_clone(k_diffusion_repo, repo_dir('k-diffusion'), "K-diffusion", k_diffusion_commit_hash) diff --git a/modules/sysinfo.py b/modules/sysinfo.py index 5abf616b707..f336251e445 100644 --- a/modules/sysinfo.py +++ b/modules/sysinfo.py @@ -24,9 +24,11 @@ "XFORMERS_PACKAGE", "CLIP_PACKAGE", "OPENCLIP_PACKAGE", + "ASSETS_REPO", "STABLE_DIFFUSION_REPO", "K_DIFFUSION_REPO", "BLIP_REPO", + "ASSETS_COMMIT_HASH", "STABLE_DIFFUSION_COMMIT_HASH", "K_DIFFUSION_COMMIT_HASH", "BLIP_COMMIT_HASH", diff --git a/modules/ui.py b/modules/ui.py index 2d2e333b293..a716a04055d 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -13,7 +13,7 @@ from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call from modules import gradio_extensons # noqa: F401 -from modules import sd_hijack, sd_models, script_callbacks, ui_extensions, deepbooru, extra_networks, ui_common, ui_postprocessing, progress, ui_loadsave, shared_items, ui_settings, timer, sysinfo, ui_checkpoint_merger, scripts, sd_samplers, processing, ui_extra_networks, ui_toprow +from modules import sd_hijack, sd_models, script_callbacks, ui_extensions, deepbooru, extra_networks, ui_common, ui_postprocessing, progress, ui_loadsave, shared_items, ui_settings, timer, sysinfo, ui_checkpoint_merger, scripts, sd_samplers, processing, ui_extra_networks, ui_toprow, launch_utils from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML, InputAccordion, ResizeHandleRow from modules.paths import script_path from modules.ui_common import create_refresh_button @@ -1223,3 +1223,5 @@ def download_sysinfo(attachment=False): app.add_api_route("/internal/sysinfo", download_sysinfo, methods=["GET"]) app.add_api_route("/internal/sysinfo-download", lambda: download_sysinfo(attachment=True), methods=["GET"]) + import fastapi.staticfiles + app.mount("/webui-assets", fastapi.staticfiles.StaticFiles(directory=launch_utils.repo_dir('stable-diffusion-webui-assets')), name="webui-assets") diff --git a/style.css b/style.css index 6d4c8a0d505..4957c523df3 100644 --- a/style.css +++ b/style.css @@ -1,6 +1,6 @@ /* temporary fix to load default gradio font in frontend instead of backend */ -@import url('https://fonts.googleapis.com/css2?family=Source+Sans+Pro:wght@400;600&display=swap'); +@import url('webui-assets/css/sourcesanspro.css'); /* temporary fix to hide gradio crop tool until it's fixed https://github.com/gradio-app/gradio/issues/3810 */ From 425507bd10c55f1f804eb5015db74520668f46f9 Mon Sep 17 00:00:00 2001 From: continue-revolution Date: Sun, 7 Jan 2024 10:25:01 -0600 Subject: [PATCH 0555/1174] add p to cfgdenoiserparams --- modules/script_callbacks.py | 5 ++++- modules/sd_samplers_cfg_denoiser.py | 2 +- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py index 9ed7ad21d1b..bb47c18d3f1 100644 --- a/modules/script_callbacks.py +++ b/modules/script_callbacks.py @@ -41,7 +41,7 @@ def __init__(self, noise, x, xi): class CFGDenoiserParams: - def __init__(self, x, image_cond, sigma, sampling_step, total_sampling_steps, text_cond, text_uncond): + def __init__(self, x, image_cond, sigma, sampling_step, total_sampling_steps, text_cond, text_uncond, p): self.x = x """Latent image representation in the process of being denoised""" @@ -63,6 +63,9 @@ def __init__(self, x, image_cond, sigma, sampling_step, total_sampling_steps, te self.text_uncond = text_uncond """ Encoder hidden states of text conditioning from negative prompt""" + self.p = p + """StableDiffusionProcessing object with processing parameters""" + class CFGDenoisedParams: def __init__(self, x, sampling_step, total_sampling_steps, inner_model): diff --git a/modules/sd_samplers_cfg_denoiser.py b/modules/sd_samplers_cfg_denoiser.py index eb9d5dafacc..f4ded6bdbe9 100644 --- a/modules/sd_samplers_cfg_denoiser.py +++ b/modules/sd_samplers_cfg_denoiser.py @@ -146,7 +146,7 @@ def apply_blend(current_latent): sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma] + [sigma]) image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_uncond] + [torch.zeros_like(self.init_latent)]) - denoiser_params = CFGDenoiserParams(x_in, image_cond_in, sigma_in, state.sampling_step, state.sampling_steps, tensor, uncond) + denoiser_params = CFGDenoiserParams(x_in, image_cond_in, sigma_in, state.sampling_step, state.sampling_steps, tensor, uncond, self.p) cfg_denoiser_callback(denoiser_params) x_in = denoiser_params.x image_cond_in = denoiser_params.image_cond From f56cebf5ba24313447b2204c3f804379767201c9 Mon Sep 17 00:00:00 2001 From: continue-revolution Date: Sun, 7 Jan 2024 12:35:35 -0600 Subject: [PATCH 0556/1174] add self instead --- modules/script_callbacks.py | 6 +++--- modules/sd_samplers_cfg_denoiser.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py index bb47c18d3f1..053dfc96309 100644 --- a/modules/script_callbacks.py +++ b/modules/script_callbacks.py @@ -41,7 +41,7 @@ def __init__(self, noise, x, xi): class CFGDenoiserParams: - def __init__(self, x, image_cond, sigma, sampling_step, total_sampling_steps, text_cond, text_uncond, p): + def __init__(self, x, image_cond, sigma, sampling_step, total_sampling_steps, text_cond, text_uncond, denoiser): self.x = x """Latent image representation in the process of being denoised""" @@ -63,8 +63,8 @@ def __init__(self, x, image_cond, sigma, sampling_step, total_sampling_steps, te self.text_uncond = text_uncond """ Encoder hidden states of text conditioning from negative prompt""" - self.p = p - """StableDiffusionProcessing object with processing parameters""" + self.denoiser = denoiser + """Current CFGDenoiser object with processing parameters""" class CFGDenoisedParams: diff --git a/modules/sd_samplers_cfg_denoiser.py b/modules/sd_samplers_cfg_denoiser.py index f4ded6bdbe9..6d76aa965a5 100644 --- a/modules/sd_samplers_cfg_denoiser.py +++ b/modules/sd_samplers_cfg_denoiser.py @@ -146,7 +146,7 @@ def apply_blend(current_latent): sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma] + [sigma]) image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_uncond] + [torch.zeros_like(self.init_latent)]) - denoiser_params = CFGDenoiserParams(x_in, image_cond_in, sigma_in, state.sampling_step, state.sampling_steps, tensor, uncond, self.p) + denoiser_params = CFGDenoiserParams(x_in, image_cond_in, sigma_in, state.sampling_step, state.sampling_steps, tensor, uncond, self) cfg_denoiser_callback(denoiser_params) x_in = denoiser_params.x image_cond_in = denoiser_params.image_cond From 37906e429ae5a92f9ea96ab6dc2157b1d7c4d8b6 Mon Sep 17 00:00:00 2001 From: Chengsong Zhang Date: Sun, 7 Jan 2024 20:17:42 -0600 Subject: [PATCH 0557/1174] make denoiser None by default --- modules/script_callbacks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py index 053dfc96309..a54cb3ebb56 100644 --- a/modules/script_callbacks.py +++ b/modules/script_callbacks.py @@ -41,7 +41,7 @@ def __init__(self, noise, x, xi): class CFGDenoiserParams: - def __init__(self, x, image_cond, sigma, sampling_step, total_sampling_steps, text_cond, text_uncond, denoiser): + def __init__(self, x, image_cond, sigma, sampling_step, total_sampling_steps, text_cond, text_uncond, denoiser=None): self.x = x """Latent image representation in the process of being denoised""" From 8e292373ec5c54493ce48af7c76f5eaa79dc8abd Mon Sep 17 00:00:00 2001 From: continue-revolution Date: Mon, 8 Jan 2024 06:43:39 -0600 Subject: [PATCH 0558/1174] lcm sampler --- modules/sd_samplers.py | 3 +- modules/sd_samplers_lcm.py | 104 +++++++++++++++++++++++++++++++++++++ 2 files changed, 106 insertions(+), 1 deletion(-) create mode 100644 modules/sd_samplers_lcm.py diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py index 45faae62821..a58528a0b52 100644 --- a/modules/sd_samplers.py +++ b/modules/sd_samplers.py @@ -1,4 +1,4 @@ -from modules import sd_samplers_kdiffusion, sd_samplers_timesteps, shared +from modules import sd_samplers_kdiffusion, sd_samplers_timesteps, sd_samplers_lcm, shared # imports for functions that previously were here and are used by other modules from modules.sd_samplers_common import samples_to_image_grid, sample_to_image # noqa: F401 @@ -6,6 +6,7 @@ all_samplers = [ *sd_samplers_kdiffusion.samplers_data_k_diffusion, *sd_samplers_timesteps.samplers_data_timesteps, + *sd_samplers_lcm.samplers_data_lcm, ] all_samplers_map = {x.name: x for x in all_samplers} diff --git a/modules/sd_samplers_lcm.py b/modules/sd_samplers_lcm.py new file mode 100644 index 00000000000..59839b720dd --- /dev/null +++ b/modules/sd_samplers_lcm.py @@ -0,0 +1,104 @@ +import torch + +from k_diffusion import utils, sampling +from k_diffusion.external import DiscreteEpsDDPMDenoiser +from k_diffusion.sampling import default_noise_sampler, trange + +from modules import shared, sd_samplers_cfg_denoiser, sd_samplers_kdiffusion, sd_samplers_common + + +class LCMCompVisDenoiser(DiscreteEpsDDPMDenoiser): + def __init__(self, model): + timesteps = 1000 + original_timesteps = 50 # LCM Original Timesteps (default=50, for current version of LCM) + self.skip_steps = timesteps // original_timesteps + + alphas_cumprod_valid = torch.zeros((original_timesteps), dtype=torch.float32, device=model.device) + for x in range(original_timesteps): + alphas_cumprod_valid[original_timesteps - 1 - x] = model.alphas_cumprod[timesteps - 1 - x * self.skip_steps] + + super().__init__(model, alphas_cumprod_valid, quantize=None) + + + def get_sigmas(self, n=None,): + if n is None: + return sampling.append_zero(self.sigmas.flip(0)) + + start = self.sigma_to_t(self.sigma_max) + end = self.sigma_to_t(self.sigma_min) + + t = torch.linspace(start, end, n, device=shared.sd_model.device) + + return sampling.append_zero(self.t_to_sigma(t)) + + + def sigma_to_t(self, sigma, quantize=None): + log_sigma = sigma.log() + dists = log_sigma - self.log_sigmas[:, None] + return dists.abs().argmin(dim=0).view(sigma.shape) * self.skip_steps + (self.skip_steps - 1) + + + def t_to_sigma(self, timestep): + t = torch.clamp(((timestep - (self.skip_steps - 1)) / self.skip_steps).float(), min=0, max=(len(self.sigmas) - 1)) + return super().t_to_sigma(t) + + + def get_eps(self, *args, **kwargs): + return self.inner_model.apply_model(*args, **kwargs) + + + def get_scaled_out(self, sigma, output, input): + sigma_data = 0.5 + scaled_timestep = utils.append_dims(self.sigma_to_t(sigma), output.ndim) * 10.0 + + c_skip = sigma_data**2 / (scaled_timestep**2 + sigma_data**2) + c_out = scaled_timestep / (scaled_timestep**2 + sigma_data**2) ** 0.5 + + return c_out * output + c_skip * input + + + def forward(self, input, sigma, **kwargs): + c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)] + eps = self.get_eps(input * c_in, self.sigma_to_t(sigma), **kwargs) + return self.get_scaled_out(sigma, input + eps * c_out, input) + + +def sample_lcm(model, x, sigmas, extra_args=None, callback=None, disable=None, noise_sampler=None): + extra_args = {} if extra_args is None else extra_args + noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler + s_in = x.new_ones([x.shape[0]]) + + for i in trange(len(sigmas) - 1, disable=disable): + denoised = model(x, sigmas[i] * s_in, **extra_args) + + if callback is not None: + callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) + + x = denoised + if sigmas[i + 1] > 0: + x += sigmas[i + 1] * noise_sampler(sigmas[i], sigmas[i + 1]) + return x + + +class CFGDenoiserLCM(sd_samplers_cfg_denoiser.CFGDenoiser): + @property + def inner_model(self): + if self.model_wrap is None: + denoiser = LCMCompVisDenoiser + self.model_wrap = denoiser(shared.sd_model) + + return self.model_wrap + + +class LCMSampler(sd_samplers_kdiffusion.KDiffusionSampler): + def __init__(self, funcname, sd_model, options=None): + super().__init__(funcname, sd_model, options) + self.model_wrap_cfg = CFGDenoiserLCM(self) + self.model_wrap = self.model_wrap_cfg.inner_model + + +samplers_lcm = [('LCM', sample_lcm, ['k_lcm'], {})] +samplers_data_lcm = [ + sd_samplers_common.SamplerData(label, lambda model, funcname=funcname: LCMSampler(funcname, model), aliases, options) + for label, funcname, aliases, options in samplers_lcm +] From df8aa69a99e38ae59a4e599b9dff11eccf3490f4 Mon Sep 17 00:00:00 2001 From: Sj-Si Date: Mon, 8 Jan 2024 14:10:03 -0500 Subject: [PATCH 0559/1174] Add tree-view display for extra networks. --- html/extra-networks-card-minimal.html | 3 + html/extra-networks-card.html | 1 + javascript/extraNetworks.js | 5 + modules/shared_options.py | 1 + modules/ui_extra_networks.py | 352 ++++++++++++++++++-------- style.css | 92 +++++++ 6 files changed, 355 insertions(+), 99 deletions(-) create mode 100644 html/extra-networks-card-minimal.html diff --git a/html/extra-networks-card-minimal.html b/html/extra-networks-card-minimal.html new file mode 100644 index 00000000000..a6a54d9f4ac --- /dev/null +++ b/html/extra-networks-card-minimal.html @@ -0,0 +1,3 @@ +
    + {name}{copy_path_button}{metadata_button}{edit_button} +
    diff --git a/html/extra-networks-card.html b/html/extra-networks-card.html index 39674666f1e..d76011d737c 100644 --- a/html/extra-networks-card.html +++ b/html/extra-networks-card.html @@ -1,6 +1,7 @@
    {background_image}
    + {copy_path_button} {metadata_button} {edit_button}
    diff --git a/javascript/extraNetworks.js b/javascript/extraNetworks.js index 98a7abb745c..40309d5575c 100644 --- a/javascript/extraNetworks.js +++ b/javascript/extraNetworks.js @@ -337,6 +337,11 @@ function requestGet(url, data, handler, errorHandler) { xhr.send(js); } +function extraNetworksCopyCardPath(event, path) { + navigator.clipboard.writeText(path); + event.stopPropagation(); +} + function extraNetworksRequestMetadata(event, extraPage, cardName) { var showError = function() { extraNetworksShowMetadata("there was an error getting metadata"); diff --git a/modules/shared_options.py b/modules/shared_options.py index d2e86ff10b3..e698c26498c 100644 --- a/modules/shared_options.py +++ b/modules/shared_options.py @@ -238,6 +238,7 @@ "extra_networks_dir_button_function": OptionInfo(False, "Add a '/' to the beginning of directory buttons").info("Buttons will display the contents of the selected directory without acting as a search filter."), "extra_networks_hidden_models": OptionInfo("When searched", "Show cards for models in hidden directories", gr.Radio, {"choices": ["Always", "When searched", "Never"]}).info('"When searched" option will only show the item when the search string has 4 characters or more'), "extra_networks_default_multiplier": OptionInfo(1.0, "Default multiplier for extra networks", gr.Slider, {"minimum": 0.0, "maximum": 2.0, "step": 0.01}), + "extra_networks_tree_view": OptionInfo(False, "Show extra networks using a directory tree view.").needs_reload_ui(), "extra_networks_card_width": OptionInfo(0, "Card width for Extra Networks").info("in pixels"), "extra_networks_card_height": OptionInfo(0, "Card height for Extra Networks").info("in pixels"), "extra_networks_card_text_scale": OptionInfo(1.0, "Card text scale", gr.Slider, {"minimum": 0.0, "maximum": 2.0, "step": 0.01}).info("1 = original size"), diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index fe5d3ba3338..8667617b9f5 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -2,6 +2,8 @@ import os.path import urllib.parse from pathlib import Path +from typing import Optional, Union +from dataclasses import dataclass from modules import shared, ui_extra_networks_user_metadata, errors, extra_networks from modules.images import read_info_from_image, save_image_with_geninfo @@ -15,10 +17,8 @@ extra_pages = [] allowed_dirs = set() - default_allowed_preview_extensions = ["png", "jpg", "jpeg", "webp", "gif"] - @functools.cache def allowed_preview_extensions_with_extra(extra_extensions=None): return set(default_allowed_preview_extensions) | set(extra_extensions or []) @@ -28,6 +28,58 @@ def allowed_preview_extensions(): return allowed_preview_extensions_with_extra((shared.opts.samples_format, )) +@dataclass +class ExtraNetworksItem: + """Wrapper for dictionaries representing ExtraNetworks items.""" + item: dict + + +def get_tree(paths: Union[str, list[str]], items: dict[str, ExtraNetworksItem]) -> dict: + """Recursively builds a directory tree. + + Args: + paths: Path or list of paths to directories. These paths are treated as roots from which + the tree will be built. + items: A dictionary associating filepaths to an ExtraNetworksItem instance. + + Returns: + The result directory tree. + """ + if isinstance(paths, (str,)): + paths = [paths] + + def _get_tree(_paths: list[str]): + _res = {} + for path in _paths: + if os.path.isdir(path): + dir_items = os.listdir(path) + # Ignore empty directories. + if not dir_items: + continue + dir_tree = _get_tree([os.path.join(path, x) for x in dir_items]) + # We only want to store non-empty folders in the tree. + if dir_tree: + _res[os.path.basename(path)] = dir_tree + else: + if path not in items: + continue + # Add the ExtraNetworksItem to the result. + _res[os.path.basename(path)] = items[path] + return _res + + res = {} + # Handle each root directory separately. + # Each root WILL have a key/value at the root of the result dict though + # the value can be an empty dict if the directory is empty. We want these + # placeholders for empty dirs so we can inform the user later. + for path in paths: + # Wrap the path in a list since that is what the `_get_tree` expects. + res[path] = _get_tree([path]) + if res[path]: + res[path] = res[path][os.path.basename(path)] + + return res + def register_page(page): """registers extra networks page for the UI; recommend doing it in on_before_ui() callback for extensions""" @@ -80,7 +132,7 @@ def get_single_card(page: str = "", tabname: str = "", name: str = ""): item = page.items.get(name) page.read_user_metadata(item) - item_html = page.create_html_for_item(item, tabname) + item_html = page.create_item_html(tabname, item) return JSONResponse({"html": item_html}) @@ -96,13 +148,15 @@ def quote_js(s): s = s.replace('"', '\\"') return f'"{s}"' - class ExtraNetworksPage: def __init__(self, title): self.title = title self.name = title.lower() self.id_page = self.name.replace(" ", "_") - self.card_page = shared.html("extra-networks-card.html") + if shared.opts.extra_networks_tree_view: + self.card_page = shared.html("extra-networks-card-minimal.html") + else: + self.card_page = shared.html("extra-networks-card.html") self.allow_prompt = True self.allow_negative_prompt = False self.metadata = {} @@ -136,12 +190,141 @@ def search_terms_from_path(self, filename, possible_directories=None): return "" - def create_html(self, tabname): - items_html = '' + def create_item_html(self, tabname: str, item: dict) -> str: + """Generates HTML for a single ExtraNetworks Item + + Args: + tabname: The name of the active tab. + item: Dictionary containing item information. + + Returns: + HTML string generated for this item. + Can be empty if the item is not meant to be shown. + """ + metadata = item.get("metadata") + if metadata: + self.metadata[item["name"]] = metadata + + if "user_metadata" not in item: + self.read_user_metadata(item) + + preview = item.get("preview", None) + height = f"height: {shared.opts.extra_networks_card_height}px;" if shared.opts.extra_networks_card_height else '' + width = f"width: {shared.opts.extra_networks_card_width}px;" if shared.opts.extra_networks_card_width else '' + background_image = f'' if preview else '' + + onclick = item.get("onclick", None) + if onclick is None: + onclick = '"' + html.escape(f"""return cardClicked({quote_js(tabname)}, {item["prompt"]}, {"true" if self.allow_negative_prompt else "false"})""") + '"' + + copy_path_button = f"
    " + + metadata_button = "" + metadata = item.get("metadata") + if metadata: + metadata_button = f"" + + edit_button = f"
    " + + local_path = "" + filename = item.get("filename", "") + for reldir in self.allowed_directories_for_previews(): + absdir = os.path.abspath(reldir) + + if filename.startswith(absdir): + local_path = filename[len(absdir):] + + # if this is true, the item must not be shown in the default view, and must instead only be + # shown when searching for it + if shared.opts.extra_networks_hidden_models == "Always": + search_only = False + else: + search_only = "/." in local_path or "\\." in local_path + + if search_only and shared.opts.extra_networks_hidden_models == "Never": + return "" + + sort_keys = " ".join([f'data-sort-{k}="{html.escape(str(v))}"' for k, v in item.get("sort_keys", {}).items()]).strip() + + # Some items here might not be used depending on HTML template used. + args = { + "background_image": background_image, + "card_clicked": onclick, + "copy_path_button": copy_path_button, + "description": (item.get("description") or "" if shared.opts.extra_networks_card_show_desc else ""), + "edit_button": edit_button, + "local_preview": quote_js(item["local_preview"]), + "metadata_button": metadata_button, + "name": html.escape(item["name"]), + "prompt": item.get("prompt", None), + "save_card_preview": '"' + html.escape(f"""return saveCardPreview(event, {quote_js(tabname)}, {quote_js(item["local_preview"])})""") + '"', + "search_only": " search_only" if search_only else "", + "search_term": item.get("search_term", ""), + "sort_keys": sort_keys, + "style": f"'display: none; {height}{width}; font-size: {shared.opts.extra_networks_card_text_scale*100}%'", + "tabname": quote_js(tabname), + } + + return self.card_page.format(**args) + + def create_tree_view_html(self, tabname: str) -> str: + """Generates HTML for displaying folders in a tree view. + + Args: + tabname: The name of the active tab. + + Returns: + HTML string generated for this tree view. + """ + self_name_id = self.name.replace(" ", "_") + res = f"
    " self.metadata = {} + self.items = {x["name"]: x for x in self.list_items()} + roots = self.allowed_directories_for_previews() + tree_items = {v["filename"]: ExtraNetworksItem(v) for v in self.items.values()} + tree = get_tree([os.path.abspath(x) for x in roots], items=tree_items) + + if not tree: + return res + "
    " + + file_template = "
  • {}
  • " + dir_template = ( + "
    " + "{}" + "{}
" + "" + ) + + def _build_tree(data: Optional[dict[str, ExtraNetworksItem]] = None) -> str: + """Recursively builds HTML for a tree.""" + _res = "
    " + if not data: + return "
    • DIRECTORY IS EMPTY
    " + + for k, v in sorted(data.items(), key=lambda x: shared.natural_sort_key(x[0])): + if isinstance(v, (ExtraNetworksItem,)): + _res += file_template.format(self.create_item_html(tabname, v.item)) + else: + _res += dir_template.format("", k, _build_tree(v)) + return _res + + res += "
      " + # Add each root directory to the tree. + for k, v in sorted(tree.items(), key=lambda x: shared.natural_sort_key(x[0])): + # If root is empty, append the "disabled" attribute to the template details tag. + res += dir_template.format("open" if v else "open disabled", k, _build_tree(v)) + res += "
    " + res += "" + + return res + + def create_card_view_html(self, tabname): + items_html = "" + self.metadata = {} subdirs = {} + for parentdir in [os.path.abspath(x) for x in self.allowed_directories_for_previews()]: for root, dirs, _ in sorted(os.walk(parentdir, followlinks=True), key=lambda x: shared.natural_sort_key(x[0])): for dirname in sorted(dirs, key=shared.natural_sort_key): @@ -171,40 +354,45 @@ def create_html(self, tabname): if subdirs: subdirs = {"": 1, **subdirs} - subdirs_html = "".join([f""" - -""" for subdir in subdirs]) + subdirs_html_template = ( + "" + ) + subdirs_html = "".join( + [ + subdirs_html_template.format( + " search-all" if subdir == "" else "", + tabname, + html.escape(subdir if subdir != "" else "all"), + ) for subdir in subdirs + ] + ) self.items = {x["name"]: x for x in self.list_items()} for item in self.items.values(): - metadata = item.get("metadata") - if metadata: - self.metadata[item["name"]] = metadata - - if "user_metadata" not in item: - self.read_user_metadata(item) - - items_html += self.create_html_for_item(item, tabname) + items_html += self.create_item_html(tabname, item) - if items_html == '': + if items_html == "": dirs = "".join([f"
  • {x}
  • " for x in self.allowed_directories_for_previews()]) items_html = shared.html("extra-networks-no-cards.html").format(dirs=dirs) self_name_id = self.name.replace(" ", "_") - res = f""" -
    -{subdirs_html} -
    -
    -{items_html} -
    -""" + res = ( + f"
    {subdirs_html}
    " + f"
    {items_html}
    " + ) return res + def create_html(self, tabname): + if shared.opts.extra_networks_tree_view: + return self.create_tree_view_html(tabname) + else: + return self.create_card_view_html(tabname) + def create_item(self, name, index=None): raise NotImplementedError() @@ -214,66 +402,6 @@ def list_items(self): def allowed_directories_for_previews(self): return [] - def create_html_for_item(self, item, tabname): - """ - Create HTML for card item in tab tabname; can return empty string if the item is not meant to be shown. - """ - - preview = item.get("preview", None) - - onclick = item.get("onclick", None) - if onclick is None: - onclick = '"' + html.escape(f"""return cardClicked({quote_js(tabname)}, {item["prompt"]}, {"true" if self.allow_negative_prompt else "false"})""") + '"' - - height = f"height: {shared.opts.extra_networks_card_height}px;" if shared.opts.extra_networks_card_height else '' - width = f"width: {shared.opts.extra_networks_card_width}px;" if shared.opts.extra_networks_card_width else '' - background_image = f'' if preview else '' - metadata_button = "" - metadata = item.get("metadata") - if metadata: - metadata_button = f"" - - edit_button = f"
    " - - local_path = "" - filename = item.get("filename", "") - for reldir in self.allowed_directories_for_previews(): - absdir = os.path.abspath(reldir) - - if filename.startswith(absdir): - local_path = filename[len(absdir):] - - # if this is true, the item must not be shown in the default view, and must instead only be - # shown when searching for it - if shared.opts.extra_networks_hidden_models == "Always": - search_only = False - else: - search_only = "/." in local_path or "\\." in local_path - - if search_only and shared.opts.extra_networks_hidden_models == "Never": - return "" - - sort_keys = " ".join([f'data-sort-{k}="{html.escape(str(v))}"' for k, v in item.get("sort_keys", {}).items()]).strip() - - args = { - "background_image": background_image, - "style": f"'display: none; {height}{width}; font-size: {shared.opts.extra_networks_card_text_scale*100}%'", - "prompt": item.get("prompt", None), - "tabname": quote_js(tabname), - "local_preview": quote_js(item["local_preview"]), - "name": html.escape(item["name"]), - "description": (item.get("description") or "" if shared.opts.extra_networks_card_show_desc else ""), - "card_clicked": onclick, - "save_card_preview": '"' + html.escape(f"""return saveCardPreview(event, {quote_js(tabname)}, {quote_js(item["local_preview"])})""") + '"', - "search_term": item.get("search_term", ""), - "metadata_button": metadata_button, - "edit_button": edit_button, - "search_only": " search_only" if search_only else "", - "sort_keys": sort_keys, - } - - return self.card_page.format(**args) - def get_sort_keys(self, path): """ List of default keys used for sorting in the UI. @@ -360,7 +488,6 @@ def tab_name_score(name): return sorted(pages, key=lambda x: tab_scores[x.name]) - def create_ui(interface: gr.Blocks, unrelated_tabs, tabname): from modules.ui import switch_values_symbol @@ -381,7 +508,6 @@ def create_ui(interface: gr.Blocks, unrelated_tabs, tabname): elem_id = f"{tabname}_{page.id_page}_cards_html" page_elem = gr.HTML('Loading...', elem_id=elem_id) ui.pages.append(page_elem) - page_elem.change(fn=lambda: None, _js='function(){applyExtraNetworkFilter(' + quote_js(tabname) + '); return []}', inputs=[], outputs=[]) editor = page.create_user_metadata_editor(ui, tabname) @@ -390,30 +516,60 @@ def create_ui(interface: gr.Blocks, unrelated_tabs, tabname): related_tabs.append(tab) + tab_controls = {} + edit_search = gr.Textbox('', show_label=False, elem_id=tabname+"_extra_search", elem_classes="search", placeholder="Search...", visible=False, interactive=True) dropdown_sort = gr.Dropdown(choices=['Path', 'Name', 'Date Created', 'Date Modified', ], value=shared.opts.extra_networks_card_order_field, elem_id=tabname+"_extra_sort", elem_classes="sort", multiselect=False, visible=False, show_label=False, interactive=True, label=tabname+"_extra_sort_order") button_sortorder = ToolButton(switch_values_symbol, elem_id=tabname+"_extra_sortorder", elem_classes=["sortorder"] + ([] if shared.opts.extra_networks_card_order == "Ascending" else ["sortReverse"]), visible=False, tooltip="Invert sort order") button_refresh = gr.Button('Refresh', elem_id=tabname+"_extra_refresh", visible=False) checkbox_show_dirs = gr.Checkbox(True, label='Show dirs', elem_id=tabname+"_extra_show_dirs", elem_classes="show-dirs", visible=False) + + tab_controls["edit_search"] = edit_search + tab_controls["dropdown_sort"] = dropdown_sort + tab_controls["button_sortorder"] = button_sortorder + tab_controls["button_refresh"] = button_refresh + tab_controls["checkbox_show_dirs"] = checkbox_show_dirs ui.button_save_preview = gr.Button('Save preview', elem_id=tabname+"_save_preview", visible=False) ui.preview_target_filename = gr.Textbox('Preview save filename', elem_id=tabname+"_preview_filename", visible=False) - tab_controls = [edit_search, dropdown_sort, button_sortorder, button_refresh, checkbox_show_dirs] - for tab in unrelated_tabs: - tab.select(fn=lambda: [gr.update(visible=False) for _ in tab_controls], _js='function(){ extraNetworksUrelatedTabSelected("' + tabname + '"); }', inputs=[], outputs=tab_controls, show_progress=False) + tab.select( + fn=lambda: [gr.update(visible=False) for _ in tab_controls], + _js="function(){ extraNetworksUrelatedTabSelected('" + tabname + "'); }", + inputs=[], + outputs=list(tab_controls.values()), + show_progress=False, + ) + + visible_controls = list(tab_controls.keys()) + if shared.opts.extra_networks_tree_view: + visible_controls = ["button_refresh"] for page, tab in zip(ui.stored_extra_pages, related_tabs): allow_prompt = "true" if page.allow_prompt else "false" allow_negative_prompt = "true" if page.allow_negative_prompt else "false" - jscode = 'extraNetworksTabSelected("' + tabname + '", "' + f"{tabname}_{page.id_page}_prompts" + '", ' + allow_prompt + ', ' + allow_negative_prompt + ');' - - tab.select(fn=lambda: [gr.update(visible=True) for _ in tab_controls], _js='function(){ ' + jscode + ' }', inputs=[], outputs=tab_controls, show_progress=False) + jscode = ( + "extraNetworksTabSelected(" + f"'{tabname}', " + f"'{tabname}_{page.id_page}_prompts', " + f"'{allow_prompt}', " + f"'{allow_negative_prompt}'" + ");" + ) + + tab.select( + fn=lambda: [gr.update(visible=k in visible_controls) for k in tab_controls], + _js="function(){ " + jscode + " }", + inputs=[], + outputs=list(tab_controls.values()), + show_progress=False, + ) dropdown_sort.change(fn=lambda: None, _js="function(){ applyExtraNetworkSort('" + tabname + "'); }") + def pages_html(): if not ui.pages_contents: return refresh() @@ -478,5 +634,3 @@ def save_preview(index, images, filename): for editor in ui.user_metadata_editors: editor.setup_ui(gallery) - - diff --git a/style.css b/style.css index ee39a57b73e..680a5f837b7 100644 --- a/style.css +++ b/style.css @@ -1157,3 +1157,95 @@ body.resizing .resize-handle { left: 7.5px; border-left: 1px dashed var(--border-color-primary); } + +.extra-network-cards .card .copy-path-button:before { + content: "⎘"; +} + +.extra-network-cards .card-minimal .button-column { + display: inline-flex; + visibility: hidden; + color: white; + padding-left: 0.5rem; + padding-right: 0.5rem; + align-items: center; +} + +.extra-network-cards .card-minimal:hover .button-column { + visibility: visible; +} + +.extra-network-cards .card-minimal .copy-path-button:before { + content: "⎘"; +} + +.extra-network-cards .card-minimal .metadata-button:before{ + content: "🛈"; +} + +.extra-network-cards .card-minimal .edit-button:before{ + content: "🛠"; +} + +.extra-network-cards .card-minimal .card-button { + color: white; + text-shadow: 2px 2px 3px black; + font-size: 1rem; + width: 1.5rem; +} + +.extra-network-cards .card-minimal .card-button:hover { + color: red; +} + +.extra-network-cards .card-minimal { + display: inline-flex; + position: relative; + overflow: hidden; + cursor: pointer; + font-size: 1rem; + font-weight: bold; + line-break: anywhere; +} + +.file-item { + list-style-type: '📄'; +} + +/* prevents clicking/collapsing of details tags when disabled attribute is used*/ +details[disabled] summary { + pointer-events: none; + user-select: none; +} + +details.folder-item > summary { + list-style-type: '📁'; +} + +details.folder-item[open] > summary { + list-style-type: '📂'; +} + +.file-item, +.folder-item, +.folder-item-summary { + display: block; + font-size: 1rem; + padding: 0.05rem; + cursor: pointer; + user-select: none; +} + +.folder-item-summary:hover, +.file-item:hover { + -webkit-transition: all 0.1s ease-in-out; + transition: all 0.1s ease-in-out; + background-color: var(--neutral-200); +} + +.dark .folder-item-summary:hover, +.dark .file-item:hover { + -webkit-transition: all 0.05s ease-in-out; + transition: all 0.05s ease-in-out; + background-color: var(--neutral-800); +} From 67a70ad112e1b0fb262852ab830896302dbd306a Mon Sep 17 00:00:00 2001 From: Sj-Si Date: Mon, 8 Jan 2024 14:11:24 -0500 Subject: [PATCH 0560/1174] fix indentation --- html/extra-networks-card.html | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/html/extra-networks-card.html b/html/extra-networks-card.html index d76011d737c..7770094da14 100644 --- a/html/extra-networks-card.html +++ b/html/extra-networks-card.html @@ -1,7 +1,7 @@
    {background_image}
    - {copy_path_button} + {copy_path_button} {metadata_button} {edit_button}
    From 34fc215249e2bc0acc66cda47319e40b6e46a05f Mon Sep 17 00:00:00 2001 From: Sj-Si Date: Mon, 8 Jan 2024 14:23:01 -0500 Subject: [PATCH 0561/1174] fix linting --- modules/ui_extra_networks.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index 8667617b9f5..ab484a5dc0b 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -216,7 +216,7 @@ def create_item_html(self, tabname: str, item: dict) -> str: onclick = item.get("onclick", None) if onclick is None: onclick = '"' + html.escape(f"""return cardClicked({quote_js(tabname)}, {item["prompt"]}, {"true" if self.allow_negative_prompt else "false"})""") + '"' - + copy_path_button = f"
    " metadata_button = "" @@ -523,7 +523,7 @@ def create_ui(interface: gr.Blocks, unrelated_tabs, tabname): button_sortorder = ToolButton(switch_values_symbol, elem_id=tabname+"_extra_sortorder", elem_classes=["sortorder"] + ([] if shared.opts.extra_networks_card_order == "Ascending" else ["sortReverse"]), visible=False, tooltip="Invert sort order") button_refresh = gr.Button('Refresh', elem_id=tabname+"_extra_refresh", visible=False) checkbox_show_dirs = gr.Checkbox(True, label='Show dirs', elem_id=tabname+"_extra_show_dirs", elem_classes="show-dirs", visible=False) - + tab_controls["edit_search"] = edit_search tab_controls["dropdown_sort"] = dropdown_sort tab_controls["button_sortorder"] = button_sortorder @@ -560,7 +560,7 @@ def create_ui(interface: gr.Blocks, unrelated_tabs, tabname): ) tab.select( - fn=lambda: [gr.update(visible=k in visible_controls) for k in tab_controls], + fn=lambda: [gr.update(visible=k in visible_controls) for k in tab_controls], _js="function(){ " + jscode + " }", inputs=[], outputs=list(tab_controls.values()), From 46413b20a4d88355b5fac5d27cc9fcb634cdb49e Mon Sep 17 00:00:00 2001 From: Sj-Si Date: Mon, 8 Jan 2024 16:23:35 -0500 Subject: [PATCH 0562/1174] Increase limits for upscalers. --- modules/api/models.py | 2 +- scripts/postprocessing_upscale.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/modules/api/models.py b/modules/api/models.py index 33894b3e694..11ba4f0b952 100644 --- a/modules/api/models.py +++ b/modules/api/models.py @@ -143,7 +143,7 @@ class ExtrasBaseRequest(BaseModel): gfpgan_visibility: float = Field(default=0, title="GFPGAN Visibility", ge=0, le=1, allow_inf_nan=False, description="Sets the visibility of GFPGAN, values should be between 0 and 1.") codeformer_visibility: float = Field(default=0, title="CodeFormer Visibility", ge=0, le=1, allow_inf_nan=False, description="Sets the visibility of CodeFormer, values should be between 0 and 1.") codeformer_weight: float = Field(default=0, title="CodeFormer Weight", ge=0, le=1, allow_inf_nan=False, description="Sets the weight of CodeFormer, values should be between 0 and 1.") - upscaling_resize: float = Field(default=2, title="Upscaling Factor", ge=1, le=8, description="By how much to upscale the image, only used when resize_mode=0.") + upscaling_resize: float = Field(default=2, title="Upscaling Factor", ge=1, description="By how much to upscale the image, only used when resize_mode=0.") upscaling_resize_w: int = Field(default=512, title="Target Width", ge=1, description="Target width for the upscaler to hit. Only used when resize_mode=1.") upscaling_resize_h: int = Field(default=512, title="Target Height", ge=1, description="Target height for the upscaler to hit. Only used when resize_mode=1.") upscaling_crop: bool = Field(default=True, title="Crop to fit", description="Should the upscaler crop the image to fit in the chosen size?") diff --git a/scripts/postprocessing_upscale.py b/scripts/postprocessing_upscale.py index ed709688de4..5946678e5be 100644 --- a/scripts/postprocessing_upscale.py +++ b/scripts/postprocessing_upscale.py @@ -21,13 +21,13 @@ def ui(self): with FormRow(): with gr.Tabs(elem_id="extras_resize_mode"): with gr.TabItem('Scale by', elem_id="extras_scale_by_tab") as tab_scale_by: - upscaling_resize = gr.Slider(minimum=1.0, maximum=8.0, step=0.05, label="Resize", value=4, elem_id="extras_upscaling_resize") + upscaling_resize = gr.Slider(minimum=1.0, maximum=100.0, step=0.05, label="Resize", value=2, elem_id="extras_upscaling_resize") with gr.TabItem('Scale to', elem_id="extras_scale_to_tab") as tab_scale_to: with FormRow(): with gr.Column(elem_id="upscaling_column_size", scale=4): - upscaling_resize_w = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="extras_upscaling_resize_w") - upscaling_resize_h = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="extras_upscaling_resize_h") + upscaling_resize_w = gr.Slider(minimum=64, maximum=8192, step=8, label="Width", value=512, elem_id="extras_upscaling_resize_w") + upscaling_resize_h = gr.Slider(minimum=64, maximum=8192, step=8, label="Height", value=512, elem_id="extras_upscaling_resize_h") with gr.Column(elem_id="upscaling_dimensions_row", scale=1, elem_classes="dimensions-tools"): upscaling_res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="upscaling_res_switch_btn", tooltip="Switch width/height") upscaling_crop = gr.Checkbox(label='Crop to fit', value=True, elem_id="extras_upscaling_crop") From 8d986727b39ee6616190559efa1c41c1942b99b0 Mon Sep 17 00:00:00 2001 From: unknown Date: Tue, 9 Jan 2024 03:01:20 -0600 Subject: [PATCH 0563/1174] include tls arguments in api uvicorn init --- modules/api/api.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/modules/api/api.py b/modules/api/api.py index 9d1292e9571..59e46335231 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -879,7 +879,15 @@ def get_extensions_list(self): def launch(self, server_name, port, root_path): self.app.include_router(self.router) - uvicorn.run(self.app, host=server_name, port=port, timeout_keep_alive=shared.cmd_opts.timeout_keep_alive, root_path=root_path) + uvicorn.run( + self.app, + host=server_name, + port=port, + timeout_keep_alive=shared.cmd_opts.timeout_keep_alive, + root_path=root_path, + ssl_keyfile=shared.cmd_opts.tls_keyfile, + ssl_certfile=shared.cmd_opts.tls_certfile + ) def kill_webui(self): restart.stop_program() From 209c26a1cb9e4be357ab3c5e7613caf3cbc26183 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Tue, 9 Jan 2024 22:11:44 +0800 Subject: [PATCH 0564/1174] improve efficiency and support more device --- modules/devices.py | 60 ++++++++++++++++++++++++++++++------------ modules/shared_init.py | 1 + 2 files changed, 44 insertions(+), 17 deletions(-) diff --git a/modules/devices.py b/modules/devices.py index ff279ac5016..6edfb127896 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -110,6 +110,7 @@ def enable_tf32(): dtype: torch.dtype = torch.float16 dtype_vae: torch.dtype = torch.float16 dtype_unet: torch.dtype = torch.float16 +dtype_inference: torch.dtype = torch.float16 unet_needs_upcast = False @@ -131,21 +132,49 @@ def cond_cast_float(input): ] -def manual_cast_forward(self, *args, **kwargs): - org_dtype = torch_utils.get_param(self).dtype - self.to(dtype) - args = [arg.to(dtype) if isinstance(arg, torch.Tensor) else arg for arg in args] - kwargs = {k: v.to(dtype) if isinstance(v, torch.Tensor) else v for k, v in kwargs.items()} - result = self.org_forward(*args, **kwargs) - self.to(org_dtype) - return result +def manual_cast_forward(target_dtype): + def forward_wrapper(self, *args, **kwargs): + org_dtype = torch_utils.get_param(self).dtype + if not target_dtype == org_dtype == dtype_inference: + self.to(target_dtype) + args = [ + arg.to(target_dtype) + if isinstance(arg, torch.Tensor) + else arg + for arg in args + ] + kwargs = { + k: v.to(target_dtype) + if isinstance(v, torch.Tensor) + else v + for k, v in kwargs.items() + } + + result = self.org_forward(*args, **kwargs) + self.to(org_dtype) + + if target_dtype != dtype_inference: + if isinstance(result, tuple): + result = tuple( + i.to(dtype_inference) + if isinstance(i, torch.Tensor) + else i + for i in result + ) + elif isinstance(result, torch.Tensor): + result = result.to(dtype_inference) + return result + return forward_wrapper @contextlib.contextmanager -def manual_cast(): +def manual_cast(target_dtype): for module_type in patch_module_list: org_forward = module_type.forward - module_type.forward = manual_cast_forward + if module_type == torch.nn.MultiheadAttention and has_xpu(): + module_type.forward = manual_cast_forward(torch.float32) + else: + module_type.forward = manual_cast_forward(target_dtype) module_type.org_forward = org_forward try: yield None @@ -161,15 +190,12 @@ def autocast(disable=False): if fp8 and device==cpu: return torch.autocast("cpu", dtype=torch.bfloat16, enabled=True) - if fp8 and (dtype == torch.float32 or shared.cmd_opts.precision == "full" or cuda_no_autocast()): - return manual_cast() - - if has_mps() and shared.cmd_opts.precision != "full": - return manual_cast() - - if dtype == torch.float32 or shared.cmd_opts.precision == "full": + if dtype == torch.float32 and shared.cmd_opts.precision == "full": return contextlib.nullcontext() + if has_xpu() or has_mps() or cuda_no_autocast(): + return manual_cast(dtype_inference) + return torch.autocast("cuda") diff --git a/modules/shared_init.py b/modules/shared_init.py index 586be3423d0..935e3a21cf2 100644 --- a/modules/shared_init.py +++ b/modules/shared_init.py @@ -29,6 +29,7 @@ def initialize(): devices.dtype = torch.float32 if cmd_opts.no_half else torch.float16 devices.dtype_vae = torch.float32 if cmd_opts.no_half or cmd_opts.no_half_vae else torch.float16 + devices.dtype_inference = torch.float32 if cmd_opts.precision == 'full' else devices.dtype shared.device = devices.device shared.weight_load_location = None if cmd_opts.lowram else "cpu" From 42e6df723c68af775b73c9fa4f43f99345348689 Mon Sep 17 00:00:00 2001 From: KohakuBlueleaf Date: Tue, 9 Jan 2024 22:39:39 +0800 Subject: [PATCH 0565/1174] Fix bugs when arg dtype doesn't match --- modules/devices.py | 25 ++++++++++--------------- 1 file changed, 10 insertions(+), 15 deletions(-) diff --git a/modules/devices.py b/modules/devices.py index 6edfb127896..e05740524f3 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -134,24 +134,19 @@ def cond_cast_float(input): def manual_cast_forward(target_dtype): def forward_wrapper(self, *args, **kwargs): + if any( + isinstance(arg, torch.Tensor) and arg.dtype != target_dtype + for arg in args + ): + args = [arg.to(target_dtype) if isinstance(arg, torch.Tensor) else arg for arg in args] + kwargs = {k: v.to(target_dtype) if isinstance(v, torch.Tensor) else v for k, v in kwargs.items()} + org_dtype = torch_utils.get_param(self).dtype - if not target_dtype == org_dtype == dtype_inference: + if org_dtype != target_dtype: self.to(target_dtype) - args = [ - arg.to(target_dtype) - if isinstance(arg, torch.Tensor) - else arg - for arg in args - ] - kwargs = { - k: v.to(target_dtype) - if isinstance(v, torch.Tensor) - else v - for k, v in kwargs.items() - } - result = self.org_forward(*args, **kwargs) - self.to(org_dtype) + if org_dtype != target_dtype: + self.to(org_dtype) if target_dtype != dtype_inference: if isinstance(result, tuple): From c2c05fcca8f3547783c5440c04ec10cc63c65db5 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Tue, 9 Jan 2024 22:53:58 +0800 Subject: [PATCH 0566/1174] linting and debugs --- modules/devices.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/modules/devices.py b/modules/devices.py index e05740524f3..ad36f6562bc 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -140,20 +140,20 @@ def forward_wrapper(self, *args, **kwargs): ): args = [arg.to(target_dtype) if isinstance(arg, torch.Tensor) else arg for arg in args] kwargs = {k: v.to(target_dtype) if isinstance(v, torch.Tensor) else v for k, v in kwargs.items()} - + org_dtype = torch_utils.get_param(self).dtype if org_dtype != target_dtype: self.to(target_dtype) result = self.org_forward(*args, **kwargs) if org_dtype != target_dtype: self.to(org_dtype) - + if target_dtype != dtype_inference: if isinstance(result, tuple): result = tuple( - i.to(dtype_inference) - if isinstance(i, torch.Tensor) - else i + i.to(dtype_inference) + if isinstance(i, torch.Tensor) + else i for i in result ) elif isinstance(result, torch.Tensor): @@ -185,7 +185,7 @@ def autocast(disable=False): if fp8 and device==cpu: return torch.autocast("cpu", dtype=torch.bfloat16, enabled=True) - if dtype == torch.float32 and shared.cmd_opts.precision == "full": + if dtype == torch.float32: return contextlib.nullcontext() if has_xpu() or has_mps() or cuda_no_autocast(): From e00365962b17550a42235d1fbe2ad2c7cc4b8961 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Tue, 9 Jan 2024 23:13:34 +0800 Subject: [PATCH 0567/1174] Apply correct inference precision implementation --- modules/devices.py | 42 +++++++++++++++++++++++++++++++++--------- 1 file changed, 33 insertions(+), 9 deletions(-) diff --git a/modules/devices.py b/modules/devices.py index ad36f6562bc..9e1f207c336 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -132,6 +132,21 @@ def cond_cast_float(input): ] +def cast_output(result): + if isinstance(result, tuple): + result = tuple(i.to(dtype_inference) if isinstance(i, torch.Tensor) else i for i in result) + elif isinstance(result, torch.Tensor): + result = result.to(dtype_inference) + return result + + +def autocast_with_cast_output(self, *args, **kwargs): + result = self.org_forward(*args, **kwargs) + if dtype_inference != dtype: + result = cast_output(result) + return result + + def manual_cast_forward(target_dtype): def forward_wrapper(self, *args, **kwargs): if any( @@ -149,15 +164,7 @@ def forward_wrapper(self, *args, **kwargs): self.to(org_dtype) if target_dtype != dtype_inference: - if isinstance(result, tuple): - result = tuple( - i.to(dtype_inference) - if isinstance(i, torch.Tensor) - else i - for i in result - ) - elif isinstance(result, torch.Tensor): - result = result.to(dtype_inference) + result = cast_output(result) return result return forward_wrapper @@ -178,6 +185,20 @@ def manual_cast(target_dtype): module_type.forward = module_type.org_forward +@contextlib.contextmanager +def precision_full_with_autocast(autocast_ctx): + for module_type in patch_module_list: + org_forward = module_type.forward + module_type.forward = autocast_with_cast_output + module_type.org_forward = org_forward + try: + with autocast_ctx: + yield None + finally: + for module_type in patch_module_list: + module_type.forward = module_type.org_forward + + def autocast(disable=False): if disable: return contextlib.nullcontext() @@ -191,6 +212,9 @@ def autocast(disable=False): if has_xpu() or has_mps() or cuda_no_autocast(): return manual_cast(dtype_inference) + if dtype_inference == torch.float32 and dtype != torch.float32: + return precision_full_with_autocast(torch.autocast("cuda")) + return torch.autocast("cuda") From 1fd69655fe340325863cbd7bf5297e034a6a3a0a Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Tue, 9 Jan 2024 23:15:05 +0800 Subject: [PATCH 0568/1174] Revert "Apply correct inference precision implementation" This reverts commit e00365962b17550a42235d1fbe2ad2c7cc4b8961. --- modules/devices.py | 42 +++++++++--------------------------------- 1 file changed, 9 insertions(+), 33 deletions(-) diff --git a/modules/devices.py b/modules/devices.py index 9e1f207c336..ad36f6562bc 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -132,21 +132,6 @@ def cond_cast_float(input): ] -def cast_output(result): - if isinstance(result, tuple): - result = tuple(i.to(dtype_inference) if isinstance(i, torch.Tensor) else i for i in result) - elif isinstance(result, torch.Tensor): - result = result.to(dtype_inference) - return result - - -def autocast_with_cast_output(self, *args, **kwargs): - result = self.org_forward(*args, **kwargs) - if dtype_inference != dtype: - result = cast_output(result) - return result - - def manual_cast_forward(target_dtype): def forward_wrapper(self, *args, **kwargs): if any( @@ -164,7 +149,15 @@ def forward_wrapper(self, *args, **kwargs): self.to(org_dtype) if target_dtype != dtype_inference: - result = cast_output(result) + if isinstance(result, tuple): + result = tuple( + i.to(dtype_inference) + if isinstance(i, torch.Tensor) + else i + for i in result + ) + elif isinstance(result, torch.Tensor): + result = result.to(dtype_inference) return result return forward_wrapper @@ -185,20 +178,6 @@ def manual_cast(target_dtype): module_type.forward = module_type.org_forward -@contextlib.contextmanager -def precision_full_with_autocast(autocast_ctx): - for module_type in patch_module_list: - org_forward = module_type.forward - module_type.forward = autocast_with_cast_output - module_type.org_forward = org_forward - try: - with autocast_ctx: - yield None - finally: - for module_type in patch_module_list: - module_type.forward = module_type.org_forward - - def autocast(disable=False): if disable: return contextlib.nullcontext() @@ -212,9 +191,6 @@ def autocast(disable=False): if has_xpu() or has_mps() or cuda_no_autocast(): return manual_cast(dtype_inference) - if dtype_inference == torch.float32 and dtype != torch.float32: - return precision_full_with_autocast(torch.autocast("cuda")) - return torch.autocast("cuda") From 58d5b042cd02f287faabef399134b97d323691f2 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Tue, 9 Jan 2024 23:23:40 +0800 Subject: [PATCH 0569/1174] Apply the correct behavior of precision='full' --- modules/devices.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/modules/devices.py b/modules/devices.py index ad36f6562bc..29a270d11a6 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -185,11 +185,14 @@ def autocast(disable=False): if fp8 and device==cpu: return torch.autocast("cpu", dtype=torch.bfloat16, enabled=True) - if dtype == torch.float32: - return contextlib.nullcontext() - if has_xpu() or has_mps() or cuda_no_autocast(): - return manual_cast(dtype_inference) + return manual_cast(dtype) + + if fp8 and dtype_inference == torch.float32: + return manual_cast(dtype) + + if dtype == torch.float32 or dtype_inference == torch.float32: + return contextlib.nullcontext() return torch.autocast("cuda") From ca671e5d7b9d03227f01e6bcb350032b6d14e722 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Tue, 9 Jan 2024 23:30:55 +0800 Subject: [PATCH 0570/1174] rearrange if-statements for cpu --- modules/devices.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/modules/devices.py b/modules/devices.py index 29a270d11a6..0321d12c6ab 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -185,15 +185,15 @@ def autocast(disable=False): if fp8 and device==cpu: return torch.autocast("cpu", dtype=torch.bfloat16, enabled=True) - if has_xpu() or has_mps() or cuda_no_autocast(): - return manual_cast(dtype) - if fp8 and dtype_inference == torch.float32: return manual_cast(dtype) if dtype == torch.float32 or dtype_inference == torch.float32: return contextlib.nullcontext() + if has_xpu() or has_mps() or cuda_no_autocast(): + return manual_cast(dtype) + return torch.autocast("cuda") From 3cc7572f5dec429d265ce937249eb4b5cc18d0ba Mon Sep 17 00:00:00 2001 From: Sj-Si Date: Tue, 9 Jan 2024 11:46:10 -0500 Subject: [PATCH 0571/1174] Restore scale factor limit changes to master branch. --- modules/api/models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/api/models.py b/modules/api/models.py index 11ba4f0b952..33894b3e694 100644 --- a/modules/api/models.py +++ b/modules/api/models.py @@ -143,7 +143,7 @@ class ExtrasBaseRequest(BaseModel): gfpgan_visibility: float = Field(default=0, title="GFPGAN Visibility", ge=0, le=1, allow_inf_nan=False, description="Sets the visibility of GFPGAN, values should be between 0 and 1.") codeformer_visibility: float = Field(default=0, title="CodeFormer Visibility", ge=0, le=1, allow_inf_nan=False, description="Sets the visibility of CodeFormer, values should be between 0 and 1.") codeformer_weight: float = Field(default=0, title="CodeFormer Weight", ge=0, le=1, allow_inf_nan=False, description="Sets the weight of CodeFormer, values should be between 0 and 1.") - upscaling_resize: float = Field(default=2, title="Upscaling Factor", ge=1, description="By how much to upscale the image, only used when resize_mode=0.") + upscaling_resize: float = Field(default=2, title="Upscaling Factor", ge=1, le=8, description="By how much to upscale the image, only used when resize_mode=0.") upscaling_resize_w: int = Field(default=512, title="Target Width", ge=1, description="Target width for the upscaler to hit. Only used when resize_mode=1.") upscaling_resize_h: int = Field(default=512, title="Target Height", ge=1, description="Target height for the upscaler to hit. Only used when resize_mode=1.") upscaling_crop: bool = Field(default=True, title="Crop to fit", description="Should the upscaler crop the image to fit in the chosen size?") From 9dd25348248c06770897f4cde64e2663dfa9f2de Mon Sep 17 00:00:00 2001 From: Sj-Si Date: Tue, 9 Jan 2024 11:47:39 -0500 Subject: [PATCH 0572/1174] Restore scale factor limit changes to master branch. --- scripts/postprocessing_upscale.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/postprocessing_upscale.py b/scripts/postprocessing_upscale.py index 5946678e5be..a57f9d4a4b6 100644 --- a/scripts/postprocessing_upscale.py +++ b/scripts/postprocessing_upscale.py @@ -21,7 +21,7 @@ def ui(self): with FormRow(): with gr.Tabs(elem_id="extras_resize_mode"): with gr.TabItem('Scale by', elem_id="extras_scale_by_tab") as tab_scale_by: - upscaling_resize = gr.Slider(minimum=1.0, maximum=100.0, step=0.05, label="Resize", value=2, elem_id="extras_upscaling_resize") + upscaling_resize = gr.Slider(minimum=1.0, maximum=8.0, step=0.05, label="Resize", value=4, elem_id="extras_upscaling_resize") with gr.TabItem('Scale to', elem_id="extras_scale_to_tab") as tab_scale_to: with FormRow(): From 4d9f2c3ec8e791b9c354292c4333fc85b8f8d740 Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Sat, 6 Jan 2024 21:41:29 +0900 Subject: [PATCH 0573/1174] update p.seed and p.subseed --- modules/txt2img.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/modules/txt2img.py b/modules/txt2img.py index c4cc12d2f6d..d22a1f319e3 100644 --- a/modules/txt2img.py +++ b/modules/txt2img.py @@ -67,13 +67,16 @@ def txt2img_upscale(id_task: str, request: gr.Request, gallery, gallery_index, g geninfo = json.loads(generation_info) all_seeds = geninfo["all_seeds"] + all_subseeds = geninfo["all_subseeds"] image_info = gallery[gallery_index] if 0 <= gallery_index < len(gallery) else gallery[0] p.firstpass_image = infotext_utils.image_from_url_text(image_info) gallery_index_from_end = len(gallery) - gallery_index seed = all_seeds[-gallery_index_from_end if gallery_index_from_end < len(all_seeds) + 1 else 0] - p.script_args = modules.scripts.scripts_txt2img.set_named_arg(p.script_args, 'ScriptSeed', 'seed', seed) + subseed = all_subseeds[-gallery_index_from_end if gallery_index_from_end < len(all_seeds) + 1 else 0] + p.seed = seed + p.subseed = subseed with closing(p): processed = modules.scripts.scripts_txt2img.run(p, *p.script_args) From 3db6938caa719aaa38b52edecf42740ef62b0c3c Mon Sep 17 00:00:00 2001 From: Sj-Si Date: Wed, 10 Jan 2024 18:11:48 -0500 Subject: [PATCH 0574/1174] begin redesign of tree module. --- html/extra-networks-card-minimal.html | 3 +- html/extra-networks-pane.html | 11 ++ html/extra-networks-tree-directory.html | 4 + html/extra-networks-tree-file.html | 1 + javascript/extraNetworks.js | 22 +++ modules/shared_options.py | 1 - modules/ui_extra_networks.py | 164 ++++++++++--------- style.css | 204 +++++++++++++----------- 8 files changed, 248 insertions(+), 162 deletions(-) create mode 100644 html/extra-networks-pane.html create mode 100644 html/extra-networks-tree-directory.html create mode 100644 html/extra-networks-tree-file.html diff --git a/html/extra-networks-card-minimal.html b/html/extra-networks-card-minimal.html index a6a54d9f4ac..d66df7dfb3f 100644 --- a/html/extra-networks-card-minimal.html +++ b/html/extra-networks-card-minimal.html @@ -1,3 +1,4 @@
    - {name}{copy_path_button}{metadata_button}{edit_button} + {name} + {copy_path_button}{metadata_button}{edit_button}
    diff --git a/html/extra-networks-pane.html b/html/extra-networks-pane.html new file mode 100644 index 00000000000..93bad698ddb --- /dev/null +++ b/html/extra-networks-pane.html @@ -0,0 +1,11 @@ +
    + {subdirs_html} +
    +
    +
    + {tree_html} +
    +
    + {items_html} +
    +
    \ No newline at end of file diff --git a/html/extra-networks-tree-directory.html b/html/extra-networks-tree-directory.html new file mode 100644 index 00000000000..cec15588635 --- /dev/null +++ b/html/extra-networks-tree-directory.html @@ -0,0 +1,4 @@ +
    +{folder_name} +{content} +
    \ No newline at end of file diff --git a/html/extra-networks-tree-file.html b/html/extra-networks-tree-file.html new file mode 100644 index 00000000000..8b137891791 --- /dev/null +++ b/html/extra-networks-tree-file.html @@ -0,0 +1 @@ + diff --git a/javascript/extraNetworks.js b/javascript/extraNetworks.js index 40309d5575c..33f45c8e4d0 100644 --- a/javascript/extraNetworks.js +++ b/javascript/extraNetworks.js @@ -253,6 +253,28 @@ function saveCardPreview(event, tabname, filename) { event.preventDefault(); } +function extraNetworksFolderClick(event, tabs_id) { + var els = document.querySelectorAll(".folder-item-summary.selected"); + [...els].forEach(el => { + el.classList.remove("selected"); + }); + event.target.classList.add("selected"); + + var searchTextArea = gradioApp().querySelector("#" + tabs_id + ' > label > textarea'); + var text = event.target.classList.contains("search-all") ? "" : event.target.firstChild.textContent.trim(); + searchTextArea.value = text; + updateInput(searchTextArea); + + if (event.target.parentElement.open) { + // before close + console.log("closed"); + } else { + // before open + console.log("opened"); + //console.log("Opened:", event.target.parentElement); + } +} + function extraNetworksSearchButton(tabs_id, event) { var searchTextarea = gradioApp().querySelector("#" + tabs_id + ' > label > textarea'); var button = event.target; diff --git a/modules/shared_options.py b/modules/shared_options.py index e698c26498c..d2e86ff10b3 100644 --- a/modules/shared_options.py +++ b/modules/shared_options.py @@ -238,7 +238,6 @@ "extra_networks_dir_button_function": OptionInfo(False, "Add a '/' to the beginning of directory buttons").info("Buttons will display the contents of the selected directory without acting as a search filter."), "extra_networks_hidden_models": OptionInfo("When searched", "Show cards for models in hidden directories", gr.Radio, {"choices": ["Always", "When searched", "Never"]}).info('"When searched" option will only show the item when the search string has 4 characters or more'), "extra_networks_default_multiplier": OptionInfo(1.0, "Default multiplier for extra networks", gr.Slider, {"minimum": 0.0, "maximum": 2.0, "step": 0.01}), - "extra_networks_tree_view": OptionInfo(False, "Show extra networks using a directory tree view.").needs_reload_ui(), "extra_networks_card_width": OptionInfo(0, "Card width for Extra Networks").info("in pixels"), "extra_networks_card_height": OptionInfo(0, "Card height for Extra Networks").info("in pixels"), "extra_networks_card_text_scale": OptionInfo(1.0, "Card text scale", gr.Slider, {"minimum": 0.0, "maximum": 2.0, "step": 0.01}).info("1 = original size"), diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index ab484a5dc0b..6318594fa5b 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -73,10 +73,11 @@ def _get_tree(_paths: list[str]): # the value can be an empty dict if the directory is empty. We want these # placeholders for empty dirs so we can inform the user later. for path in paths: + short_path = os.path.basename(path) # Wrap the path in a list since that is what the `_get_tree` expects. - res[path] = _get_tree([path]) - if res[path]: - res[path] = res[path][os.path.basename(path)] + res[short_path] = _get_tree([path]) + if res[short_path]: + res[short_path] = res[short_path][os.path.basename(path)] return res @@ -153,10 +154,9 @@ def __init__(self, title): self.title = title self.name = title.lower() self.id_page = self.name.replace(" ", "_") - if shared.opts.extra_networks_tree_view: - self.card_page = shared.html("extra-networks-card-minimal.html") - else: - self.card_page = shared.html("extra-networks-card.html") + self.extra_networks_pane_template = shared.html("extra-networks-pane.html") + self.card_page_template = shared.html("extra-networks-card.html") + self.card_page_minimal_template = shared.html("extra-networks-card-minimal.html") self.allow_prompt = True self.allow_negative_prompt = False self.metadata = {} @@ -182,15 +182,14 @@ def link_preview(self, filename): def search_terms_from_path(self, filename, possible_directories=None): abspath = os.path.abspath(filename) - for parentdir in (possible_directories if possible_directories is not None else self.allowed_directories_for_previews()): - parentdir = os.path.abspath(parentdir) + parentdir = os.path.dirname(os.path.abspath(parentdir)) if abspath.startswith(parentdir): - return abspath[len(parentdir):].replace('\\', '/') + return os.path.relpath(abspath, parentdir) return "" - def create_item_html(self, tabname: str, item: dict) -> str: + def create_item_html(self, tabname: str, item: dict, template: Optional[str] = None) -> str: """Generates HTML for a single ExtraNetworks Item Args: @@ -265,7 +264,10 @@ def create_item_html(self, tabname: str, item: dict) -> str: "tabname": quote_js(tabname), } - return self.card_page.format(**args) + if template: + return template.format(**args) + else: + return self.card_page.format(**args) def create_tree_view_html(self, tabname: str) -> str: """Generates HTML for displaying folders in a tree view. @@ -276,53 +278,67 @@ def create_tree_view_html(self, tabname: str) -> str: Returns: HTML string generated for this tree view. """ - self_name_id = self.name.replace(" ", "_") - res = f"
    " - - self.metadata = {} - self.items = {x["name"]: x for x in self.list_items()} + res = f"" + # Generate HTML for the tree. roots = self.allowed_directories_for_previews() tree_items = {v["filename"]: ExtraNetworksItem(v) for v in self.items.values()} tree = get_tree([os.path.abspath(x) for x in roots], items=tree_items) if not tree: - return res + "
    " + return res - file_template = "
  • {}
  • " + file_template = "
  • {card}
  • " dir_template = ( - "
    " - "{}" - "{}
" + "
" + "" + "{folder_name}" + "" + "
    {content}
" "
" ) def _build_tree(data: Optional[dict[str, ExtraNetworksItem]] = None) -> str: """Recursively builds HTML for a tree.""" - _res = "
    " + _res = "" if not data: - return "
    • DIRECTORY IS EMPTY
    " + return "
  • DIRECTORY IS EMPTY
  • " for k, v in sorted(data.items(), key=lambda x: shared.natural_sort_key(x[0])): if isinstance(v, (ExtraNetworksItem,)): - _res += file_template.format(self.create_item_html(tabname, v.item)) + item_html = self.create_item_html(tabname, v.item, self.card_page_minimal_template) + _res += file_template.format(**{"card": item_html}) else: - _res += dir_template.format("", k, _build_tree(v)) + tmp = dir_template.format( + **{ + "attributes": "", + "tabname": tabname, + "folder_name": k, + "content": _build_tree(v), + } + ) + _res += tmp + return _res - res += "
      " # Add each root directory to the tree. for k, v in sorted(tree.items(), key=lambda x: shared.natural_sort_key(x[0])): # If root is empty, append the "disabled" attribute to the template details tag. - res += dir_template.format("open" if v else "open disabled", k, _build_tree(v)) + res += "
        " + res += dir_template.format( + **{ + "attributes": "open" if v else "open", + "tabname": tabname, + "folder_name": k, + "content": _build_tree(v), + } + ) + res += "
      " res += "
    " - res += "" return res - def create_card_view_html(self, tabname): - items_html = "" - self.metadata = {} + def create_subdirs_html(self, tabname): subdirs = {} for parentdir in [os.path.abspath(x) for x in self.allowed_directories_for_previews()]: @@ -355,43 +371,53 @@ def create_card_view_html(self, tabname): subdirs = {"": 1, **subdirs} subdirs_html_template = ( - "" ) - subdirs_html = "".join( + return "".join( [ subdirs_html_template.format( - " search-all" if subdir == "" else "", - tabname, - html.escape(subdir if subdir != "" else "all"), + **{ + "classes": "search-all" if not subdir else "", + "tabname": tabname, + "content": html.escape(subdir if subdir else "all"), + } ) for subdir in subdirs ] ) + def create_card_view_html(self, tabname): + res = "" self.items = {x["name"]: x for x in self.list_items()} for item in self.items.values(): - items_html += self.create_item_html(tabname, item) + res += self.create_item_html(tabname, item, self.card_page_template) - if items_html == "": + if res == "": dirs = "".join([f"
  • {x}
  • " for x in self.allowed_directories_for_previews()]) - items_html = shared.html("extra-networks-no-cards.html").format(dirs=dirs) - - self_name_id = self.name.replace(" ", "_") - - res = ( - f"
    {subdirs_html}
    " - f"
    {items_html}
    " - ) + res = shared.html("extra-networks-no-cards.html").format(dirs=dirs) return res def create_html(self, tabname): - if shared.opts.extra_networks_tree_view: - return self.create_tree_view_html(tabname) - else: - return self.create_card_view_html(tabname) + self.metadata = {} + self.items = {x["name"]: x for x in self.list_items()} + + tree_view_html = self.create_tree_view_html(tabname) + subdirs_html = self.create_subdirs_html(tabname) + card_view_html = self.create_card_view_html(tabname) + network_type_id = self.name.replace(" ", "_") + + return self.extra_networks_pane_template.format( + **{ + "tabname": tabname, + "network_type_id": network_type_id, + "tree_html": tree_view_html, + "subdirs_html": subdirs_html, + "items_html": card_view_html, + } + ) def create_item(self, name, index=None): raise NotImplementedError() @@ -516,19 +542,19 @@ def create_ui(interface: gr.Blocks, unrelated_tabs, tabname): related_tabs.append(tab) - tab_controls = {} - edit_search = gr.Textbox('', show_label=False, elem_id=tabname+"_extra_search", elem_classes="search", placeholder="Search...", visible=False, interactive=True) dropdown_sort = gr.Dropdown(choices=['Path', 'Name', 'Date Created', 'Date Modified', ], value=shared.opts.extra_networks_card_order_field, elem_id=tabname+"_extra_sort", elem_classes="sort", multiselect=False, visible=False, show_label=False, interactive=True, label=tabname+"_extra_sort_order") button_sortorder = ToolButton(switch_values_symbol, elem_id=tabname+"_extra_sortorder", elem_classes=["sortorder"] + ([] if shared.opts.extra_networks_card_order == "Ascending" else ["sortReverse"]), visible=False, tooltip="Invert sort order") button_refresh = gr.Button('Refresh', elem_id=tabname+"_extra_refresh", visible=False) checkbox_show_dirs = gr.Checkbox(True, label='Show dirs', elem_id=tabname+"_extra_show_dirs", elem_classes="show-dirs", visible=False) - tab_controls["edit_search"] = edit_search - tab_controls["dropdown_sort"] = dropdown_sort - tab_controls["button_sortorder"] = button_sortorder - tab_controls["button_refresh"] = button_refresh - tab_controls["checkbox_show_dirs"] = checkbox_show_dirs + tab_controls = [ + edit_search, + dropdown_sort, + button_sortorder, + button_refresh, + checkbox_show_dirs, + ] ui.button_save_preview = gr.Button('Save preview', elem_id=tabname+"_save_preview", visible=False) ui.preview_target_filename = gr.Textbox('Preview save filename', elem_id=tabname+"_preview_filename", visible=False) @@ -538,32 +564,28 @@ def create_ui(interface: gr.Blocks, unrelated_tabs, tabname): fn=lambda: [gr.update(visible=False) for _ in tab_controls], _js="function(){ extraNetworksUrelatedTabSelected('" + tabname + "'); }", inputs=[], - outputs=list(tab_controls.values()), + outputs=tab_controls, show_progress=False, ) - visible_controls = list(tab_controls.keys()) - if shared.opts.extra_networks_tree_view: - visible_controls = ["button_refresh"] - for page, tab in zip(ui.stored_extra_pages, related_tabs): allow_prompt = "true" if page.allow_prompt else "false" allow_negative_prompt = "true" if page.allow_negative_prompt else "false" jscode = ( "extraNetworksTabSelected(" - f"'{tabname}', " - f"'{tabname}_{page.id_page}_prompts', " - f"'{allow_prompt}', " - f"'{allow_negative_prompt}'" + f"'{tabname}', " + f"'{tabname}_{page.id_page}_prompts', " + f"'{allow_prompt}', " + f"'{allow_negative_prompt}'" ");" ) tab.select( - fn=lambda: [gr.update(visible=k in visible_controls) for k in tab_controls], + fn=lambda: [gr.update(visible=True) for _ in tab_controls], _js="function(){ " + jscode + " }", inputs=[], - outputs=list(tab_controls.values()), + outputs=tab_controls, show_progress=False, ) diff --git a/style.css b/style.css index 680a5f837b7..4f285c68b06 100644 --- a/style.css +++ b/style.css @@ -863,7 +863,7 @@ footer { margin-bottom: 1em; } -.extra-network-cards{ +.extra-network-pane{ height: calc(100vh - 24rem); overflow: clip scroll; resize: vertical; @@ -908,53 +908,75 @@ footer { width: auto; } -.extra-network-cards .nocards{ +.extra-network-pane .nocards{ margin: 1.25em 0.5em 0.5em 0.5em; } -.extra-network-cards .nocards h1{ +.extra-network-pane .nocards h1{ font-size: 1.5em; margin-bottom: 1em; } -.extra-network-cards .nocards li{ +.extra-network-pane .nocards li{ margin-left: 0.5em; } +.extra-network-pane :is(.card, .card-minimal) .button-row{ + display: inline-flex; + visibility: hidden; + color: white; +} -.extra-network-cards .card .button-row{ - display: none; +.extra-network-pane .card .button-row { position: absolute; - color: white; right: 0; - z-index: 1 + z-index: 1; } -.extra-network-cards .card:hover .button-row{ - display: flex; + +.extra-network-pane .card-minimal .button-row { + padding-left: 0.5rem; + padding-right: 0.5rem; + align-items: center; } -.extra-network-cards .card .card-button{ +.extra-network-pane :is(.card:hover, .card-minimal:hover) .button-row{ + visibility: visible; +} + +.extra-network-pane .card-button{ color: white; } -.extra-network-cards .card .metadata-button:before{ +.extra-network-pane .copy-path-button:before { + content: "⎘"; +} + +.extra-network-pane .metadata-button:before{ content: "🛈"; } -.extra-network-cards .card .edit-button:before{ +.extra-network-pane .edit-button:before{ content: "🛠"; } -.extra-network-cards .card .card-button { +.extra-network-pane .card-button { + width: 1.5em; text-shadow: 2px 2px 3px black; + color: white; padding: 0.25em 0.1em; - font-size: 200%; - width: 1.5em; } -.extra-network-cards .card .card-button:hover{ + +.extra-network-pane .card-button:hover{ color: red; } +.extra-network-pane .card .card-button { + font-size: 2rem; +} + +.extra-network-pane .card-minimal .card-button { + font-size: 1rem; +} .standalone-card-preview.card .preview{ position: absolute; @@ -963,7 +985,7 @@ footer { height:100%; } -.extra-network-cards .card, .standalone-card-preview.card{ +.extra-network-pane .card, .standalone-card-preview.card{ display: inline-block; margin: 0.5rem; width: 16rem; @@ -980,15 +1002,15 @@ footer { background-image: url('./file=html/card-no-preview.png') } -.extra-network-cards .card:hover{ +.extra-network-pane .card:hover{ box-shadow: 0 0 2px 0.3em rgba(0, 128, 255, 0.35); } -.extra-network-cards .card .actions .additional{ +.extra-network-pane .card .actions .additional{ display: none; } -.extra-network-cards .card .actions{ +.extra-network-pane .card .actions{ position: absolute; bottom: 0; left: 0; @@ -999,45 +1021,45 @@ footer { text-shadow: 0 0 0.2em black; } -.extra-network-cards .card .actions *{ +.extra-network-pane .card .actions *{ color: white; } -.extra-network-cards .card .actions .name{ +.extra-network-pane .card .actions .name{ font-size: 1.7em; font-weight: bold; line-break: anywhere; } -.extra-network-cards .card .actions .description { +.extra-network-pane .card .actions .description { display: block; max-height: 3em; white-space: pre-wrap; line-height: 1.1; } -.extra-network-cards .card .actions .description:hover { +.extra-network-pane .card .actions .description:hover { max-height: none; } -.extra-network-cards .card .actions:hover .additional{ +.extra-network-pane .card .actions:hover .additional{ display: block; } -.extra-network-cards .card ul{ +.extra-network-pane .card ul{ margin: 0.25em 0 0.75em 0.25em; cursor: unset; } -.extra-network-cards .card ul a{ +.extra-network-pane .card ul a{ cursor: pointer; } -.extra-network-cards .card ul a:hover{ +.extra-network-pane .card ul a:hover{ color: red; } -.extra-network-cards .card .preview{ +.extra-network-pane .card .preview{ position: absolute; object-fit: cover; width: 100%; @@ -1158,94 +1180,98 @@ body.resizing .resize-handle { border-left: 1px dashed var(--border-color-primary); } -.extra-network-cards .card .copy-path-button:before { - content: "⎘"; -} - -.extra-network-cards .card-minimal .button-column { +.extra-network-pane .card-minimal { display: inline-flex; - visibility: hidden; - color: white; - padding-left: 0.5rem; - padding-right: 0.5rem; - align-items: center; + flex-grow: 1; + position: relative; + overflow: hidden; + cursor: pointer; + font-size: 1rem; + font-weight: bold; + line-break: anywhere; } -.extra-network-cards .card-minimal:hover .button-column { - visibility: visible; +/* Pushes buttons to right */ +.extra-network-pane .card-minimal .name { + flex-grow: 1; } -.extra-network-cards .card-minimal .copy-path-button:before { - content: "⎘"; +.file-item, +.folder-item, +.folder-item-summary { + padding-left: 0.05rem; + cursor: pointer; + user-select: none; + font-size: 1rem; } -.extra-network-cards .card-minimal .metadata-button:before{ - content: "🛈"; +.extra-network-pane .extra-network-tree .folder-item-summary:hover, +.extra-network-pane .extra-network-tree .file-item:hover { + -webkit-transition: all 0.1s ease-in-out; + transition: all 0.1s ease-in-out; + background-color: var(--neutral-200); } -.extra-network-cards .card-minimal .edit-button:before{ - content: "🛠"; +.dark .extra-network-pane .extra-network-tree .folder-item-summary:hover, +.dark .extra-network-pane .extra-network-tree .file-item:hover { + -webkit-transition: all 0.05s ease-in-out; + transition: all 0.05s ease-in-out; + background-color: var(--neutral-800); } -.extra-network-cards .card-minimal .card-button { - color: white; - text-shadow: 2px 2px 3px black; - font-size: 1rem; - width: 1.5rem; +/* prevents clicking/collapsing of details tags when disabled attribute is used*/ +.extra-network-pane .extra-network-tree details[disabled] summary { + pointer-events: none; + user-select: none; } -.extra-network-cards .card-minimal .card-button:hover { - color: red; +.extra-network-pane .extra-network-tree details.folder-item > summary { + list-style-type: '📁'; + text-overflow: ellipsis; } -.extra-network-cards .card-minimal { - display: inline-flex; - position: relative; - overflow: hidden; - cursor: pointer; - font-size: 1rem; - font-weight: bold; - line-break: anywhere; +.extra-network-pane .extra-network-tree details.folder-item[open] > summary { + list-style-type: '📂'; + text-overflow: ellipsis; } -.file-item { - list-style-type: '📄'; +.extra-network-pane .extra-network tree ul.folder-container { + list-style: none; + font-size: 1rem; + text-overflow: ellipsis; } -/* prevents clicking/collapsing of details tags when disabled attribute is used*/ -details[disabled] summary { - pointer-events: none; - user-select: none; +.extra-network-pane .extra-network-tree li.file-item { + display: flex; + position: relative; + align-items: center; } -details.folder-item > summary { - list-style-type: '📁'; +.extra-network-pane .extra-network-tree li.file-item::before { + content: '📄'; + font-size: 0.85rem; + vertical-align: middle; } -details.folder-item[open] > summary { - list-style-type: '📂'; +.extra-network-pane { + display: flex; } -.file-item, -.folder-item, -.folder-item-summary { +.extra-network-pane .extra-network-subdirs { display: block; +} +.extra-network-pane .extra-network-tree { font-size: 1rem; - padding: 0.05rem; - cursor: pointer; - user-select: none; + width: 25%; } - -.folder-item-summary:hover, -.file-item:hover { - -webkit-transition: all 0.1s ease-in-out; - transition: all 0.1s ease-in-out; - background-color: var(--neutral-200); +.extra-network-pane .extra-network-cards { + flex-grow: 1; } -.dark .folder-item-summary:hover, -.dark .file-item:hover { - -webkit-transition: all 0.05s ease-in-out; - transition: all 0.05s ease-in-out; +.dark .extra-network-tree .folder-item-summary.selected{ background-color: var(--neutral-800); } + +.extra-network-tree .folder-item-summary.selected { + background-color: var(--neutral-200); +} From 0011640ab187e5a59098bb80d165a23f9d08568c Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Thu, 11 Jan 2024 08:29:42 +0200 Subject: [PATCH 0575/1174] Logging: set formatter correctly for fallback logger too --- modules/logging_config.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/modules/logging_config.py b/modules/logging_config.py index 11eee9a63db..8e31d8c9fd1 100644 --- a/modules/logging_config.py +++ b/modules/logging_config.py @@ -36,20 +36,21 @@ def setup_logging(loglevel): # Already configured, do not interfere return + formatter = logging.Formatter( + '%(asctime)s %(levelname)s [%(name)s] %(message)s', + '%Y-%m-%d %H:%M:%S', + ) + if os.environ.get("SD_WEBUI_RICH_LOG"): from rich.logging import RichHandler handler = RichHandler() else: handler = logging.StreamHandler() + handler.setFormatter(formatter) if TqdmLoggingHandler: handler = TqdmLoggingHandler(handler) - formatter = logging.Formatter( - '%(asctime)s %(levelname)s [%(name)s] %(message)s', - '%Y-%m-%d %H:%M:%S', - ) - handler.setFormatter(formatter) log_level = getattr(logging, loglevel.upper(), None) or logging.INFO From 0726a6e12e85a37d1e514f5603acf9f058c11783 Mon Sep 17 00:00:00 2001 From: Sj-Si Date: Thu, 11 Jan 2024 15:06:57 -0500 Subject: [PATCH 0576/1174] Finish base layout. Fix bugs. Need to test for stability and clean up. --- .../Lora/ui_extra_networks_lora.py | 5 +- html/extra-networks-card.html | 4 +- html/extra-networks-pane.html | 3 - javascript/extraNetworks.js | 48 ++++---- modules/ui_extra_networks.py | 113 ++++++------------ modules/ui_extra_networks_checkpoints.py | 5 +- modules/ui_extra_networks_hypernets.py | 6 +- .../ui_extra_networks_textual_inversion.py | 5 +- style.css | 24 ++-- 9 files changed, 89 insertions(+), 124 deletions(-) diff --git a/extensions-builtin/Lora/ui_extra_networks_lora.py b/extensions-builtin/Lora/ui_extra_networks_lora.py index df02c663b12..db612fa2ba3 100644 --- a/extensions-builtin/Lora/ui_extra_networks_lora.py +++ b/extensions-builtin/Lora/ui_extra_networks_lora.py @@ -24,13 +24,16 @@ def create_item(self, name, index=None, enable_filter=True): alias = lora_on_disk.get_alias() + search_terms = [self.search_terms_from_path(lora_on_disk.filename)] + if lora_on_disk.hash: + search_terms.append(lora_on_disk.hash) item = { "name": name, "filename": lora_on_disk.filename, "shorthash": lora_on_disk.shorthash, "preview": self.find_preview(path), "description": self.find_description(path), - "search_term": self.search_terms_from_path(lora_on_disk.filename) + " " + (lora_on_disk.hash or ""), + "search_terms": search_terms, "local_preview": f"{path}.{shared.opts.samples_format}", "metadata": lora_on_disk.metadata, "sort_keys": {'default': index, **self.get_sort_keys(lora_on_disk.filename)}, diff --git a/html/extra-networks-card.html b/html/extra-networks-card.html index 7770094da14..d163fe370a5 100644 --- a/html/extra-networks-card.html +++ b/html/extra-networks-card.html @@ -6,9 +6,7 @@ {edit_button}
    -
    - -
    +
    {search_terms}
    {name} {description}
    diff --git a/html/extra-networks-pane.html b/html/extra-networks-pane.html index 93bad698ddb..20cf6686755 100644 --- a/html/extra-networks-pane.html +++ b/html/extra-networks-pane.html @@ -1,6 +1,3 @@ -
    - {subdirs_html} -
    {tree_html} diff --git a/javascript/extraNetworks.js b/javascript/extraNetworks.js index 33f45c8e4d0..4cc67fd18a3 100644 --- a/javascript/extraNetworks.js +++ b/javascript/extraNetworks.js @@ -24,8 +24,6 @@ function setupExtraNetworksForTab(tabname) { var sort = gradioApp().getElementById(tabname + '_extra_sort'); var sortOrder = gradioApp().getElementById(tabname + '_extra_sortorder'); var refresh = gradioApp().getElementById(tabname + '_extra_refresh'); - var showDirsDiv = gradioApp().getElementById(tabname + '_extra_show_dirs'); - var showDirs = gradioApp().querySelector('#' + tabname + '_extra_show_dirs input'); var promptContainer = gradioApp().querySelector('.prompt-container-compact#' + tabname + '_prompt_container'); var negativePrompt = gradioApp().querySelector('#' + tabname + '_neg_prompt'); @@ -33,14 +31,14 @@ function setupExtraNetworksForTab(tabname) { tabs.appendChild(sort); tabs.appendChild(sortOrder); tabs.appendChild(refresh); - tabs.appendChild(showDirsDiv); var applyFilter = function() { var searchTerm = search.value.toLowerCase(); gradioApp().querySelectorAll('#' + tabname + '_extra_tabs div.card').forEach(function(elem) { var searchOnly = elem.querySelector('.search_only'); - var text = elem.querySelector('.name').textContent.toLowerCase() + " " + elem.querySelector('.search_term').textContent.toLowerCase(); + + var text = Array.prototype.map.call(elem.querySelectorAll('.search_terms'), function(t) { return t.textContent.toLowerCase() }).join(" "); var visible = text.indexOf(searchTerm) != -1; @@ -100,15 +98,6 @@ function setupExtraNetworksForTab(tabname) { extraNetworksApplySort[tabname] = applySort; extraNetworksApplyFilter[tabname] = applyFilter; - - var showDirsUpdate = function() { - var css = '#' + tabname + '_extra_tabs .extra-network-subdirs { display: none; }'; - toggleCss(tabname + '_extra_show_dirs_style', css, !showDirs.checked); - localSet('extra-networks-show-dirs', showDirs.checked ? 1 : 0); - }; - showDirs.checked = localGet('extra-networks-show-dirs', 1) == 1; - showDirs.addEventListener("change", showDirsUpdate); - showDirsUpdate(); } function extraNetworksMovePromptToTab(tabname, id, showPrompt, showNegativePrompt) { @@ -136,14 +125,23 @@ function extraNetworksMovePromptToTab(tabname, id, showPrompt, showNegativePromp } } +function clearSearch(tabname) { + // Clear search box. + var tab_id = tabname + "_extra_search"; + var searchTextarea = gradioApp().querySelector("#" + tab_id + ' > label > textarea'); + searchTextarea.value = ""; + updateInput(searchTextarea); +} + -function extraNetworksUrelatedTabSelected(tabname) { // called from python when user selects an unrelated tab (generate) +function extraNetworksUnrelatedTabSelected(tabname) { // called from python when user selects an unrelated tab (generate) extraNetworksMovePromptToTab(tabname, '', false, false); + clearSearch(tabname); } function extraNetworksTabSelected(tabname, id, showPrompt, showNegativePrompt) { // called from python when user selects an extra networks tab extraNetworksMovePromptToTab(tabname, id, showPrompt, showNegativePrompt); - + clearSearch(tabname); } function applyExtraNetworkFilter(tabname) { @@ -254,6 +252,15 @@ function saveCardPreview(event, tabname, filename) { } function extraNetworksFolderClick(event, tabs_id) { + // If folder is open but not selected, we don't want to collapse it. Instead + // we override the removal of the "open" attribute so that the folder is + // only selected but remains open. Since this is a toggle event, removing + // the "open" attribute instead forces the event to add it back which keeps it open. + if (event.target.parentElement.open && !event.target.classList.contains("selected")) { + // before event handler removes "open" + event.target.parentElement.removeAttribute("open"); + } + var els = document.querySelectorAll(".folder-item-summary.selected"); [...els].forEach(el => { el.classList.remove("selected"); @@ -261,18 +268,9 @@ function extraNetworksFolderClick(event, tabs_id) { event.target.classList.add("selected"); var searchTextArea = gradioApp().querySelector("#" + tabs_id + ' > label > textarea'); - var text = event.target.classList.contains("search-all") ? "" : event.target.firstChild.textContent.trim(); + var text = event.target.classList.contains("search-all") ? "" : event.target.getAttribute("data-path"); searchTextArea.value = text; updateInput(searchTextArea); - - if (event.target.parentElement.open) { - // before close - console.log("closed"); - } else { - // before open - console.log("opened"); - //console.log("Opened:", event.target.parentElement); - } } function extraNetworksSearchButton(tabs_id, event) { diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index 6318594fa5b..2e226ba0029 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -48,23 +48,24 @@ def get_tree(paths: Union[str, list[str]], items: dict[str, ExtraNetworksItem]) if isinstance(paths, (str,)): paths = [paths] - def _get_tree(_paths: list[str]): + def _get_tree(_paths: list[str], _root: str): _res = {} for path in _paths: + relpath = os.path.relpath(path, _root) if os.path.isdir(path): dir_items = os.listdir(path) # Ignore empty directories. if not dir_items: continue - dir_tree = _get_tree([os.path.join(path, x) for x in dir_items]) + dir_tree = _get_tree([os.path.join(path, x) for x in dir_items], _root) # We only want to store non-empty folders in the tree. if dir_tree: - _res[os.path.basename(path)] = dir_tree + _res[relpath] = dir_tree else: if path not in items: continue # Add the ExtraNetworksItem to the result. - _res[os.path.basename(path)] = items[path] + _res[relpath] = items[path] return _res res = {} @@ -73,11 +74,13 @@ def _get_tree(_paths: list[str]): # the value can be an empty dict if the directory is empty. We want these # placeholders for empty dirs so we can inform the user later. for path in paths: - short_path = os.path.basename(path) + root = os.path.dirname(path) + relpath = os.path.relpath(path, root) # Wrap the path in a list since that is what the `_get_tree` expects. - res[short_path] = _get_tree([path]) - if res[short_path]: - res[short_path] = res[short_path][os.path.basename(path)] + res[relpath] = _get_tree([path], root) + if res[relpath]: + # We need to pull the inner path out one for these root dirs. + res[relpath] = res[relpath][relpath] return res @@ -245,6 +248,17 @@ def create_item_html(self, tabname: str, item: dict, template: Optional[str] = N sort_keys = " ".join([f'data-sort-{k}="{html.escape(str(v))}"' for k, v in item.get("sort_keys", {}).items()]).strip() + search_terms_html = "" + search_term_template = "{search_term}" + for search_term in item.get("search_terms", []): + search_terms_html += search_term_template.format( + **{ + "style": "display: none;", + "class": "search_terms" + (" search_only" if search_only else ""), + "search_term": search_term, + } + ) + # Some items here might not be used depending on HTML template used. args = { "background_image": background_image, @@ -258,7 +272,7 @@ def create_item_html(self, tabname: str, item: dict, template: Optional[str] = N "prompt": item.get("prompt", None), "save_card_preview": '"' + html.escape(f"""return saveCardPreview(event, {quote_js(tabname)}, {quote_js(item["local_preview"])})""") + '"', "search_only": " search_only" if search_only else "", - "search_term": item.get("search_term", ""), + "search_terms": search_terms_html, "sort_keys": sort_keys, "style": f"'display: none; {height}{width}; font-size: {shared.opts.extra_networks_card_text_scale*100}%'", "tabname": quote_js(tabname), @@ -278,7 +292,7 @@ def create_tree_view_html(self, tabname: str) -> str: Returns: HTML string generated for this tree view. """ - res = f"" + res = "" # Generate HTML for the tree. roots = self.allowed_directories_for_previews() @@ -291,7 +305,8 @@ def create_tree_view_html(self, tabname: str) -> str: file_template = "
  • {card}
  • " dir_template = ( "
    " - "" + "" "{folder_name}" "" "
      {content}
    " @@ -309,16 +324,15 @@ def _build_tree(data: Optional[dict[str, ExtraNetworksItem]] = None) -> str: item_html = self.create_item_html(tabname, v.item, self.card_page_minimal_template) _res += file_template.format(**{"card": item_html}) else: - tmp = dir_template.format( + _res += dir_template.format( **{ "attributes": "", "tabname": tabname, - "folder_name": k, + "folder_name": os.path.basename(k), + "data_path": k, "content": _build_tree(v), } ) - _res += tmp - return _res # Add each root directory to the tree. @@ -329,65 +343,15 @@ def _build_tree(data: Optional[dict[str, ExtraNetworksItem]] = None) -> str: **{ "attributes": "open" if v else "open", "tabname": tabname, - "folder_name": k, + "folder_name": os.path.basename(k), + "data_path": k, "content": _build_tree(v), } ) res += "
" res += "" - return res - def create_subdirs_html(self, tabname): - subdirs = {} - - for parentdir in [os.path.abspath(x) for x in self.allowed_directories_for_previews()]: - for root, dirs, _ in sorted(os.walk(parentdir, followlinks=True), key=lambda x: shared.natural_sort_key(x[0])): - for dirname in sorted(dirs, key=shared.natural_sort_key): - x = os.path.join(root, dirname) - - if not os.path.isdir(x): - continue - - subdir = os.path.abspath(x)[len(parentdir):].replace("\\", "/") - - if shared.opts.extra_networks_dir_button_function: - if not subdir.startswith("/"): - subdir = "/" + subdir - else: - while subdir.startswith("/"): - subdir = subdir[1:] - - is_empty = len(os.listdir(x)) == 0 - if not is_empty and not subdir.endswith("/"): - subdir = subdir + "/" - - if ("/." in subdir or subdir.startswith(".")) and not shared.opts.extra_networks_show_hidden_directories: - continue - - subdirs[subdir] = 1 - - if subdirs: - subdirs = {"": 1, **subdirs} - - subdirs_html_template = ( - "" - ) - return "".join( - [ - subdirs_html_template.format( - **{ - "classes": "search-all" if not subdir else "", - "tabname": tabname, - "content": html.escape(subdir if subdir else "all"), - } - ) for subdir in subdirs - ] - ) - def create_card_view_html(self, tabname): res = "" self.items = {x["name"]: x for x in self.list_items()} @@ -405,7 +369,6 @@ def create_html(self, tabname): self.items = {x["name"]: x for x in self.list_items()} tree_view_html = self.create_tree_view_html(tabname) - subdirs_html = self.create_subdirs_html(tabname) card_view_html = self.create_card_view_html(tabname) network_type_id = self.name.replace(" ", "_") @@ -414,7 +377,6 @@ def create_html(self, tabname): "tabname": tabname, "network_type_id": network_type_id, "tree_html": tree_view_html, - "subdirs_html": subdirs_html, "items_html": card_view_html, } ) @@ -534,7 +496,12 @@ def create_ui(interface: gr.Blocks, unrelated_tabs, tabname): elem_id = f"{tabname}_{page.id_page}_cards_html" page_elem = gr.HTML('Loading...', elem_id=elem_id) ui.pages.append(page_elem) - page_elem.change(fn=lambda: None, _js='function(){applyExtraNetworkFilter(' + quote_js(tabname) + '); return []}', inputs=[], outputs=[]) + page_elem.change( + fn=lambda: None, + _js=f"function(){{applyExtraNetworkFilter({tabname}_extra_search); return []}}", + inputs=[], + outputs=[], + ) editor = page.create_user_metadata_editor(ui, tabname) editor.create_ui() @@ -542,18 +509,16 @@ def create_ui(interface: gr.Blocks, unrelated_tabs, tabname): related_tabs.append(tab) - edit_search = gr.Textbox('', show_label=False, elem_id=tabname+"_extra_search", elem_classes="search", placeholder="Search...", visible=False, interactive=True) + edit_search = gr.Textbox('', show_label=False, elem_id=f"{tabname}_extra_search", elem_classes="search", placeholder="Search...", visible=False, interactive=True) dropdown_sort = gr.Dropdown(choices=['Path', 'Name', 'Date Created', 'Date Modified', ], value=shared.opts.extra_networks_card_order_field, elem_id=tabname+"_extra_sort", elem_classes="sort", multiselect=False, visible=False, show_label=False, interactive=True, label=tabname+"_extra_sort_order") button_sortorder = ToolButton(switch_values_symbol, elem_id=tabname+"_extra_sortorder", elem_classes=["sortorder"] + ([] if shared.opts.extra_networks_card_order == "Ascending" else ["sortReverse"]), visible=False, tooltip="Invert sort order") button_refresh = gr.Button('Refresh', elem_id=tabname+"_extra_refresh", visible=False) - checkbox_show_dirs = gr.Checkbox(True, label='Show dirs', elem_id=tabname+"_extra_show_dirs", elem_classes="show-dirs", visible=False) tab_controls = [ edit_search, dropdown_sort, button_sortorder, button_refresh, - checkbox_show_dirs, ] ui.button_save_preview = gr.Button('Save preview', elem_id=tabname+"_save_preview", visible=False) @@ -562,7 +527,7 @@ def create_ui(interface: gr.Blocks, unrelated_tabs, tabname): for tab in unrelated_tabs: tab.select( fn=lambda: [gr.update(visible=False) for _ in tab_controls], - _js="function(){ extraNetworksUrelatedTabSelected('" + tabname + "'); }", + _js=f"function(){{ extraNetworksUnrelatedTabSelected('{tabname}'); }}", inputs=[], outputs=tab_controls, show_progress=False, diff --git a/modules/ui_extra_networks_checkpoints.py b/modules/ui_extra_networks_checkpoints.py index 1693e71f16f..e7976ba1260 100644 --- a/modules/ui_extra_networks_checkpoints.py +++ b/modules/ui_extra_networks_checkpoints.py @@ -21,13 +21,16 @@ def create_item(self, name, index=None, enable_filter=True): return path, ext = os.path.splitext(checkpoint.filename) + search_terms = [self.search_terms_from_path(checkpoint.filename)] + if checkpoint.sha256: + search_terms.append(checkpoint.sha256) return { "name": checkpoint.name_for_extra, "filename": checkpoint.filename, "shorthash": checkpoint.shorthash, "preview": self.find_preview(path), "description": self.find_description(path), - "search_term": self.search_terms_from_path(checkpoint.filename) + " " + (checkpoint.sha256 or ""), + "search_terms": search_terms, "onclick": '"' + html.escape(f"""return selectCheckpoint({quote_js(name)})""") + '"', "local_preview": f"{path}.{shared.opts.samples_format}", "metadata": checkpoint.metadata, diff --git a/modules/ui_extra_networks_hypernets.py b/modules/ui_extra_networks_hypernets.py index c96c4fa3b12..2fb4bd190a1 100644 --- a/modules/ui_extra_networks_hypernets.py +++ b/modules/ui_extra_networks_hypernets.py @@ -20,14 +20,16 @@ def create_item(self, name, index=None, enable_filter=True): path, ext = os.path.splitext(full_path) sha256 = sha256_from_cache(full_path, f'hypernet/{name}') shorthash = sha256[0:10] if sha256 else None - + search_terms = [self.search_terms_from_path(path)] + if sha256: + search_terms.append(sha256) return { "name": name, "filename": full_path, "shorthash": shorthash, "preview": self.find_preview(path), "description": self.find_description(path), - "search_term": self.search_terms_from_path(path) + " " + (sha256 or ""), + "search_terms": search_terms, "prompt": quote_js(f""), "local_preview": f"{path}.preview.{shared.opts.samples_format}", "sort_keys": {'default': index, **self.get_sort_keys(path + ext)}, diff --git a/modules/ui_extra_networks_textual_inversion.py b/modules/ui_extra_networks_textual_inversion.py index 1b334fda174..deb7cb8733b 100644 --- a/modules/ui_extra_networks_textual_inversion.py +++ b/modules/ui_extra_networks_textual_inversion.py @@ -18,13 +18,16 @@ def create_item(self, name, index=None, enable_filter=True): return path, ext = os.path.splitext(embedding.filename) + search_terms = [self.search_terms_from_path(embedding.filename)] + if embedding.hash: + search_terms.append(embedding.hash) return { "name": name, "filename": embedding.filename, "shorthash": embedding.shorthash, "preview": self.find_preview(path), "description": self.find_description(path), - "search_term": self.search_terms_from_path(embedding.filename) + " " + (embedding.hash or ""), + "search_terms": search_terms, "prompt": quote_js(embedding.name), "local_preview": f"{path}.preview.{shared.opts.samples_format}", "sort_keys": {'default': index, **self.get_sort_keys(embedding.filename)}, diff --git a/style.css b/style.css index 4f285c68b06..70d80d6a170 100644 --- a/style.css +++ b/style.css @@ -878,16 +878,8 @@ footer { margin: 0.3em; } -.extra-network-subdirs{ - padding: 0.2em 0.35em; -} - -.extra-network-subdirs button{ - margin: 0 0.15em; -} .extra-networks .tab-nav .search, -.extra-networks .tab-nav .sort, -.extra-networks .tab-nav .show-dirs +.extra-networks .tab-nav .sort { margin: 0.3em; align-self: center; @@ -1196,6 +1188,10 @@ body.resizing .resize-handle { flex-grow: 1; } +.folder-container { + margin-left: 1.5em !important; +} + .file-item, .folder-item, .folder-item-summary { @@ -1235,7 +1231,7 @@ body.resizing .resize-handle { text-overflow: ellipsis; } -.extra-network-pane .extra-network tree ul.folder-container { +.extra-network-pane .extra-network-tree ul.folder-container { list-style: none; font-size: 1rem; text-overflow: ellipsis; @@ -1257,15 +1253,15 @@ body.resizing .resize-handle { display: flex; } -.extra-network-pane .extra-network-subdirs { - display: block; -} .extra-network-pane .extra-network-tree { font-size: 1rem; - width: 25%; + min-width: 25%; + max-width: 25%; + border: 1px solid var(--block-border-color); } .extra-network-pane .extra-network-cards { flex-grow: 1; + border: 1px solid var(--block-border-color); } .dark .extra-network-tree .folder-item-summary.selected{ From 541881e3188b4f59d82043c5fdfac02607ec3119 Mon Sep 17 00:00:00 2001 From: WebDev <6970043+WebDev9000@users.noreply.github.com> Date: Sat, 13 Jan 2024 02:11:06 -0800 Subject: [PATCH 0577/1174] Adjust brush size with hotkeys. --- extensions-builtin/canvas-zoom-and-pan/javascript/zoom.js | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/extensions-builtin/canvas-zoom-and-pan/javascript/zoom.js b/extensions-builtin/canvas-zoom-and-pan/javascript/zoom.js index 45c7600ac5f..c695debc65b 100644 --- a/extensions-builtin/canvas-zoom-and-pan/javascript/zoom.js +++ b/extensions-builtin/canvas-zoom-and-pan/javascript/zoom.js @@ -218,6 +218,8 @@ onUiLoaded(async() => { canvas_hotkey_fullscreen: "KeyS", canvas_hotkey_move: "KeyF", canvas_hotkey_overlap: "KeyO", + canvas_hotkey_shrink_brush: "BracketLeft", + canvas_hotkey_grow_brush: "BracketRight", canvas_disabled_functions: [], canvas_show_tooltip: true, canvas_auto_expand: true, @@ -227,6 +229,8 @@ onUiLoaded(async() => { const functionMap = { "Zoom": "canvas_hotkey_zoom", "Adjust brush size": "canvas_hotkey_adjust", + "Shrink brush size": "canvas_hotkey_shrink_brush", + "Grow brush size": "canvas_hotkey_grow_brush", "Moving canvas": "canvas_hotkey_move", "Fullscreen": "canvas_hotkey_fullscreen", "Reset Zoom": "canvas_hotkey_reset", @@ -686,7 +690,9 @@ onUiLoaded(async() => { const hotkeyActions = { [hotkeysConfig.canvas_hotkey_reset]: resetZoom, [hotkeysConfig.canvas_hotkey_overlap]: toggleOverlap, - [hotkeysConfig.canvas_hotkey_fullscreen]: fitToScreen + [hotkeysConfig.canvas_hotkey_fullscreen]: fitToScreen, + [defaultHotkeysConfig.canvas_hotkey_shrink_brush]: () => adjustBrushSize(elemId, 10), + [defaultHotkeysConfig.canvas_hotkey_grow_brush]: () => adjustBrushSize(elemId, -10) }; const action = hotkeyActions[event.code]; From 47b52d9b28aaabb603358fae5d9b824c49aa627b Mon Sep 17 00:00:00 2001 From: WebDev <6970043+WebDev9000@users.noreply.github.com> Date: Sat, 13 Jan 2024 02:31:26 -0800 Subject: [PATCH 0578/1174] Add # to the invalid_filename_chars list --- modules/images.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/images.py b/modules/images.py index 87a7bf22139..b6f2358c3ee 100644 --- a/modules/images.py +++ b/modules/images.py @@ -321,7 +321,7 @@ def resize(im, w, h): return res -invalid_filename_chars = '<>:"/\\|?*\n\r\t' +invalid_filename_chars = '#<>:"/\\|?*\n\r\t' invalid_filename_prefix = ' ' invalid_filename_postfix = ' .' re_nonletters = re.compile(r'[\s' + string.punctuation + ']+') From b6dc307c99a25bc3f3f3e96ce640a4710c14e133 Mon Sep 17 00:00:00 2001 From: Andray Date: Sat, 13 Jan 2024 14:45:15 +0400 Subject: [PATCH 0579/1174] fix_extension_check_for_requirements --- modules/extensions.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/modules/extensions.py b/modules/extensions.py index 99e7ee60f64..04bda297e79 100644 --- a/modules/extensions.py +++ b/modules/extensions.py @@ -224,13 +224,16 @@ def list_extensions(): # check for requirements for extension in extensions: + if not extension.enabled: + continue + for req in extension.metadata.requires: required_extension = loaded_extensions.get(req) if required_extension is None: errors.report(f'Extension "{extension.name}" requires "{req}" which is not installed.', exc_info=False) continue - if not extension.enabled: + if not required_extension.enabled: errors.report(f'Extension "{extension.name}" requires "{required_extension.name}" which is disabled.', exc_info=False) continue From 02e6963325e5221e0efb96a63f3dc849550489b7 Mon Sep 17 00:00:00 2001 From: Sj-Si Date: Sat, 13 Jan 2024 13:16:39 -0500 Subject: [PATCH 0580/1174] continue cleanup and redesign. --- html/extra-networks-tree-button.html | 11 + javascript/extraNetworks.js | 294 ++++++++++++++++++--------- modules/ui_extra_networks.py | 177 +++++++++++++--- style.css | 277 +++++++++++++++++++------ 4 files changed, 572 insertions(+), 187 deletions(-) create mode 100644 html/extra-networks-tree-button.html diff --git a/html/extra-networks-tree-button.html b/html/extra-networks-tree-button.html new file mode 100644 index 00000000000..920330f7a7d --- /dev/null +++ b/html/extra-networks-tree-button.html @@ -0,0 +1,11 @@ + + \ No newline at end of file diff --git a/javascript/extraNetworks.js b/javascript/extraNetworks.js index 3e3b03f320e..cce2246899e 100644 --- a/javascript/extraNetworks.js +++ b/javascript/extraNetworks.js @@ -16,88 +16,110 @@ function toggleCss(key, css, enable) { } function setupExtraNetworksForTab(tabname) { - gradioApp().querySelector('#' + tabname + '_extra_tabs').classList.add('extra-networks'); + var this_tab = gradioApp().querySelector('#' + tabname + '_extra_tabs'); + this_tab.classList.add('extra-networks'); - var tabs = gradioApp().querySelector('#' + tabname + '_extra_tabs > div'); - var searchDiv = gradioApp().getElementById(tabname + '_extra_search'); - var search = searchDiv.querySelector('textarea'); - var sort = gradioApp().getElementById(tabname + '_extra_sort'); - var sortOrder = gradioApp().getElementById(tabname + '_extra_sortorder'); - var refresh = gradioApp().getElementById(tabname + '_extra_refresh'); - var promptContainer = gradioApp().querySelector('.prompt-container-compact#' + tabname + '_prompt_container'); - var negativePrompt = gradioApp().querySelector('#' + tabname + '_neg_prompt'); - - tabs.appendChild(searchDiv); - tabs.appendChild(sort); - tabs.appendChild(sortOrder); - tabs.appendChild(refresh); - - var applyFilter = function() { - var searchTerm = search.value.toLowerCase(); - - gradioApp().querySelectorAll('#' + tabname + '_extra_tabs div.card').forEach(function(elem) { - var searchOnly = elem.querySelector('.search_only'); - - var text = Array.prototype.map.call(elem.querySelectorAll('.search_terms'), function(t) { return t.textContent.toLowerCase() }).join(" "); - - var visible = text.indexOf(searchTerm) != -1; - - if (searchOnly && searchTerm.length < 4) { - visible = false; - } - - elem.style.display = visible ? "" : "none"; - }); - - applySort(); - }; - - var applySort = function() { - var cards = gradioApp().querySelectorAll('#' + tabname + '_extra_tabs div.card'); - - var reverse = sortOrder.classList.contains("sortReverse"); - var sortKey = sort.querySelector("input").value.toLowerCase().replace("sort", "").replaceAll(" ", "_").replace(/_+$/, "").trim() || "name"; - sortKey = "sort" + sortKey.charAt(0).toUpperCase() + sortKey.slice(1); - var sortKeyStore = sortKey + "-" + (reverse ? "Descending" : "Ascending") + "-" + cards.length; + function registerPrompt(tabname, id) { + var textarea = gradioApp().querySelector("#" + id + " > label > textarea"); - if (sortKeyStore == sort.dataset.sortkey) { - return; + if (!activePromptTextarea[tabname]) { + activePromptTextarea[tabname] = textarea; } - sort.dataset.sortkey = sortKeyStore; - cards.forEach(function(card) { - card.originalParentElement = card.parentElement; + textarea.addEventListener("focus", function() { + activePromptTextarea[tabname] = textarea; }); - var sortedCards = Array.from(cards); - sortedCards.sort(function(cardA, cardB) { - var a = cardA.dataset[sortKey]; - var b = cardB.dataset[sortKey]; - if (!isNaN(a) && !isNaN(b)) { - return parseInt(a) - parseInt(b); + } + + this_tab.querySelectorAll(":scope > [id^='" + tabname + "_']").forEach(function(elem) { + var tab_id = elem.getAttribute("id"); + + var tabs = gradioApp().querySelector('#' + tabname + '_extra_tabs > div'); + var searchDiv = gradioApp().QuerySelector("#" + tab_id + "_extra_search"); + console.log("HERE:", tab_id + "_extra_search", searchDiv); + var search = searchDiv.value; + var sort = gradioApp().getElementById(tabname + '_extra_sort'); + var sortOrder = gradioApp().getElementById(tabname + '_extra_sortorder'); + var refresh = gradioApp().getElementById(tabname + '_extra_refresh'); + var promptContainer = gradioApp().querySelector('.prompt-container-compact#' + tabname + '_prompt_container'); + var negativePrompt = gradioApp().querySelector('#' + tabname + '_neg_prompt'); + tabs.appendChild(searchDiv); + tabs.appendChild(sort); + tabs.appendChild(sortOrder); + tabs.appendChild(refresh); + var applyFilter = function() { + var searchTerm = search.value.toLowerCase(); + + gradioApp().querySelectorAll('#' + tabname + '_extra_tabs div.card').forEach(function(elem) { + var searchOnly = elem.querySelector('.search_only'); + + var text = Array.prototype.map.call(elem.querySelectorAll('.search_terms'), function(t) { return t.textContent.toLowerCase() }).join(" "); + + var visible = text.indexOf(searchTerm) != -1; + + if (searchOnly && searchTerm.length < 4) { + visible = false; + } + + elem.style.display = visible ? "" : "none"; + }); + + applySort(); + }; + + var applySort = function() { + var cards = gradioApp().querySelectorAll('#' + tabname + '_extra_tabs div.card'); + + var reverse = sortOrder.classList.contains("sortReverse"); + var sortKey = sort.querySelector("input").value.toLowerCase().replace("sort", "").replaceAll(" ", "_").replace(/_+$/, "").trim() || "name"; + sortKey = "sort" + sortKey.charAt(0).toUpperCase() + sortKey.slice(1); + var sortKeyStore = sortKey + "-" + (reverse ? "Descending" : "Ascending") + "-" + cards.length; + + if (sortKeyStore == sort.dataset.sortkey) { + return; } - - return (a < b ? -1 : (a > b ? 1 : 0)); - }); - if (reverse) { - sortedCards.reverse(); - } - cards.forEach(function(card) { - card.remove(); - }); - sortedCards.forEach(function(card) { - card.originalParentElement.appendChild(card); + sort.dataset.sortkey = sortKeyStore; + + cards.forEach(function(card) { + card.originalParentElement = card.parentElement; + }); + var sortedCards = Array.from(cards); + sortedCards.sort(function(cardA, cardB) { + var a = cardA.dataset[sortKey]; + var b = cardB.dataset[sortKey]; + if (!isNaN(a) && !isNaN(b)) { + return parseInt(a) - parseInt(b); + } + + return (a < b ? -1 : (a > b ? 1 : 0)); + }); + if (reverse) { + sortedCards.reverse(); + } + cards.forEach(function(card) { + card.remove(); + }); + sortedCards.forEach(function(card) { + card.originalParentElement.appendChild(card); + }); + }; + + search.addEventListener("input", applyFilter); + sortOrder.addEventListener("click", function() { + sortOrder.classList.toggle("sortReverse"); + applySort(); }); - }; + applyFilter(); + + extraNetworksApplySort[tab_id] = applySort; + extraNetworksApplyFilter[tab_id] = applyFilter; - search.addEventListener("input", applyFilter); - sortOrder.addEventListener("click", function() { - sortOrder.classList.toggle("sortReverse"); - applySort(); + registerPrompt(tab_id, tab_id + "_prompt"); + registerPrompt(tab_id, tab_id + "_neg_prompt"); }); - applyFilter(); - extraNetworksApplySort[tabname] = applySort; - extraNetworksApplyFilter[tabname] = applyFilter; + + } function extraNetworksMovePromptToTab(tabname, id, showPrompt, showNegativePrompt) { @@ -136,12 +158,12 @@ function clearSearch(tabname) { function extraNetworksUnrelatedTabSelected(tabname) { // called from python when user selects an unrelated tab (generate) extraNetworksMovePromptToTab(tabname, '', false, false); - clearSearch(tabname); + //clearSearch(tabname); } function extraNetworksTabSelected(tabname, id, showPrompt, showNegativePrompt) { // called from python when user selects an extra networks tab extraNetworksMovePromptToTab(tabname, id, showPrompt, showNegativePrompt); - clearSearch(tabname); + //clearSearch(tabname); } function applyExtraNetworkFilter(tabname) { @@ -159,23 +181,6 @@ var activePromptTextarea = {}; function setupExtraNetworks() { setupExtraNetworksForTab('txt2img'); setupExtraNetworksForTab('img2img'); - - function registerPrompt(tabname, id) { - var textarea = gradioApp().querySelector("#" + id + " > label > textarea"); - - if (!activePromptTextarea[tabname]) { - activePromptTextarea[tabname] = textarea; - } - - textarea.addEventListener("focus", function() { - activePromptTextarea[tabname] = textarea; - }); - } - - registerPrompt('txt2img', 'txt2img_prompt'); - registerPrompt('txt2img', 'txt2img_neg_prompt'); - registerPrompt('img2img', 'img2img_prompt'); - registerPrompt('img2img', 'img2img_neg_prompt'); } onUiLoaded(setupExtraNetworks); @@ -262,6 +267,106 @@ function saveCardPreview(event, tabname, filename) { event.preventDefault(); } +function extraNetworksTreeProcessFileClick(event, btn, tabname, tab_id) { + /** + * Processes `onclick` events when user clicks on files in tree. + * + * @param event The generated event. + * @param btn The clicked `action-list-item` button. + * @param tabname The name of the active tab in the sd webui. Ex: txt2img, img2img, etc. + * @param tab_id The id of the active extraNetworks tab. Ex: lora, checkpoints, etc. + */ + var par = btn.parentElement; + var search_id = tabname + "_" + tab_id + "_extra_search"; + var type = par.getAttribute("data-tree-entry-type"); + var path = par.getAttribute("data-path"); +} + +function extraNetworksTreeProcessDirectoryClick(event, btn) { + /** + * Processes `onclick` events when user clicks on directories in tree. + * + * Here is how the tree reacts to clicks for various states: + * unselected unopened directory: Diretory is selected and expanded. + * unselected opened directory: Directory is selected. + * selected opened directory: Directory is collapsed and deselected. + * chevron is clicked: Directory is expanded or collapsed. Selected state unchanged. + * + * @param event The generated event. + * @param btn The clicked `action-list-item` button. + */ + var ul = btn.nextElementSibling; + // This is the actual target that the user clicked on within the target button. + // We use this to detect if the chevron was clicked. + var true_targ = event.target; + + function _expand_or_collapse(_ul, _btn) { + // Expands
    if it is collapsed, collapses otherwise. Updates button attributes. + if (_ul.hasAttribute("data-hidden")) { + _ul.removeAttribute("data-hidden"); + _btn.setAttribute("expanded", "true"); + } else { + _ul.setAttribute("data-hidden", ""); + _btn.setAttribute("expanded", "false"); + } + } + + function _remove_selected_from_all() { + // Removes the `selected` attribute from all buttons. + var sels = document.querySelectorAll("button.action-list-content"); + [...sels].forEach(el => { + el.removeAttribute("selected"); + }) + } + + function _select_button(_btn) { + // Removes `selected` attribute from all buttons then adds to passed button. + _remove_selected_from_all(); + _btn.setAttribute("selected", ""); + } + + // If user clicks on the chevron, then we do not select the folder. + if (true_targ.matches(".action-list-item-action--leading, .action-list-item-action-chevron")) { + _expand_or_collapse(ul, btn); + } else { + // User clicked anywhere else on the button. + if (btn.hasAttribute("selected") && !ul.hasAttribute("data-hidden")) { + // If folder is select and open, collapse and deselect button. + _expand_or_collapse(ul, btn); + btn.removeAttribute("selected"); + } else if (!(!btn.hasAttribute("selected") && !ul.hasAttribute("data-hidden"))) { + // If folder is open and not selected, then we don't collapse; just select. + // NOTE: Double inversion sucks but it is the clearest way to show the branching here. + _expand_or_collapse(ul, btn); + _select_button(btn); + } else { + // All other cases, just select the button. + _select_button(btn); + } + + } +} + +function extraNetworksTreeOnClick(event, tabname, tab_id) { + /** + * Handles `onclick` events for buttons within an `extra-network-tree .action-list--tree`. + * + * Determines whether the clicked button in the tree is for a file entry or a directory + * then calls the appropriate function. + * + * @param event The generated event. + * @param tabname The name of the active tab in the sd webui. Ex: txt2img, img2img, etc. + * @param tab_id The id of the active extraNetworks tab. Ex: lora, checkpoints, etc. + */ + var btn = event.currentTarget; + var par = btn.parentElement; + if (par.getAttribute("data-tree-entry-type") === "file") { + extraNetworksTreeProcessFileClick(event, btn, tabname, tab_id); + } else { + extraNetworksTreeProcessDirectoryClick(event, btn); + } +} + function extraNetworksFolderClick(event, tabs_id) { // If folder is open but not selected, we don't want to collapse it. Instead // we override the removal of the "open" attribute so that the folder is @@ -434,3 +539,10 @@ window.addEventListener("keydown", function(event) { closePopup(); } }); + +function testprint(e) { + console.log(e); +} + +const testinput = gradioApp().querySelector("#txt2img_lora_extra_search"); +testinput.addEventListener("input", testprint); \ No newline at end of file diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index 093ac7b4f0c..9cf5b57fbcb 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -19,6 +19,90 @@ allowed_dirs = set() default_allowed_preview_extensions = ["png", "jpg", "jpeg", "webp", "gif"] +tree_tpl = ( + "" + "
      " + "{content}" + "
    " +) + +tree_ul_tpl = ( + "
      " + "{content}" + "
    " +) + +tree_li_dir_tpl = ( + "
  • " + "{content}" + "
  • " +) +tree_li_file_tpl = ( + "
  • " + "{content}" + "
  • " +) + +tree_btn_dir_tpl = ( + "" +) + +tree_btn_file_action_buttons_tpl = ( + "
    " + "
    " + "
    " + "
    " + "
    " + "
    " +) + +tree_btn_file_tpl = ( + "" + "" +) + + @functools.cache def allowed_preview_extensions_with_extra(extra_extensions=None): return set(default_allowed_preview_extensions) | set(extra_extensions or []) @@ -160,6 +244,7 @@ def __init__(self, title): self.extra_networks_pane_template = shared.html("extra-networks-pane.html") self.card_page_template = shared.html("extra-networks-card.html") self.card_page_minimal_template = shared.html("extra-networks-card-minimal.html") + self.tree_button_template = shared.html("extra-networks-tree-button.html") self.allow_prompt = True self.allow_negative_prompt = False self.metadata = {} @@ -279,7 +364,9 @@ def create_item_html(self, tabname: str, item: dict, template: Optional[str] = N "search_terms": search_terms_html, "sort_keys": sort_keys, "style": f"'display: none; {height}{width}; font-size: {shared.opts.extra_networks_card_text_scale*100}%'", - "tabname": quote_js(tabname), + "tabname": tabname, + "tab_id": self.id_page, + } if template: @@ -306,55 +393,81 @@ def create_tree_view_html(self, tabname: str) -> str: if not tree: return res - file_template = "
  • {card}
  • " - dir_template = ( - "
    " - "" - "{folder_name}" - "" - "
      {content}
    " - "
    " - ) - def _build_tree(data: Optional[dict[str, ExtraNetworksItem]] = None) -> str: """Recursively builds HTML for a tree.""" _res = "" if not data: - return "
  • DIRECTORY IS EMPTY
  • " + return ( + "
    " + "Directory is empty" + "
    " + ) for k, v in sorted(data.items(), key=lambda x: shared.natural_sort_key(x[0])): if isinstance(v, (ExtraNetworksItem,)): - item_html = self.create_item_html(tabname, v.item, self.card_page_minimal_template) - _res += file_template.format(**{"card": item_html}) + _action_buttons = tree_btn_file_action_buttons_tpl.format( + **{ + "path": quote_js(k), + "filename": quote_js(v.item["name"]), + "tabname": quote_js(tabname), + "tab_id": quote_js(self.id_page), + } + ) + _btn = tree_btn_file_tpl.format( + **{ + "label": v.item["name"], + "filter": v.item["search_terms"], + "tabname": tabname, + "tab_id": self.id_page, + "buttons": _action_buttons, + } + ) + _li = tree_li_file_tpl.format( + **{ + "hash": v.item["shorthash"], + "path": k, + "type": "file", + #"content": _btn, + "content": self.create_item_html(tabname, v.item, self.tree_button_template), + } + ) + _res += _li + #item_html = self.create_item_html(tabname, v.item, self.card_page_minimal_template) + #_res += file_template.format(**{"card": item_html}) else: - _res += dir_template.format( + _btn = tree_btn_dir_tpl.format( **{ - "attributes": "", + "label": os.path.basename(k), "tabname": tabname, - "folder_name": os.path.basename(k), - "data_path": k, - "content": _build_tree(v), + "tab_id": self.id_page, } ) + _ul = tree_ul_tpl.format(**{"content": _build_tree(v)}) + _li = tree_li_dir_tpl.format(**{"content": _btn + _ul, "path": k}) + _res += _li return _res # Add each root directory to the tree. for k, v in sorted(tree.items(), key=lambda x: shared.natural_sort_key(x[0])): # If root is empty, append the "disabled" attribute to the template details tag. - res += "
      " - res += dir_template.format( + btn = tree_btn_dir_tpl.format( **{ - "attributes": "open" if v else "open", + "label": os.path.basename(k), "tabname": tabname, - "folder_name": os.path.basename(k), - "data_path": k, - "content": _build_tree(v), + "tab_id": self.id_page, } ) - res += "
    " - res += "
" - return res + ul = tree_ul_tpl.format(**{"content": _build_tree(v)}) + li = tree_li_dir_tpl.format(**{"content": btn + ul, "path": k}) + res += li + + return tree_tpl.format( + **{ + "content": res, + "tabname": tabname, + "tab_id": self.id_page, + } + ) def create_card_view_html(self, tabname): res = "" @@ -375,7 +488,7 @@ def create_html(self, tabname): tree_view_html = self.create_tree_view_html(tabname) card_view_html = self.create_card_view_html(tabname) - network_type_id = self.name.replace(" ", "_") + network_type_id = self.id_page return self.extra_networks_pane_template.format( **{ @@ -506,7 +619,7 @@ def create_ui(interface: gr.Blocks, unrelated_tabs, tabname): ui.pages.append(page_elem) page_elem.change( fn=lambda: None, - _js=f"function(){{applyExtraNetworkFilter({tabname}_extra_search); return []}}", + _js=f"function(){{applyExtraNetworkFilter({tabname}_{page.id_page}_extra_search); return []}}", inputs=[], outputs=[], ) @@ -517,13 +630,11 @@ def create_ui(interface: gr.Blocks, unrelated_tabs, tabname): related_tabs.append(tab) - edit_search = gr.Textbox('', show_label=False, elem_id=f"{tabname}_extra_search", elem_classes="search", placeholder="Search...", visible=False, interactive=True) dropdown_sort = gr.Dropdown(choices=['Path', 'Name', 'Date Created', 'Date Modified', ], value=shared.opts.extra_networks_card_order_field, elem_id=tabname+"_extra_sort", elem_classes="sort", multiselect=False, visible=False, show_label=False, interactive=True, label=tabname+"_extra_sort_order") button_sortorder = ToolButton(switch_values_symbol, elem_id=tabname+"_extra_sortorder", elem_classes=["sortorder"] + ([] if shared.opts.extra_networks_card_order == "Ascending" else ["sortReverse"]), visible=False, tooltip="Invert sort order") button_refresh = gr.Button('Refresh', elem_id=tabname+"_extra_refresh", visible=False) tab_controls = [ - edit_search, dropdown_sort, button_sortorder, button_refresh, diff --git a/style.css b/style.css index aaafaa9d7e1..8aa41088dff 100644 --- a/style.css +++ b/style.css @@ -955,15 +955,15 @@ footer { color: white; } -.extra-network-pane .copy-path-button:before { +.extra-network-pane .copy-path-button::before { content: "⎘"; } -.extra-network-pane .metadata-button:before{ +.extra-network-pane .metadata-button::before{ content: "🛈"; } -.extra-network-pane .edit-button:before{ +.extra-network-pane .edit-button::before{ content: "🛠"; } @@ -1188,102 +1188,253 @@ body.resizing .resize-handle { border-left: 1px dashed var(--border-color-primary); } -.extra-network-pane .card-minimal { - display: inline-flex; - flex-grow: 1; - position: relative; - overflow: hidden; - cursor: pointer; - font-size: 1rem; - font-weight: bold; - line-break: anywhere; +/* ========================= */ +.extra-network-pane { + display: flex; } -/* Pushes buttons to right */ -.extra-network-pane .card-minimal .name { - flex-grow: 1; +.extra-network-pane .extra-network-cards { + display: block; } -.folder-container { - margin-left: 1.5em !important; +.extra-network-pane .extra-network-tree { + display: block; + font-size: 1rem; + min-width: 25%; + border: 1px solid var(--block-border-color); + overflow: hidden; } -.file-item, -.folder-item, -.folder-item-summary { - padding-left: 0.05rem; +.extra-network-tree .action-list--tree { cursor: pointer; + -webkit-user-select: none; + -moz-user-select: none; + -ms-user-select: none; user-select: none; - font-size: 1rem; + margin: 0; + padding: 0; } -.extra-network-pane .extra-network-tree .folder-item-summary:hover, -.extra-network-pane .extra-network-tree .file-item:hover { - -webkit-transition: all 0.1s ease-in-out; - transition: all 0.1s ease-in-out; - background-color: var(--neutral-200); +/* Remove auto indentation from tree. Will be overridden later. */ +.extra-network-tree .action-list--subgroup { + margin: 0 !important; + padding: 0 !important; + box-shadow: 0.6rem 0 0 var(--body-background-fill) inset, + 0.8rem 0 0 var(--neutral-800) inset; +} + +/* Set indentation for each depth of tree. */ +.extra-network-tree .action-list--subgroup > .action-list-item { + margin-left: 0.4rem !important; + padding-left: 0.4rem !important; } -.dark .extra-network-pane .extra-network-tree .folder-item-summary:hover, -.dark .extra-network-pane .extra-network-tree .file-item:hover { +/* Styles for tree
    elements. */ +.extra-network-tree .action-list { + +} + +/* Styles for tree
  • elements. */ +.extra-network-tree .action-list-item { + list-style: none; + position: relative; + background-color: transparent; +} + +/* Directory
      */ +.extra-network-tree .action-list-content[expanded=false]+.action-list--subgroup { + height: 0; + overflow: hidden; + visibility: hidden; + opacity: 0; +} + +.extra-network-tree .action-list-content[expanded=true]+.action-list--subgroup { + height: auto; + overflow: visible; + visibility: visible; + opacity: 1; +} + +/* File
    • */ +.extra-network-tree .action-list-item--subitem { +} + +/*
    • containing
        */ +.extra-network-tree .action-list-item--has-subitem { +} + +/* BUTTON ELEMENTS */ +/* \ No newline at end of file diff --git a/javascript/extraNetworks.js b/javascript/extraNetworks.js index 97a97d61f2c..56c9dc4c35d 100644 --- a/javascript/extraNetworks.js +++ b/javascript/extraNetworks.js @@ -267,7 +267,7 @@ function extraNetworksTreeProcessFileClick(event, btn, tabname, tab_id) { * Processes `onclick` events when user clicks on files in tree. * * @param event The generated event. - * @param btn The clicked `action-list-item` button. + * @param btn The clicked `tree-list-item` button. * @param tabname The name of the active tab in the sd webui. Ex: txt2img, img2img, etc. * @param tab_id The id of the active extraNetworks tab. Ex: lora, checkpoints, etc. */ @@ -288,7 +288,7 @@ function extraNetworksTreeProcessDirectoryClick(event, btn, tabname, tab_id) { * chevron is clicked: Directory is expanded or collapsed. Selected state unchanged. * * @param event The generated event. - * @param btn The clicked `action-list-item` button. + * @param btn The clicked `tree-list-item` button. * @param tabname The name of the active tab in the sd webui. Ex: txt2img, img2img, etc. * @param tab_id The id of the active extraNetworks tab. Ex: lora, checkpoints, etc. */ @@ -310,7 +310,7 @@ function extraNetworksTreeProcessDirectoryClick(event, btn, tabname, tab_id) { function _remove_selected_from_all() { // Removes the `selected` attribute from all buttons. - var sels = document.querySelectorAll("button.action-list-content"); + var sels = document.querySelectorAll("button.tree-list-content"); [...sels].forEach(el => { el.removeAttribute("selected"); }); @@ -331,7 +331,7 @@ function extraNetworksTreeProcessDirectoryClick(event, btn, tabname, tab_id) { // If user clicks on the chevron, then we do not select the folder. - if (true_targ.matches(".action-list-item-action--leading, .action-list-item-action-chevron")) { + if (true_targ.matches(".tree-list-item-action--leading, .tree-list-item-action-chevron")) { _expand_or_collapse(ul, btn); } else { // User clicked anywhere else on the button. @@ -356,7 +356,7 @@ function extraNetworksTreeProcessDirectoryClick(event, btn, tabname, tab_id) { function extraNetworksTreeOnClick(event, tabname, tab_id) { /** - * Handles `onclick` events for buttons within an `extra-network-tree .action-list--tree`. + * Handles `onclick` events for buttons within an `extra-network-tree .tree-list--tree`. * * Determines whether the clicked button in the tree is for a file entry or a directory * then calls the appropriate function. diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index 9cf5b57fbcb..a49c6c1c54b 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -20,28 +20,28 @@ default_allowed_preview_extensions = ["png", "jpg", "jpeg", "webp", "gif"] tree_tpl = ( - " \ No newline at end of file diff --git a/html/extra-networks-tree.html b/html/extra-networks-tree.html new file mode 100644 index 00000000000..4d29b1be273 --- /dev/null +++ b/html/extra-networks-tree.html @@ -0,0 +1,42 @@ +
        +
        + +
        + +
        +
        + +
        +
        + +
        +
        +
        + {tree} +
        +
        \ No newline at end of file diff --git a/javascript/extraNetworks.js b/javascript/extraNetworks.js index 56c9dc4c35d..cf98452a296 100644 --- a/javascript/extraNetworks.js +++ b/javascript/extraNetworks.js @@ -37,15 +37,20 @@ function setupExtraNetworksForTab(tabname) { return; // `continue` doesn't work in `forEach` loops. This is equivalent. } - var tabs = gradioApp().querySelector('#' + tabname + '_extra_tabs > div'); - var sort = gradioApp().getElementById(tabname + '_extra_sort'); - var sortOrder = gradioApp().getElementById(tabname + '_extra_sortorder'); - var refresh = gradioApp().getElementById(tabname + '_extra_refresh'); - var promptContainer = gradioApp().querySelector('.prompt-container-compact#' + tabname + '_prompt_container'); - var negativePrompt = gradioApp().querySelector('#' + tabname + '_neg_prompt'); - tabs.appendChild(sort); - tabs.appendChild(sortOrder); - tabs.appendChild(refresh); + var sort = gradioApp().querySelector("#" + tab_id + "_extra_sort"); + if (!sort) { + return; // `continue` doesn't work in `forEach` loops. This is equivalent. + } + + var sort_dir = gradioApp().querySelector("#" + tab_id + "_extra_sort_dir"); + if (!sort_dir) { + return; // `continue` doesn't work in `forEach` loops. This is equivalent. + } + + var refresh = gradioApp().querySelector("#" + tab_id + "_extra_refresh"); + if (!refresh) { + return; // `continue` doesn't work in `forEach` loops. This is equivalent. + } var applyFilter = function() { var searchTerm = search.value.toLowerCase(); @@ -72,8 +77,8 @@ function setupExtraNetworksForTab(tabname) { var applySort = function() { var cards = gradioApp().querySelectorAll('#' + tabname + '_extra_tabs div.card'); - var reverse = sortOrder.classList.contains("sortReverse"); - var sortKey = sort.querySelector("input").value.toLowerCase().replace("sort", "").replaceAll(" ", "_").replace(/_+$/, "").trim() || "name"; + var reverse = sort_dir.dataset.sortdir == "Descending"; + var sortKey = sort.dataset.sortmode.toLowerCase().replace("sort", "").replaceAll(" ", "_").replace(/_+$/, "").trim() || "name"; sortKey = "sort" + sortKey.charAt(0).toUpperCase() + sortKey.slice(1); var sortKeyStore = sortKey + "-" + (reverse ? "Descending" : "Ascending") + "-" + cards.length; @@ -107,10 +112,7 @@ function setupExtraNetworksForTab(tabname) { }; search.addEventListener("input", applyFilter); - sortOrder.addEventListener("click", function() { - sortOrder.classList.toggle("sortReverse"); - applySort(); - }); + applySort(); applyFilter(); extraNetworksApplySort[tab_id] = applySort; @@ -274,7 +276,7 @@ function extraNetworksTreeProcessFileClick(event, btn, tabname, tab_id) { var par = btn.parentElement; var search_id = tabname + "_" + tab_id + "_extra_search"; var type = par.getAttribute("data-tree-entry-type"); - var path = par.getAttribute("data-path"); + var path = btn.getAttribute("data-path"); } function extraNetworksTreeProcessDirectoryClick(event, btn, tabname, tab_id) { @@ -310,7 +312,7 @@ function extraNetworksTreeProcessDirectoryClick(event, btn, tabname, tab_id) { function _remove_selected_from_all() { // Removes the `selected` attribute from all buttons. - var sels = document.querySelectorAll("button.tree-list-content"); + var sels = document.querySelectorAll("div.tree-list-content"); [...sels].forEach(el => { el.removeAttribute("selected"); }); @@ -345,11 +347,11 @@ function extraNetworksTreeProcessDirectoryClick(event, btn, tabname, tab_id) { // NOTE: Double inversion sucks but it is the clearest way to show the branching here. _expand_or_collapse(ul, btn); _select_button(btn, tabname, tab_id); - _update_search(tabname, tab_id, btn.parentElement.getAttribute("data-path")); + _update_search(tabname, tab_id, btn.getAttribute("data-path")); } else { // All other cases, just select the button. _select_button(btn, tabname, tab_id); - _update_search(tabname, tab_id, btn.parentElement.getAttribute("data-path")); + _update_search(tabname, tab_id, btn.getAttribute("data-path")); } } } @@ -374,6 +376,48 @@ function extraNetworksTreeOnClick(event, tabname, tab_id) { } } +function extraNetworksTreeSortOnClick(event, tabname, tab_id) { + var curr_mode = event.currentTarget.dataset.sortmode; + var el_sort_dir = gradioApp().querySelector("#" + tabname + "_" + tab_id + "_extra_sort_dir"); + var sort_dir = el_sort_dir.dataset.sortdir; + if (curr_mode == "path") { + event.currentTarget.dataset.sortmode = "name"; + event.currentTarget.dataset.sortkey = "sortName-" + sort_dir + "-640"; + event.currentTarget.setAttribute("title", "Sort by filename"); + } else if (curr_mode == "name") { + event.currentTarget.dataset.sortmode = "date_created"; + event.currentTarget.dataset.sortkey = "sortDate_created-" + sort_dir + "-640"; + event.currentTarget.setAttribute("title", "Sort by date created"); + } else if (curr_mode == "date_created") { + event.currentTarget.dataset.sortmode = "date_modified"; + event.currentTarget.dataset.sortkey = "sortDate_modified-" + sort_dir + "-640"; + event.currentTarget.setAttribute("title", "Sort by date modified"); + } else { + event.currentTarget.dataset.sortmode = "path"; + event.currentTarget.dataset.sortkey = "sortPath-" + sort_dir + "-640"; + event.currentTarget.setAttribute("title", "Sort by path"); + } + applyExtraNetworkSort(tabname + "_" + tab_id); +} + +function extraNetworksTreeSortDirOnClick(event, tabname, tab_id) { + var curr_dir = event.currentTarget.getAttribute("data-sortdir"); + if (curr_dir == "Ascending") { + event.currentTarget.dataset.sortdir = "Descending"; + event.currentTarget.setAttribute("title", "Sort descending"); + } else { + event.currentTarget.dataset.sortdir = "Ascending"; + event.currentTarget.setAttribute("title", "Sort ascending"); + } + applyExtraNetworkSort(tabname + "_" + tab_id); +} + +function extraNetworksTreeRefreshOnClick(event, tabname, tab_id) { + console.log("refresh clicked"); + var btn_refresh_internal = gradioApp().getElementById(tabname + "_extra_refresh_internal"); + btn_refresh_internal.dispatchEvent(new Event("click")); +} + var globalPopup = null; var globalPopupInner = null; diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index a49c6c1c54b..4ba2bea1a8c 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -19,50 +19,6 @@ allowed_dirs = set() default_allowed_preview_extensions = ["png", "jpg", "jpeg", "webp", "gif"] -tree_tpl = ( - "" - "
          " - "{content}" - "
        " -) - -tree_ul_tpl = ( - "
          " - "{content}" - "
        " -) - -tree_li_dir_tpl = ( - "
      • " - "{content}" - "
      • " -) -tree_li_file_tpl = ( - "
      • " - "{content}" - "
      • " -) - -action_list_item_action_leading = ( - "" - "" - "" -) - @functools.cache def allowed_preview_extensions_with_extra(extra_extensions=None): return set(default_allowed_preview_extensions) | set(extra_extensions or []) @@ -201,9 +157,13 @@ def __init__(self, title): self.title = title self.name = title.lower() self.id_page = self.name.replace(" ", "_") - self.extra_networks_pane_template = shared.html("extra-networks-pane.html") - self.card_page_template = shared.html("extra-networks-card.html") - self.tree_button_template = shared.html("extra-networks-tree-button.html") + self.pane_tpl = shared.html("extra-networks-pane.html") + self.tree_tpl = shared.html("extra-networks-tree.html") + self.card_tpl = shared.html("extra-networks-card.html") + self.btn_tree_tpl = shared.html("extra-networks-tree-button.html") + self.btn_copy_path_tpl = shared.html("extra-networks-copy-path-button.html") + self.btn_metadata_tpl = shared.html("extra-networks-metadata-button.html") + self.btn_edit_item_tpl = shared.html("extra-networks-edit-item-button.html") self.allow_prompt = True self.allow_negative_prompt = False self.metadata = {} @@ -268,12 +228,8 @@ def create_item_html( onclick = item.get("onclick", None) if onclick is None: - print("HERE") - print("TABNAME:", tabname) - print("PROMPT:", item["prompt"]) - print("NEG_PROMPT:", item.get("negative_prompt", "")) - print("ALLOW_NEG:", self.allow_negative_prompt) - onclick_js_tpl = "cardClicked('{tabname}', '{prompt}', '{neg_prompt}', '{allow_neg}');" + # Don't quote prompt/neg_prompt since they are stored as js strings already. + onclick_js_tpl = "cardClicked('{tabname}', {prompt}, {neg_prompt}, '{allow_neg}');" onclick = onclick_js_tpl.format( **{ "tabname": tabname, @@ -284,15 +240,23 @@ def create_item_html( ) onclick = html.escape(onclick) - - copy_path_button = f"
        " - - metadata_button = "" + btn_copy_path = self.btn_copy_path_tpl.format(**{"filename": item["filename"]}) + btn_metadata = "" metadata = item.get("metadata") if metadata: - metadata_button = f"" - - edit_button = f"
        " + btn_metadata = self.btn_metadata_tpl.format( + **{ + "page_id": self.id_page, + "name": html.escape(item["name"]), + } + ) + btn_edit_item = self.btn_edit_item_tpl.format( + **{ + "tabname": tabname, + "page_id": self.id_page, + "name": html.escape(item["name"]), + } + ) local_path = "" filename = item.get("filename", "") @@ -334,11 +298,11 @@ def create_item_html( args = { "background_image": background_image, "card_clicked": onclick, - "copy_path_button": copy_path_button, + "copy_path_button": btn_copy_path, "description": (item.get("description") or "" if shared.opts.extra_networks_card_show_desc else ""), - "edit_button": edit_button, + "edit_button": btn_edit_item, "local_preview": quote_js(item["local_preview"]), - "metadata_button": metadata_button, + "metadata_button": btn_metadata, "name": html.escape(item["name"]), "prompt": item.get("prompt", None), "save_card_preview": '"' + html.escape(f"""return saveCardPreview(event, {quote_js(tabname)}, {quote_js(item["local_preview"])})""") + '"', @@ -355,6 +319,57 @@ def create_item_html( else: return args + def create_tree_dir_item_html(self, tabname: str, dir_path: str, content: Optional[str] = None) -> Optional[str]: + if not content: + return None + + btn = self.btn_tree_tpl.format( + **{ + "search_terms": "", + "subclass": "tree-list-content-dir", + "tabname": tabname, + "tab_id": self.id_page, + "onclick_extra": "", + "data_path": dir_path, + "data_hash": "", + "action_list_item_action_leading": "", + "action_list_item_visual_leading": "🗀", + "action_list_item_label": os.path.basename(dir_path), + "action_list_item_visual_trailing": "", + "action_list_item_action_trailing": "", + } + ) + ul = f"
          {content}
        " + return f"
      • {btn + ul}
      • " + + def create_tree_file_item_html(self, tabname: str, item_name: str, item: dict) -> str: + item_html_args = self.create_item_html(tabname, item) + action_buttons = "".join( + [ + item_html_args["copy_path_button"], + item_html_args["metadata_button"], + item_html_args["edit_button"], + ] + ) + action_buttons = f"
        {action_buttons}
        " + btn = self.btn_tree_tpl.format( + **{ + "search_terms": "", + "subclass": "tree-list-content-file", + "tabname": tabname, + "tab_id": self.id_page, + "onclick_extra": item_html_args["card_clicked"], + "data_path": item_name, + "data_hash": item["shorthash"], + "action_list_item_action_leading": "", + "action_list_item_visual_leading": "🗎", + "action_list_item_label": item["name"], + "action_list_item_visual_trailing": "", + "action_list_item_action_trailing": action_buttons, + } + ) + return f"
      • {btn}
      • " + def create_tree_view_html(self, tabname: str) -> str: """Generates HTML for displaying folders in a tree view. @@ -385,57 +400,9 @@ def _build_tree(data: Optional[dict[str, ExtraNetworksItem]] = None) -> Optional for k, v in sorted(data.items(), key=lambda x: shared.natural_sort_key(x[0])): if isinstance(v, (ExtraNetworksItem,)): - _item_html_args = self.create_item_html(tabname, v.item) - _action_buttons = "".join( - [ - _item_html_args["copy_path_button"], - _item_html_args["metadata_button"], - _item_html_args["edit_button"], - ] - ) - _action_buttons = f"
        {_action_buttons}
        " - _btn = self.tree_button_template.format( - **{ - "search_terms": "", - "subclass": "tree-list-content-file", - "tabname": tabname, - "tab_id": self.id_page, - "onclick_extra": _item_html_args["card_clicked"], - "action_list_item_action_leading": action_list_item_action_leading, - "action_list_item_visual_leading": "🗎", - "action_list_item_label": v.item["name"], - "action_list_item_visual_trailing": "", - "action_list_item_action_trailing": _action_buttons, - } - ) - - _li = tree_li_file_tpl.format( - **{ - "hash": v.item["shorthash"], - "path": k, - "type": "file", - "content": _btn, - } - ) - _file_li.append(_li) + _file_li.append(self.create_tree_file_item_html(tabname, k, v.item)) else: - _btn = self.tree_button_template.format( - **{ - "search_terms": "", - "subclass": "tree-list-content-dir", - "tabname": tabname, - "tab_id": self.id_page, - "onclick_extra": "", - "action_list_item_action_leading": action_list_item_action_leading, - "action_list_item_visual_leading": "🗀", - "action_list_item_label": os.path.basename(k), - "action_list_item_visual_trailing": "", - "action_list_item_action_trailing": "", - } - ) - _ul = tree_ul_tpl.format(**{"content": _build_tree(v)}) - _li = tree_li_dir_tpl.format(**{"content": _btn + _ul, "path": k}) - _dir_li.append(_li) + _dir_li.append(self.create_tree_dir_item_html(tabname, k, _build_tree(v))) # Directories should always be displayed before files. return "".join(_dir_li) + "".join(_file_li) @@ -443,31 +410,15 @@ def _build_tree(data: Optional[dict[str, ExtraNetworksItem]] = None) -> Optional # Add each root directory to the tree. for k, v in sorted(tree.items(), key=lambda x: shared.natural_sort_key(x[0])): # If root is empty, append the "disabled" attribute to the template details tag. - btn = self.tree_button_template.format( - **{ - "search_terms": "", - "subclass": "tree-list-content-dir", - "tabname": tabname, - "tab_id": self.id_page, - "onclick_extra": "", - "action_list_item_action_leading": action_list_item_action_leading, - "action_list_item_visual_leading": "🗀", - "action_list_item_label": os.path.basename(k), - "action_list_item_visual_trailing": "", - "action_list_item_action_trailing": "", - } - ) - subtree = _build_tree(v) - if subtree: - ul = tree_ul_tpl.format(**{"content": _build_tree(v)}) - li = tree_li_dir_tpl.format(**{"content": btn + ul, "path": k}) - res += li + item_html = self.create_tree_dir_item_html(tabname, k, _build_tree(v)) + if item_html: + res += item_html - return tree_tpl.format( + return self.tree_tpl.format( **{ - "content": res, "tabname": tabname, "tab_id": self.id_page, + "tree": f"
          {res}
        " } ) @@ -475,8 +426,7 @@ def create_card_view_html(self, tabname): res = "" self.items = {x["name"]: x for x in self.list_items()} for item in self.items.values(): - print("HEEEERRE:", item) - res += self.create_item_html(tabname, item, self.card_page_template) + res += self.create_item_html(tabname, item, self.card_tpl) if res == "": dirs = "".join([f"
      • {x}
      • " for x in self.allowed_directories_for_previews()]) @@ -493,7 +443,7 @@ def create_html(self, tabname): card_view_html = self.create_card_view_html(tabname) network_type_id = self.id_page - return self.extra_networks_pane_template.format( + return self.pane_tpl.format( **{ "tabname": tabname, "network_type_id": network_type_id, @@ -612,6 +562,8 @@ def create_ui(interface: gr.Blocks, unrelated_tabs, tabname): related_tabs = [] + button_refresh = gr.Button("Refresh", elem_id=tabname+"_extra_refresh_internal", visible=False) + for page in ui.stored_extra_pages: with gr.Tab(page.title, elem_id=f"{tabname}_{page.id_page}", elem_classes=["extra-page"]) as tab: with gr.Column(elem_id=f"{tabname}_{page.id_page}_prompts", elem_classes=["extra-page-prompts"]): @@ -633,51 +585,9 @@ def create_ui(interface: gr.Blocks, unrelated_tabs, tabname): related_tabs.append(tab) - dropdown_sort = gr.Dropdown(choices=['Path', 'Name', 'Date Created', 'Date Modified', ], value=shared.opts.extra_networks_card_order_field, elem_id=tabname+"_extra_sort", elem_classes="sort", multiselect=False, visible=False, show_label=False, interactive=True, label=tabname+"_extra_sort_order") - button_sortorder = ToolButton(switch_values_symbol, elem_id=tabname+"_extra_sortorder", elem_classes=["sortorder"] + ([] if shared.opts.extra_networks_card_order == "Ascending" else ["sortReverse"]), visible=False, tooltip="Invert sort order") - button_refresh = gr.Button('Refresh', elem_id=tabname+"_extra_refresh", visible=False) - - tab_controls = [ - dropdown_sort, - button_sortorder, - button_refresh, - ] - ui.button_save_preview = gr.Button('Save preview', elem_id=tabname+"_save_preview", visible=False) ui.preview_target_filename = gr.Textbox('Preview save filename', elem_id=tabname+"_preview_filename", visible=False) - for tab in unrelated_tabs: - tab.select( - fn=lambda: [gr.update(visible=False) for _ in tab_controls], - _js=f"function(){{ extraNetworksUnrelatedTabSelected('{tabname}'); }}", - inputs=[], - outputs=tab_controls, - show_progress=False, - ) - - for page, tab in zip(ui.stored_extra_pages, related_tabs): - allow_prompt = "true" if page.allow_prompt else "false" - allow_negative_prompt = "true" if page.allow_negative_prompt else "false" - - jscode = ( - "extraNetworksTabSelected(" - f"'{tabname}', " - f"'{tabname}_{page.id_page}_prompts', " - f"'{allow_prompt}', " - f"'{allow_negative_prompt}'" - ");" - ) - - tab.select( - fn=lambda: [gr.update(visible=True) for _ in tab_controls], - _js="function(){ " + jscode + " }", - inputs=[], - outputs=tab_controls, - show_progress=False, - ) - - dropdown_sort.change(fn=lambda: None, _js="function(){ applyExtraNetworkSort('" + tabname + "'); }") - def create_html(): ui.pages_contents = [pg.create_html(ui.tabname) for pg in ui.stored_extra_pages] @@ -693,6 +603,8 @@ def refresh(): return ui.pages_contents interface.load(fn=pages_html, inputs=[], outputs=ui.pages) + # NOTE: Event is manually fired in extraNetworks.js:extraNetworksTreeRefreshOnClick() + # button is unused and hidden at all times. Only used in order to fire this event. button_refresh.click(fn=refresh, inputs=[], outputs=ui.pages) return ui diff --git a/style.css b/style.css index 2dafe97fd87..08573248654 100644 --- a/style.css +++ b/style.css @@ -1196,17 +1196,33 @@ body.resizing .resize-handle { overflow: hidden; } -.extra-network-tree .tree-list--tree { - cursor: pointer; - -webkit-user-select: none; - -moz-user-select: none; - -ms-user-select: none; - user-select: none; - margin: 0; +.extra-network-tree .tree-list { + margin: 0 0.25rem; padding: 0; - margin-left: 0.25rem; } +.extra-network-tree .tree-list .tree-list-controls { + position: relative; + display: grid; + width: 100%; + padding: 0 !important; + margin-top: 0 !important; + margin-bottom: 0 !important; + font-size: 1rem; + text-align: left; + user-select: none; + background-color: transparent; + border: none; + transition: background 33.333ms linear; + grid-template-rows: min-content; + grid-template-areas: "tree-list-controls-col-0 tree-list-controls-col-1 tree-list-controls-col-2 tree-list-controls-col-3"; + grid-template-columns: minmax(0, auto) min-content min-content min-content; + grid-gap: 0.1rem; + align-items: start; +} + +.extra-network-tree .tree-list--tree {} + /* Remove auto indentation from tree. Will be overridden later. */ .extra-network-tree .tree-list--subgroup { margin: 0 !important; @@ -1221,9 +1237,6 @@ body.resizing .resize-handle { padding-left: 0.4rem !important; } -/* Styles for tree
          elements. */ -.extra-network-tree .tree-list {} - /* Styles for tree
        • elements. */ .extra-network-tree .tree-list-item { list-style: none; @@ -1288,26 +1301,182 @@ body.resizing .resize-handle { padding-top: 0.5rem !important; } -.dark .extra-network-tree button.tree-list-content:hover { +.dark .extra-network-tree div.tree-list-content:hover { -webkit-transition: all 0.05s ease-in-out; transition: all 0.05s ease-in-out; background-color: var(--neutral-800); } -.dark .extra-network-tree button.tree-list-content[selected] { +.dark .extra-network-tree div.tree-list-content[selected] { background-color: var(--neutral-700); } -.extra-network-tree button.tree-list-content:hover { +.extra-network-tree div.tree-list-content:hover { -webkit-transition: all 0.05s ease-in-out; transition: all 0.05s ease-in-out; background-color: var(--neutral-200); } -.extra-network-tree button.tree-list-content[selected] { +.extra-network-tree div.tree-list-content[selected] { background-color: var(--neutral-300); } +/* ==== CHEVRON ICON ACTIONS ==== */ +/* Define the animation for the arrow when it is clicked. */ +.extra-network-tree .tree-list-content-dir[expanded=false] .tree-list-item-action-chevron { + -ms-transform: rotate(135deg); + -webkit-transform: rotate(135deg); + transform: rotate(135deg); + transition: transform 0.2s; +} + +.extra-network-tree .tree-list-content-dir[expanded=true] .tree-list-item-action-chevron { + -ms-transform: rotate(225deg); + -webkit-transform: rotate(225deg); + transform: rotate(225deg); + transition: transform 0.2s; +} + +.tree-list-item-action-chevron { + display: inline-flex; + /* Uses box shadow to generate a pseudo chevron `>` icon. */ + padding: 0.3rem; + box-shadow: 0.1rem 0.1rem 0 0 var(--neutral-200) inset; + transform: rotate(135deg); +} + +/* ==== SEARCH INPUT ACTIONS ==== */ +/* Add icon to left side of */ +.extra-network-tree .tree-list-controls .tree-list-search::before { + content: "🔎︎"; + position: absolute; + margin: 0.5rem; + font-size: 1rem; + color: var(--input-placeholder-color); +} + +.extra-network-tree .tree-list-controls .tree-list-search { + display: inline-flex; + grid-area: tree-list-controls-col-0; + position: relative; + margin: 0.5rem; +} + +.extra-network-tree .tree-list-controls .tree-list-search .tree-list-search-text { + border: 1px solid var(--button-secondary-border-color); + border-radius: 0.5rem; + color: var(--button-secondary-text-color); + background-color: transparent; + width: 100%; + padding-left: 2rem; + line-height: 1rem; +} + +/* clear button (x on right side) styling */ +.extra-network-tree .tree-list-controls .tree-list-search .tree-list-search-text::-webkit-search-cancel-button { + -webkit-appearance: none; + appearance: none; + cursor: pointer; + height: 1rem; + width: 1rem; + mask-image: url('data:image/svg+xml,'); + mask-repeat: no-repeat; + mask-position: center center; + mask-size: 100%; + background-color: var(--input-placeholder-color); +} + +/* ==== SORT ICON ACTIONS ==== */ +.extra-network-tree .tree-list-controls .tree-list-sort { + grid-area: tree-list-controls-col-1; + padding: 0.25rem; + display: inline-flex; + cursor: pointer; + justify-self: center; + align-self: center; +} + +.extra-network-tree .tree-list-controls .tree-list-sort .tree-list-sort-icon { + height: 1.5rem; + width: 1.5rem; + mask-repeat: no-repeat; + mask-position: center center; + mask-size: 100%; + background-color: var(--input-placeholder-color); +} + +.extra-network-tree .tree-list-sort[data-sortmode="path"] .tree-list-sort-icon { + mask-image: url('data:image/svg+xml,'); +} + +.extra-network-tree .tree-list-sort[data-sortmode="name"] .tree-list-sort-icon { + mask-image: url('data:image/svg+xml,'); +} + +.extra-network-tree .tree-list-sort[data-sortmode="date_created"] .tree-list-sort-icon { + mask-image: url('data:image/svg+xml,'); +} + +.extra-network-tree .tree-list-sort[data-sortmode="date_modified"] .tree-list-sort-icon { + mask-image: url('data:image/svg+xml,'); +} + +/* ==== SORT DIRECTION ICON ACTIONS ==== */ +.extra-network-tree .tree-list-controls .tree-list-sort-dir { + grid-area: tree-list-controls-col-2; + padding: 0.25rem; + display: inline-flex; + cursor: pointer; + justify-self: center; + align-self: center; +} + +.extra-network-tree .tree-list-controls .tree-list-sort-dir .tree-list-sort-dir-icon { + height: 1.5rem; + width: 1.5rem; + mask-repeat: no-repeat; + mask-position: center center; + mask-size: 100%; + background-color: var(--input-placeholder-color); +} + +.extra-network-tree .tree-list-sort-dir[data-sortdir="Ascending"] .tree-list-sort-dir-icon { + mask-image: url('data:image/svg+xml,'); +} + +.extra-network-tree .tree-list-sort-dir[data-sortdir="Descending"] .tree-list-sort-dir-icon { + mask-image: url('data:image/svg+xml,'); +} + +/* ==== REFRESH ICON ACTIONS ==== */ +.extra-network-tree .tree-list-controls .tree-list-refresh { + grid-area: tree-list-controls-col-3; + padding: 0.25rem; + display: inline-flex; + cursor: pointer; + justify-self: center; + align-self: center; +} + +.extra-network-tree .tree-list-controls .tree-list-refresh .tree-list-refresh-icon { + height: 1.5rem; + width: 1.5rem; + mask-image: url('data:image/svg+xml,'); + mask-repeat: no-repeat; + mask-position: center center; + mask-size: 100%; + background-color: var(--input-placeholder-color); +} + +.extra-network-tree .tree-list-refresh-icon:active { + -ms-transform: rotate(180deg); + -webkit-transform: rotate(180deg); + transform: rotate(180deg); + transition: transform 0.2s; +} + +/* ==== TREE GRID CONFIG ==== */ + /* Text for button. */ .extra-network-tree .tree-list-item-label { position: relative; @@ -1332,6 +1501,7 @@ body.resizing .resize-handle { align-items: right; } + /* Icon for button when it is before label. */ .extra-network-tree .tree-list-item-visual--leading { grid-area: leading-visual; @@ -1348,7 +1518,7 @@ body.resizing .resize-handle { /* Dropdown arrow for button. */ .extra-network-tree .tree-list-item-action--leading { - margin-right: 0.2rem; + margin-right: 0.5rem; margin-left: 0.2rem; } @@ -1356,30 +1526,6 @@ body.resizing .resize-handle { visibility: hidden; } -/* Define the animation for the arrow when it is clicked. */ -.extra-network-tree .tree-list-content-dir[expanded=false] .tree-list-item-action-chevron { - -ms-transform: rotate(135deg); - -webkit-transform: rotate(135deg); - transform: rotate(135deg); - transition: transform 0.2s; -} - -.extra-network-tree .tree-list-content-dir[expanded=true] .tree-list-item-action-chevron { - -ms-transform: rotate(225deg); - -webkit-transform: rotate(225deg); - transform: rotate(225deg); - transition: transform 0.2s; -} - -.tree-list-item-action-chevron { - display: inline-flex; - /* Uses box shadow to generate a pseudo chevron `>` icon. */ - padding: 0.3rem; - box-shadow: 0.1rem 0.1rem 0 0 var(--neutral-200) inset; - transform: rotate(135deg); -} - - .extra-network-tree .tree-list-item-action--leading { grid-area: leading-action; } @@ -1399,41 +1545,3 @@ body.resizing .resize-handle { .extra-network-tree .tree-list-content:hover .button-row { visibility: visible; } - -/* Add icon to left side of */ -.extra-network-tree .tree-list-search::before { - content: "🔎︎"; - position: absolute; - margin: 0.5rem; - font-size: 1rem; - color: var(--input-placeholder-color); -} - -.extra-network-tree .tree-list-search { - position: relative; - margin: 0.5rem; -} - -.extra-network-tree .tree-list-search .tree-list-search-text { - border: 1px solid var(--button-secondary-border-color); - border-radius: 0.5rem; - color: var(--button-secondary-text-color); - background-color: transparent; - width: 100%; - padding-left: 2rem; - line-height: 1rem; -} - -/* clear button (x on right side) styling */ -.extra-network-tree .tree-list-search .tree-list-search-text::-webkit-search-cancel-button { - -webkit-appearance: none; - appearance: none; - cursor: pointer; - height: 1rem; - width: 1rem; - mask-image: url('data:image/svg+xml,'); - mask-repeat: no-repeat; - mask-position: center center; - mask-size: 100%; - background-color: var(--input-placeholder-color); -} \ No newline at end of file From 1fdc18e6a01eb40889d46fab40f21aa138d64b01 Mon Sep 17 00:00:00 2001 From: Sj-Si Date: Mon, 15 Jan 2024 18:01:13 -0500 Subject: [PATCH 0592/1174] Run linting --- modules/ui_extra_networks.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index 4ba2bea1a8c..06fa22afada 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -13,7 +13,6 @@ from fastapi.exceptions import HTTPException from modules.infotext_utils import image_from_url_text -from modules.ui_components import ToolButton extra_pages = [] allowed_dirs = set() @@ -225,7 +224,6 @@ def create_item_html( width = f"width: {shared.opts.extra_networks_card_width}px;" if shared.opts.extra_networks_card_width else '' background_image = f'' if preview else '' - onclick = item.get("onclick", None) if onclick is None: # Don't quote prompt/neg_prompt since they are stored as js strings already. @@ -239,7 +237,7 @@ def create_item_html( } ) onclick = html.escape(onclick) - + btn_copy_path = self.btn_copy_path_tpl.format(**{"filename": item["filename"]}) btn_metadata = "" metadata = item.get("metadata") @@ -551,8 +549,6 @@ def tab_name_score(name): return sorted(pages, key=lambda x: tab_scores[x.name]) def create_ui(interface: gr.Blocks, unrelated_tabs, tabname): - from modules.ui import switch_values_symbol - ui = ExtraNetworksUi() ui.pages = [] ui.pages_contents = [] @@ -588,6 +584,9 @@ def create_ui(interface: gr.Blocks, unrelated_tabs, tabname): ui.button_save_preview = gr.Button('Save preview', elem_id=tabname+"_save_preview", visible=False) ui.preview_target_filename = gr.Textbox('Preview save filename', elem_id=tabname+"_preview_filename", visible=False) + for tab in unrelated_tabs: + tab.select(fn=None, _js='function(){ extraNetworksUrelatedTabSelected("' + tabname + '"); }', inputs=[], outputs=[], show_progress=False) + def create_html(): ui.pages_contents = [pg.create_html(ui.tabname) for pg in ui.stored_extra_pages] From c1e04c63b3d2a5102b4cf6deddea337c8a964c53 Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Tue, 16 Jan 2024 14:18:20 +0900 Subject: [PATCH 0593/1174] callback postprocess_image_after_composite --- modules/processing.py | 5 +++++ modules/scripts.py | 17 +++++++++++++++++ 2 files changed, 22 insertions(+) diff --git a/modules/processing.py b/modules/processing.py index dcc807fe38d..2c3365a026f 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -1029,6 +1029,11 @@ def infotext(index=0, use_main_prompt=False): image = apply_overlay(image, p.paste_to, overlay_image) + if p.scripts is not None: + pp = scripts.PostprocessImageArgs(image) + p.scripts.postprocess_image_after_composite(p, pp) + image = pp.image + if save_samples: images.save_image(image, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p) diff --git a/modules/scripts.py b/modules/scripts.py index cf938ebb972..060069cf36a 100644 --- a/modules/scripts.py +++ b/modules/scripts.py @@ -262,6 +262,15 @@ def postprocess_maskoverlay(self, p, ppmo: PostProcessMaskOverlayArgs, *args): pass + def postprocess_image_after_composite(self, p, pp: PostprocessImageArgs, *args): + """ + Called for every image after it has been generated. + Same as postprocess_image but after inpaint_full_res composite + So that it operates on the full image instead of the inpaint_full_res crop region. + """ + + pass + def postprocess(self, p, processed, *args): """ This function is called after processing ends for AlwaysVisible scripts. @@ -856,6 +865,14 @@ def postprocess_maskoverlay(self, p, ppmo: PostProcessMaskOverlayArgs): except Exception: errors.report(f"Error running postprocess_image: {script.filename}", exc_info=True) + def postprocess_image_after_composite(self, p, pp: PostprocessImageArgs): + for script in self.alwayson_scripts: + try: + script_args = p.script_args[script.args_from:script.args_to] + script.postprocess_image_after_composite(p, pp, *script_args) + except Exception: + errors.report(f"Error running postprocess_image_after_composite: {script.filename}", exc_info=True) + def before_component(self, component, **kwargs): for callback, script in self.on_before_component_elem_id.get(kwargs.get("elem_id"), []): try: From 14b9762bcab41fe0f7411498d9d3ffdc69ea1dda Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Tue, 16 Jan 2024 17:08:07 +0900 Subject: [PATCH 0594/1174] immediately stop on second interrupt Revert "immediately stop on second interrupt" This reverts commit ab409072a1f2a9c911a63aee98a6b42081803cdc. immediately stop on second interrupt --- modules/ui_toprow.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/modules/ui_toprow.py b/modules/ui_toprow.py index 1abc9117ba3..fbe705be130 100644 --- a/modules/ui_toprow.py +++ b/modules/ui_toprow.py @@ -107,8 +107,9 @@ def create_submit_box(self): ) def interrupt_function(): - if shared.state.job_count > 1 and shared.opts.interrupt_after_current: + if not shared.state.stopping_generation and shared.state.job_count > 1 and shared.opts.interrupt_after_current: shared.state.stop_generating() + gr.Info("Generation will stop after finishing this image, click again to stop immediately.") else: shared.state.interrupt() From a75dfe1c0de1564f4607a6f4ed7f4a7c00ef18a0 Mon Sep 17 00:00:00 2001 From: Arturo Albacete Date: Tue, 16 Jan 2024 19:03:48 +0100 Subject: [PATCH 0595/1174] - expand fields to include model name and hash - write these in the CSV log file - ensure old log files are updated w.r.t delimiter count --- modules/ui_common.py | 47 ++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 45 insertions(+), 2 deletions(-) diff --git a/modules/ui_common.py b/modules/ui_common.py index f17259c2956..cbb2495d6c9 100644 --- a/modules/ui_common.py +++ b/modules/ui_common.py @@ -36,6 +36,29 @@ def plaintext_to_html(text, classname=None): return f"

          {content}

          " if classname else f"

          {content}

          " +def update_logfile(logfile_path, fields): + import csv + + with open(logfile_path, "r", encoding="utf8", newline="") as file: + reader = csv.reader(file) + rows = list(reader) + + # blank file: leave it as is + if not rows: + return + + rows[0] = fields + + # append new fields to each row as empty values + for row in rows[1:]: + while len(row) < len(fields): + row.append("") + + with open(logfile_path, "w", encoding="utf8", newline="") as file: + writer = csv.writer(file) + writer.writerows(rows) + + def save_files(js_data, images, do_make_zip, index): import csv filenames = [] @@ -64,11 +87,31 @@ def __init__(self, d=None): os.makedirs(shared.opts.outdir_save, exist_ok=True) + fields = [ + "prompt", + "seed", + "width", + "height", + "sampler", + "cfgs", + "steps", + "filename", + "negative_prompt", + "sd_model_name", + "sd_model_hash", + ] + logfile_path = os.path.join(shared.opts.outdir_save, "log.csv") + + # NOTE: ensure csv integrity when fields are added by + # updating headers and padding with delimeters where needed + if os.path.exists(logfile_path): + update_logfile(logfile_path, fields) + with open(os.path.join(shared.opts.outdir_save, "log.csv"), "a", encoding="utf8", newline='') as file: at_start = file.tell() == 0 writer = csv.writer(file) if at_start: - writer.writerow(["prompt", "seed", "width", "height", "sampler", "cfgs", "steps", "filename", "negative_prompt"]) + writer.writerow(fields) for image_index, filedata in enumerate(images, start_index): image = image_from_url_text(filedata) @@ -86,7 +129,7 @@ def __init__(self, d=None): filenames.append(os.path.basename(txt_fullfn)) fullfns.append(txt_fullfn) - writer.writerow([data["prompt"], data["seed"], data["width"], data["height"], data["sampler_name"], data["cfg_scale"], data["steps"], filenames[0], data["negative_prompt"]]) + writer.writerow([data["prompt"], data["seed"], data["width"], data["height"], data["sampler_name"], data["cfg_scale"], data["steps"], filenames[0], data["negative_prompt"], data["sd_model_name"], data["sd_model_hash"]]) # Make Zip if do_make_zip: From 315e40a49c32438551ed6b66138acdf664ecdbc8 Mon Sep 17 00:00:00 2001 From: Arturo Albacete Date: Tue, 16 Jan 2024 19:11:28 +0100 Subject: [PATCH 0596/1174] reuse variable for log file path --- modules/ui_common.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/ui_common.py b/modules/ui_common.py index cbb2495d6c9..1db72092d59 100644 --- a/modules/ui_common.py +++ b/modules/ui_common.py @@ -107,7 +107,7 @@ def __init__(self, d=None): if os.path.exists(logfile_path): update_logfile(logfile_path, fields) - with open(os.path.join(shared.opts.outdir_save, "log.csv"), "a", encoding="utf8", newline='') as file: + with open(logfile_path, "a", encoding="utf8", newline='') as file: at_start = file.tell() == 0 writer = csv.writer(file) if at_start: From 4f9626703345ced77935e6bbb06de0b4522d53b7 Mon Sep 17 00:00:00 2001 From: Sj-Si Date: Tue, 16 Jan 2024 13:35:01 -0500 Subject: [PATCH 0597/1174] Finish cleanup. --- html/extra-networks-card-minimal.html | 4 - html/extra-networks-card.html | 10 +- html/extra-networks-edit-item-button.html | 2 +- html/extra-networks-metadata-button.html | 2 +- html/extra-networks-pane.html | 6 +- html/extra-networks-tree-button.html | 3 +- html/extra-networks-tree.html | 14 +- javascript/extraNetworks.js | 170 +++++++++++---------- modules/ui_extra_networks.py | 162 +++++++++++++++----- modules/ui_extra_networks_user_metadata.py | 2 +- style.css | 81 +++++++--- 11 files changed, 288 insertions(+), 168 deletions(-) delete mode 100644 html/extra-networks-card-minimal.html diff --git a/html/extra-networks-card-minimal.html b/html/extra-networks-card-minimal.html deleted file mode 100644 index d66df7dfb3f..00000000000 --- a/html/extra-networks-card-minimal.html +++ /dev/null @@ -1,4 +0,0 @@ -
          - {name} - {copy_path_button}{metadata_button}{edit_button} -
          diff --git a/html/extra-networks-card.html b/html/extra-networks-card.html index ca683dc47ec..f1d959a6733 100644 --- a/html/extra-networks-card.html +++ b/html/extra-networks-card.html @@ -1,9 +1,9 @@ -
          +
          {background_image}
          {copy_path_button}{metadata_button}{edit_button}
          -
          -
          {search_terms}
          - {name} - {description} +
          +
          {search_terms}
          + {name} + {description}
          diff --git a/html/extra-networks-edit-item-button.html b/html/extra-networks-edit-item-button.html index 7d2677d9f14..0fe43082ad1 100644 --- a/html/extra-networks-edit-item-button.html +++ b/html/extra-networks-edit-item-button.html @@ -1,4 +1,4 @@
          + onclick="extraNetworksEditUserMetadata(event, '{tabname}', '{extra_networks_tabname}', '{name}')">
          \ No newline at end of file diff --git a/html/extra-networks-metadata-button.html b/html/extra-networks-metadata-button.html index ad6d6f41662..285b5b3b658 100644 --- a/html/extra-networks-metadata-button.html +++ b/html/extra-networks-metadata-button.html @@ -1,4 +1,4 @@ \ No newline at end of file diff --git a/html/extra-networks-pane.html b/html/extra-networks-pane.html index 20cf6686755..bf46ca16305 100644 --- a/html/extra-networks-pane.html +++ b/html/extra-networks-pane.html @@ -1,8 +1,8 @@ -
          -
          +
          +
          {tree_html}
          -
          +
          {items_html}
          \ No newline at end of file diff --git a/html/extra-networks-tree-button.html b/html/extra-networks-tree-button.html index 20a9b0b863f..9dc2e2a40c8 100644 --- a/html/extra-networks-tree-button.html +++ b/html/extra-networks-tree-button.html @@ -1,8 +1,7 @@
          diff --git a/html/extra-networks-tree.html b/html/extra-networks-tree.html index 4d29b1be273..23f6af1058c 100644 --- a/html/extra-networks-tree.html +++ b/html/extra-networks-tree.html @@ -2,36 +2,36 @@
          diff --git a/javascript/extraNetworks.js b/javascript/extraNetworks.js index cf98452a296..a3f003bf475 100644 --- a/javascript/extraNetworks.js +++ b/javascript/extraNetworks.js @@ -31,25 +31,15 @@ function setupExtraNetworksForTab(tabname) { var this_tab = gradioApp().querySelector('#' + tabname + '_extra_tabs'); this_tab.classList.add('extra-networks'); this_tab.querySelectorAll(":scope > [id^='" + tabname + "_']").forEach(function(elem) { - var tab_id = elem.getAttribute("id"); - var search = gradioApp().querySelector("#" + tab_id + "_extra_search"); - if (!search) { - return; // `continue` doesn't work in `forEach` loops. This is equivalent. - } - - var sort = gradioApp().querySelector("#" + tab_id + "_extra_sort"); - if (!sort) { - return; // `continue` doesn't work in `forEach` loops. This is equivalent. - } - - var sort_dir = gradioApp().querySelector("#" + tab_id + "_extra_sort_dir"); - if (!sort_dir) { - return; // `continue` doesn't work in `forEach` loops. This is equivalent. - } - - var refresh = gradioApp().querySelector("#" + tab_id + "_extra_refresh"); - if (!refresh) { - return; // `continue` doesn't work in `forEach` loops. This is equivalent. + var extra_networks_tabname = elem.id; + var search = gradioApp().querySelector("#" + extra_networks_tabname + "_extra_search"); + var sort_mode = gradioApp().querySelector("#" + extra_networks_tabname + "_extra_sort"); + var sort_dir = gradioApp().querySelector("#" + extra_networks_tabname + "_extra_sort_dir"); + var refresh = gradioApp().querySelector("#" + extra_networks_tabname + "_extra_refresh"); + + // If any of the buttons above don't exist, we want to skip this iteration of the loop. + if (!search || !sort_mode || !sort_dir || !refresh) { + return; // `return` is equivalent of `continue` but for forEach loops. } var applyFilter = function() { @@ -78,14 +68,14 @@ function setupExtraNetworksForTab(tabname) { var cards = gradioApp().querySelectorAll('#' + tabname + '_extra_tabs div.card'); var reverse = sort_dir.dataset.sortdir == "Descending"; - var sortKey = sort.dataset.sortmode.toLowerCase().replace("sort", "").replaceAll(" ", "_").replace(/_+$/, "").trim() || "name"; + var sortKey = sort_mode.dataset.sortmode.toLowerCase().replace("sort", "").replaceAll(" ", "_").replace(/_+$/, "").trim() || "name"; sortKey = "sort" + sortKey.charAt(0).toUpperCase() + sortKey.slice(1); var sortKeyStore = sortKey + "-" + (reverse ? "Descending" : "Ascending") + "-" + cards.length; - if (sortKeyStore == sort.dataset.sortkey) { + if (sortKeyStore == sort_mode.dataset.sortkey) { return; } - sort.dataset.sortkey = sortKeyStore; + sort_mode.dataset.sortkey = sortKeyStore; cards.forEach(function(card) { card.originalParentElement = card.parentElement; @@ -115,8 +105,8 @@ function setupExtraNetworksForTab(tabname) { applySort(); applyFilter(); - extraNetworksApplySort[tab_id] = applySort; - extraNetworksApplyFilter[tab_id] = applyFilter; + extraNetworksApplySort[extra_networks_tabname] = applySort; + extraNetworksApplyFilter[extra_networks_tabname] = applyFilter; }); registerPrompt(tabname, tabname + "_prompt"); @@ -148,14 +138,6 @@ function extraNetworksMovePromptToTab(tabname, id, showPrompt, showNegativePromp } } -function clearSearch(tabname) { - // Clear search box. - var tab_id = tabname + "_extra_search"; - var searchTextarea = gradioApp().querySelector("#" + tab_id + ' > label > textarea'); - searchTextarea.value = ""; - updateInput(searchTextarea); -} - function extraNetworksUnrelatedTabSelected(tabname) { // called from python when user selects an unrelated tab (generate) extraNetworksMovePromptToTab(tabname, '', false, false); @@ -264,22 +246,20 @@ function saveCardPreview(event, tabname, filename) { event.preventDefault(); } -function extraNetworksTreeProcessFileClick(event, btn, tabname, tab_id) { +function extraNetworksTreeProcessFileClick(event, btn, tabname, extra_networks_tabname) { /** * Processes `onclick` events when user clicks on files in tree. * - * @param event The generated event. - * @param btn The clicked `tree-list-item` button. - * @param tabname The name of the active tab in the sd webui. Ex: txt2img, img2img, etc. - * @param tab_id The id of the active extraNetworks tab. Ex: lora, checkpoints, etc. + * @param event The generated event. + * @param btn The clicked `tree-list-item` button. + * @param tabname The name of the active tab in the sd webui. Ex: txt2img, img2img, etc. + * @param extra_networks_tabname The id of the active extraNetworks tab. Ex: lora, checkpoints, etc. */ - var par = btn.parentElement; - var search_id = tabname + "_" + tab_id + "_extra_search"; - var type = par.getAttribute("data-tree-entry-type"); - var path = btn.getAttribute("data-path"); + // NOTE: Currently unused. + return; } -function extraNetworksTreeProcessDirectoryClick(event, btn, tabname, tab_id) { +function extraNetworksTreeProcessDirectoryClick(event, btn, tabname, extra_networks_tabname) { /** * Processes `onclick` events when user clicks on directories in tree. * @@ -289,10 +269,10 @@ function extraNetworksTreeProcessDirectoryClick(event, btn, tabname, tab_id) { * selected opened directory: Directory is collapsed and deselected. * chevron is clicked: Directory is expanded or collapsed. Selected state unchanged. * - * @param event The generated event. - * @param btn The clicked `tree-list-item` button. - * @param tabname The name of the active tab in the sd webui. Ex: txt2img, img2img, etc. - * @param tab_id The id of the active extraNetworks tab. Ex: lora, checkpoints, etc. + * @param event The generated event. + * @param btn The clicked `tree-list-item` button. + * @param tabname The name of the active tab in the sd webui. Ex: txt2img, img2img, etc. + * @param extra_networks_tabname The id of the active extraNetworks tab. Ex: lora, checkpoints, etc. */ var ul = btn.nextElementSibling; // This is the actual target that the user clicked on within the target button. @@ -301,12 +281,12 @@ function extraNetworksTreeProcessDirectoryClick(event, btn, tabname, tab_id) { function _expand_or_collapse(_ul, _btn) { // Expands
            if it is collapsed, collapses otherwise. Updates button attributes. - if (_ul.hasAttribute("data-hidden")) { - _ul.removeAttribute("data-hidden"); - _btn.setAttribute("expanded", "true"); + if (_ul.hasAttribute("hidden")) { + _ul.removeAttribute("hidden"); + _btn.dataset.expanded = ""; } else { - _ul.setAttribute("data-hidden", ""); - _btn.setAttribute("expanded", "false"); + _ul.setAttribute("hidden", ""); + delete _btn.dataset.expanded; } } @@ -314,19 +294,19 @@ function extraNetworksTreeProcessDirectoryClick(event, btn, tabname, tab_id) { // Removes the `selected` attribute from all buttons. var sels = document.querySelectorAll("div.tree-list-content"); [...sels].forEach(el => { - el.removeAttribute("selected"); + delete el.dataset.selected; }); } function _select_button(_btn) { - // Removes `selected` attribute from all buttons then adds to passed button. + // Removes `data-selected` attribute from all buttons then adds to passed button. _remove_selected_from_all(); - _btn.setAttribute("selected", ""); + _btn.dataset.selected = ""; } - function _update_search(_tabname, _tab_id, _search_text) { + function _update_search(_tabname, _extra_networks_tabname, _search_text) { // Update search input with select button's path. - var search_input_elem = gradioApp().querySelector("#" + tabname + "_" + tab_id + "_extra_search"); + var search_input_elem = gradioApp().querySelector("#" + tabname + "_" + extra_networks_tabname + "_extra_search"); search_input_elem.value = _search_text; updateInput(search_input_elem); } @@ -337,48 +317,58 @@ function extraNetworksTreeProcessDirectoryClick(event, btn, tabname, tab_id) { _expand_or_collapse(ul, btn); } else { // User clicked anywhere else on the button. - if (btn.hasAttribute("selected") && !ul.hasAttribute("data-hidden")) { + if ("selected" in btn.dataset && !(ul.hasAttribute("hidden"))) { // If folder is select and open, collapse and deselect button. _expand_or_collapse(ul, btn); - btn.removeAttribute("selected"); - _update_search(tabname, tab_id, ""); - } else if (!(!btn.hasAttribute("selected") && !ul.hasAttribute("data-hidden"))) { + delete btn.dataset.selected; + _update_search(tabname, extra_networks_tabname, ""); + } else if (!(!("selected" in btn.dataset) && !(ul.hasAttribute("hidden")))) { // If folder is open and not selected, then we don't collapse; just select. // NOTE: Double inversion sucks but it is the clearest way to show the branching here. _expand_or_collapse(ul, btn); - _select_button(btn, tabname, tab_id); - _update_search(tabname, tab_id, btn.getAttribute("data-path")); + _select_button(btn, tabname, extra_networks_tabname); + _update_search(tabname, extra_networks_tabname, btn.dataset.path); } else { // All other cases, just select the button. - _select_button(btn, tabname, tab_id); - _update_search(tabname, tab_id, btn.getAttribute("data-path")); + _select_button(btn, tabname, extra_networks_tabname); + _update_search(tabname, extra_networks_tabname, btn.dataset.path); } } } -function extraNetworksTreeOnClick(event, tabname, tab_id) { +function extraNetworksTreeOnClick(event, tabname, extra_networks_tabname) { /** * Handles `onclick` events for buttons within an `extra-network-tree .tree-list--tree`. * * Determines whether the clicked button in the tree is for a file entry or a directory * then calls the appropriate function. * - * @param event The generated event. - * @param tabname The name of the active tab in the sd webui. Ex: txt2img, img2img, etc. - * @param tab_id The id of the active extraNetworks tab. Ex: lora, checkpoints, etc. + * @param event The generated event. + * @param tabname The name of the active tab in the sd webui. Ex: txt2img, img2img, etc. + * @param extra_networks_tabname The id of the active extraNetworks tab. Ex: lora, checkpoints, etc. */ var btn = event.currentTarget; var par = btn.parentElement; - if (par.getAttribute("data-tree-entry-type") === "file") { - extraNetworksTreeProcessFileClick(event, btn, tabname, tab_id); + if (par.dataset.treeEntryType === "file") { + extraNetworksTreeProcessFileClick(event, btn, tabname, extra_networks_tabname); } else { - extraNetworksTreeProcessDirectoryClick(event, btn, tabname, tab_id); + extraNetworksTreeProcessDirectoryClick(event, btn, tabname, extra_networks_tabname); } } -function extraNetworksTreeSortOnClick(event, tabname, tab_id) { +function extraNetworksTreeSortOnClick(event, tabname, extra_networks_tabname) { + /** + * Handles `onclick` events for the Sort Mode button. + * + * Modifies the data attributes of the Sort Mode button to cycle between + * various sorting modes. + * + * @param event The generated event. + * @param tabname The name of the active tab in the sd webui. Ex: txt2img, img2img, etc. + * @param extra_networks_tabname The id of the active extraNetworks tab. Ex: lora, checkpoints, etc. + */ var curr_mode = event.currentTarget.dataset.sortmode; - var el_sort_dir = gradioApp().querySelector("#" + tabname + "_" + tab_id + "_extra_sort_dir"); + var el_sort_dir = gradioApp().querySelector("#" + tabname + "_" + extra_networks_tabname + "_extra_sort_dir"); var sort_dir = el_sort_dir.dataset.sortdir; if (curr_mode == "path") { event.currentTarget.dataset.sortmode = "name"; @@ -397,23 +387,43 @@ function extraNetworksTreeSortOnClick(event, tabname, tab_id) { event.currentTarget.dataset.sortkey = "sortPath-" + sort_dir + "-640"; event.currentTarget.setAttribute("title", "Sort by path"); } - applyExtraNetworkSort(tabname + "_" + tab_id); + applyExtraNetworkSort(tabname + "_" + extra_networks_tabname); } -function extraNetworksTreeSortDirOnClick(event, tabname, tab_id) { - var curr_dir = event.currentTarget.getAttribute("data-sortdir"); - if (curr_dir == "Ascending") { +function extraNetworksTreeSortDirOnClick(event, tabname, extra_networks_tabname) { + /** + * Handles `onclick` events for the Sort Direction button. + * + * Modifies the data attributes of the Sort Direction button to cycle between + * ascending and descending sort directions. + * + * @param event The generated event. + * @param tabname The name of the active tab in the sd webui. Ex: txt2img, img2img, etc. + * @param extra_networks_tabname The id of the active extraNetworks tab. Ex: lora, checkpoints, etc. + */ + if (event.currentTarget.dataset.sortdir == "Ascending") { event.currentTarget.dataset.sortdir = "Descending"; event.currentTarget.setAttribute("title", "Sort descending"); } else { event.currentTarget.dataset.sortdir = "Ascending"; event.currentTarget.setAttribute("title", "Sort ascending"); } - applyExtraNetworkSort(tabname + "_" + tab_id); + applyExtraNetworkSort(tabname + "_" + extra_networks_tabname); } -function extraNetworksTreeRefreshOnClick(event, tabname, tab_id) { - console.log("refresh clicked"); +function extraNetworksTreeRefreshOnClick(event, tabname, extra_networks_tabname) { + /** + * Handles `onclick` events for the Refresh Page button. + * + * In order to actually call the python functions in `ui_extra_networks.py` + * to refresh the page, we created an empty gradio button in that file with an + * event handler that refreshes the page. So what this function here does + * is it manually raises a `click` event on that button. + * + * @param event The generated event. + * @param tabname The name of the active tab in the sd webui. Ex: txt2img, img2img, etc. + * @param extra_networks_tabname The id of the active extraNetworks tab. Ex: lora, checkpoints, etc. + */ var btn_refresh_internal = gradioApp().getElementById(tabname + "_extra_refresh_internal"); btn_refresh_internal.dispatchEvent(new Event("click")); } diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index 06fa22afada..a03207b2bc0 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -155,7 +155,14 @@ class ExtraNetworksPage: def __init__(self, title): self.title = title self.name = title.lower() - self.id_page = self.name.replace(" ", "_") + # This is the actual name of the extra networks tab (not txt2img/img2img). + self.extra_networks_tabname = self.name.replace(" ", "_") + self.allow_prompt = True + self.allow_negative_prompt = False + self.metadata = {} + self.items = {} + self.lister = util.MassFileLister() + # HTML Templates self.pane_tpl = shared.html("extra-networks-pane.html") self.tree_tpl = shared.html("extra-networks-tree.html") self.card_tpl = shared.html("extra-networks-card.html") @@ -163,11 +170,6 @@ def __init__(self, title): self.btn_copy_path_tpl = shared.html("extra-networks-copy-path-button.html") self.btn_metadata_tpl = shared.html("extra-networks-metadata-button.html") self.btn_edit_item_tpl = shared.html("extra-networks-edit-item-button.html") - self.allow_prompt = True - self.allow_negative_prompt = False - self.metadata = {} - self.items = {} - self.lister = util.MassFileLister() def refresh(self): pass @@ -202,15 +204,17 @@ def create_item_html( item: dict, template: Optional[str] = None, ) -> Union[str, dict]: - """Generates HTML for a single ExtraNetworks Item + """Generates HTML for a single ExtraNetworks Item. Args: tabname: The name of the active tab. item: Dictionary containing item information. + template: Optional template string to use. Returns: - HTML string generated for this item. - Can be empty if the item is not meant to be shown. + If a template is passed: HTML string generated for this item. + Can be empty if the item is not meant to be shown. + If no template is passed: A dictionary containing the generated item's attributes. """ metadata = item.get("metadata") if metadata: @@ -244,14 +248,14 @@ def create_item_html( if metadata: btn_metadata = self.btn_metadata_tpl.format( **{ - "page_id": self.id_page, + "extra_networks_tabname": self.extra_networks_tabname, "name": html.escape(item["name"]), } ) btn_edit_item = self.btn_edit_item_tpl.format( **{ "tabname": tabname, - "page_id": self.id_page, + "extra_networks_tabname": self.extra_networks_tabname, "name": html.escape(item["name"]), } ) @@ -307,9 +311,9 @@ def create_item_html( "search_only": " search_only" if search_only else "", "search_terms": search_terms_html, "sort_keys": sort_keys, - "style": f"'display: none; {height}{width}; font-size: {shared.opts.extra_networks_card_text_scale*100}%'", + "style": f"display: none; {height}{width}; font-size: {shared.opts.extra_networks_card_text_scale*100}%", "tabname": tabname, - "tab_id": self.id_page, + "extra_networks_tabname": self.extra_networks_tabname, } if template: @@ -317,7 +321,32 @@ def create_item_html( else: return args - def create_tree_dir_item_html(self, tabname: str, dir_path: str, content: Optional[str] = None) -> Optional[str]: + def create_tree_dir_item_html( + self, + tabname: str, + dir_path: str, + content: Optional[str] = None, + ) -> Optional[str]: + """Generates HTML for a directory item in the tree. + + The generated HTML is of the format: + ```html +
          • +
            +
              + {content} +
            +
          • + ``` + + Args: + tabname: The name of the active tab. + dir_path: Path to the directory for this item. + content: Optional HTML string that will be wrapped by this
              . + + Returns: + HTML formatted string. + """ if not content: return None @@ -326,7 +355,7 @@ def create_tree_dir_item_html(self, tabname: str, dir_path: str, content: Option "search_terms": "", "subclass": "tree-list-content-dir", "tabname": tabname, - "tab_id": self.id_page, + "extra_networks_tabname": self.extra_networks_tabname, "onclick_extra": "", "data_path": dir_path, "data_hash": "", @@ -337,10 +366,32 @@ def create_tree_dir_item_html(self, tabname: str, dir_path: str, content: Option "action_list_item_action_trailing": "", } ) - ul = f"
                {content}
              " - return f"
            • {btn + ul}
            • " + ul = f"" + return ( + "
            • " + f"{btn + ul}" + "
            • " + ) + + def create_tree_file_item_html(self, tabname: str, file_path: str, item: dict) -> str: + """Generates HTML for a file item in the tree. + + The generated HTML is of the format: + ```html +
            • + +
              +
            • + ``` - def create_tree_file_item_html(self, tabname: str, item_name: str, item: dict) -> str: + Args: + tabname: The name of the active tab. + file_path: The path to the file for this item. + item: Dictionary containing the item information. + + Returns: + HTML formatted string. + """ item_html_args = self.create_item_html(tabname, item) action_buttons = "".join( [ @@ -355,9 +406,9 @@ def create_tree_file_item_html(self, tabname: str, item_name: str, item: dict) - "search_terms": "", "subclass": "tree-list-content-file", "tabname": tabname, - "tab_id": self.id_page, + "extra_networks_tabname": self.extra_networks_tabname, "onclick_extra": item_html_args["card_clicked"], - "data_path": item_name, + "data_path": file_path, "data_hash": item["shorthash"], "action_list_item_action_leading": "", "action_list_item_visual_leading": "🗎", @@ -366,11 +417,17 @@ def create_tree_file_item_html(self, tabname: str, item_name: str, item: dict) - "action_list_item_action_trailing": action_buttons, } ) - return f"
            • {btn}
            • " + return ( + "
            • " + f"{btn}" + "
            • " + ) def create_tree_view_html(self, tabname: str) -> str: """Generates HTML for displaying folders in a tree view. + The generated HTML uses `extra-networks-tree.html` as a template. + Args: tabname: The name of the active tab. @@ -379,7 +436,7 @@ def create_tree_view_html(self, tabname: str) -> str: """ res = "" - # Generate HTML for the tree. + # Setup the tree dictionary. roots = self.allowed_directories_for_previews() tree_items = {v["filename"]: ExtraNetworksItem(v) for v in self.items.values()} tree = get_tree([os.path.abspath(x) for x in roots], items=tree_items) @@ -388,7 +445,17 @@ def create_tree_view_html(self, tabname: str) -> str: return res def _build_tree(data: Optional[dict[str, ExtraNetworksItem]] = None) -> Optional[str]: - """Recursively builds HTML for a tree.""" + """Recursively builds HTML for a tree. + + Args: + data: Dictionary representing a directory tree. Can be NoneType. + Data keys should be absolute paths from the root and values + should be subdirectory trees or an ExtraNetworksItem. + + Returns: + If data is not None: HTML string + Else: None + """ if not data: return None @@ -402,25 +469,36 @@ def _build_tree(data: Optional[dict[str, ExtraNetworksItem]] = None) -> Optional else: _dir_li.append(self.create_tree_dir_item_html(tabname, k, _build_tree(v))) - # Directories should always be displayed before files. + # Directories should always be displayed before files so we order them here. return "".join(_dir_li) + "".join(_file_li) # Add each root directory to the tree. for k, v in sorted(tree.items(), key=lambda x: shared.natural_sort_key(x[0])): - # If root is empty, append the "disabled" attribute to the template details tag. item_html = self.create_tree_dir_item_html(tabname, k, _build_tree(v)) - if item_html: + # Only add non-empty entries to the tree. + if item_html is not None: res += item_html return self.tree_tpl.format( **{ "tabname": tabname, - "tab_id": self.id_page, + "extra_networks_tabname": self.extra_networks_tabname, "tree": f"
                {res}
              " } ) - def create_card_view_html(self, tabname): + def create_card_view_html(self, tabname: str) -> str: + """Generates HTML for the network Card View section for a tab. + + This HTML goes into the `extra-networks-pane.html`
              with + `id='{tabname}_{extra_networks_tabname}_cards`. + + Args: + tabname: The name of the active tab. + + Returns: + HTML formatted string. + """ res = "" self.items = {x["name"]: x for x in self.list_items()} for item in self.items.values(): @@ -433,20 +511,26 @@ def create_card_view_html(self, tabname): return res def create_html(self, tabname): + """Generates an HTML string for the current pane. + + The generated HTML uses `extra-networks-pane.html` as a template. + + Args: + tabname: The name of the active tab. + + Returns: + HTML formatted string. + """ self.lister.reset() self.metadata = {} self.items = {x["name"]: x for x in self.list_items()} - tree_view_html = self.create_tree_view_html(tabname) - card_view_html = self.create_card_view_html(tabname) - network_type_id = self.id_page - return self.pane_tpl.format( **{ "tabname": tabname, - "network_type_id": network_type_id, - "tree_html": tree_view_html, - "items_html": card_view_html, + "extra_networks_tabname": self.extra_networks_tabname, + "tree_html": self.create_tree_view_html(tabname), + "items_html": self.create_card_view_html(tabname), } ) @@ -561,16 +645,16 @@ def create_ui(interface: gr.Blocks, unrelated_tabs, tabname): button_refresh = gr.Button("Refresh", elem_id=tabname+"_extra_refresh_internal", visible=False) for page in ui.stored_extra_pages: - with gr.Tab(page.title, elem_id=f"{tabname}_{page.id_page}", elem_classes=["extra-page"]) as tab: - with gr.Column(elem_id=f"{tabname}_{page.id_page}_prompts", elem_classes=["extra-page-prompts"]): + with gr.Tab(page.title, elem_id=f"{tabname}_{page.extra_networks_tabname}", elem_classes=["extra-page"]) as tab: + with gr.Column(elem_id=f"{tabname}_{page.extra_networks_tabname}_prompts", elem_classes=["extra-page-prompts"]): pass - elem_id = f"{tabname}_{page.id_page}_cards_html" + elem_id = f"{tabname}_{page.extra_networks_tabname}_cards_html" page_elem = gr.HTML('Loading...', elem_id=elem_id) ui.pages.append(page_elem) page_elem.change( fn=lambda: None, - _js=f"function(){{applyExtraNetworkFilter({tabname}_{page.id_page}_extra_search); return []}}", + _js=f"function(){{applyExtraNetworkFilter({tabname}_{page.extra_networks_tabname}_extra_search); return []}}", inputs=[], outputs=[], ) diff --git a/modules/ui_extra_networks_user_metadata.py b/modules/ui_extra_networks_user_metadata.py index 989a649b799..2ca937fd117 100644 --- a/modules/ui_extra_networks_user_metadata.py +++ b/modules/ui_extra_networks_user_metadata.py @@ -14,7 +14,7 @@ def __init__(self, ui, tabname, page): self.ui = ui self.tabname = tabname self.page = page - self.id_part = f"{self.tabname}_{self.page.id_page}_edit_user_metadata" + self.id_part = f"{self.tabname}_{self.page.extra_networks_tabname}_edit_user_metadata" self.box = None diff --git a/style.css b/style.css index 08573248654..1090e4368ab 100644 --- a/style.css +++ b/style.css @@ -879,13 +879,6 @@ footer { margin-bottom: 1em; } -.extra-network-pane{ - height: calc(100vh - 24rem); - overflow: clip scroll; - resize: vertical; - min-height: 52rem; -} - .extra-networks > div.tab-nav{ min-height: 3.4rem; } @@ -1182,23 +1175,63 @@ body.resizing .resize-handle { /* ========================= */ .extra-network-pane { display: flex; -} - -.extra-network-pane .extra-network-cards { - display: block; + height: calc(100vh - 24rem); + resize: vertical; + min-height: 52rem; } .extra-network-pane .extra-network-tree { - display: block; + flex: 1; + flex-direction: column; + display: flex; font-size: 1rem; - min-width: 25%; border: 1px solid var(--block-border-color); - overflow: hidden; } -.extra-network-tree .tree-list { - margin: 0 0.25rem; +.extra-network-pane .extra-network-cards { + flex: 3; + overflow: clip auto !important; + border: 1px solid var(--block-border-color); +} + +.extra-network-pane .extra-network-tree .tree-list { + flex: 1; + display: flex; + flex-direction: column; padding: 0; + width: 100%; + overflow: hidden; +} + +.extra-network-pane .extra-network-tree .tree-list .tree-list-container { + flex: 1; + overflow: clip auto !important; + width: 100%; +} + + +.extra-network-pane .extra-network-cards::-webkit-scrollbar, +.extra-network-pane .tree-list-container::-webkit-scrollbar { + background-color: transparent; + width: 16px; +} + +.extra-network-pane .extra-network-cards::-webkit-scrollbar-track, +.extra-network-pane .tree-list-container::-webkit-scrollbar-track { + background-color: transparent; + background-clip: content-box; +} + +.extra-network-pane .extra-network-cards::-webkit-scrollbar-thumb, +.extra-network-pane .tree-list-container::-webkit-scrollbar-thumb { + background-color: var(--border-color-primary); + border-radius: 16px; + border: 4px solid var(--background-fill-primary); +} + +.extra-network-pane .extra-network-cards::-webkit-scrollbar-button, +.extra-network-pane .tree-list-container::-webkit-scrollbar-button { + display: none; } .extra-network-tree .tree-list .tree-list-controls { @@ -1244,17 +1277,15 @@ body.resizing .resize-handle { background-color: transparent; } -/* Directory
                visibility based on expanded attribute. */ -.extra-network-tree .tree-list-content[expanded=false]+.tree-list--subgroup { +/* Directory
                  visibility based on data-expanded attribute. */ +.extra-network-tree .tree-list-content+.tree-list--subgroup { height: 0; - overflow: hidden; visibility: hidden; opacity: 0; } -.extra-network-tree .tree-list-content[expanded=true]+.tree-list--subgroup { +.extra-network-tree .tree-list-content[data-expanded]+.tree-list--subgroup { height: auto; - overflow: visible; visibility: visible; opacity: 1; } @@ -1307,7 +1338,7 @@ body.resizing .resize-handle { background-color: var(--neutral-800); } -.dark .extra-network-tree div.tree-list-content[selected] { +.dark .extra-network-tree div.tree-list-content[data-selected] { background-color: var(--neutral-700); } @@ -1317,20 +1348,20 @@ body.resizing .resize-handle { background-color: var(--neutral-200); } -.extra-network-tree div.tree-list-content[selected] { +.extra-network-tree div.tree-list-content[data-selected] { background-color: var(--neutral-300); } /* ==== CHEVRON ICON ACTIONS ==== */ /* Define the animation for the arrow when it is clicked. */ -.extra-network-tree .tree-list-content-dir[expanded=false] .tree-list-item-action-chevron { +.extra-network-tree .tree-list-content-dir .tree-list-item-action-chevron { -ms-transform: rotate(135deg); -webkit-transform: rotate(135deg); transform: rotate(135deg); transition: transform 0.2s; } -.extra-network-tree .tree-list-content-dir[expanded=true] .tree-list-item-action-chevron { +.extra-network-tree .tree-list-content-dir[data-expanded] .tree-list-item-action-chevron { -ms-transform: rotate(225deg); -webkit-transform: rotate(225deg); transform: rotate(225deg); From ccee26b0653b4f6778c107d68df52da27446abd2 Mon Sep 17 00:00:00 2001 From: Sj-Si Date: Tue, 16 Jan 2024 14:54:07 -0500 Subject: [PATCH 0598/1174] fix bugs --- javascript/extraNetworks.js | 28 ++++++++----------- modules/ui_extra_networks.py | 35 ++++++++++++------------ modules/ui_extra_networks_checkpoints.py | 3 +- 3 files changed, 31 insertions(+), 35 deletions(-) diff --git a/javascript/extraNetworks.js b/javascript/extraNetworks.js index a3f003bf475..caaa3fae0d2 100644 --- a/javascript/extraNetworks.js +++ b/javascript/extraNetworks.js @@ -31,11 +31,12 @@ function setupExtraNetworksForTab(tabname) { var this_tab = gradioApp().querySelector('#' + tabname + '_extra_tabs'); this_tab.classList.add('extra-networks'); this_tab.querySelectorAll(":scope > [id^='" + tabname + "_']").forEach(function(elem) { - var extra_networks_tabname = elem.id; - var search = gradioApp().querySelector("#" + extra_networks_tabname + "_extra_search"); - var sort_mode = gradioApp().querySelector("#" + extra_networks_tabname + "_extra_sort"); - var sort_dir = gradioApp().querySelector("#" + extra_networks_tabname + "_extra_sort_dir"); - var refresh = gradioApp().querySelector("#" + extra_networks_tabname + "_extra_refresh"); + // tabname_full = {tabname}_{extra_networks_tabname} + var tabname_full = elem.id; + var search = gradioApp().querySelector("#" + tabname_full + "_extra_search"); + var sort_mode = gradioApp().querySelector("#" + tabname_full + "_extra_sort"); + var sort_dir = gradioApp().querySelector("#" + tabname_full + "_extra_sort_dir"); + var refresh = gradioApp().querySelector("#" + tabname_full + "_extra_refresh"); // If any of the buttons above don't exist, we want to skip this iteration of the loop. if (!search || !sort_mode || !sort_dir || !refresh) { @@ -44,16 +45,13 @@ function setupExtraNetworksForTab(tabname) { var applyFilter = function() { var searchTerm = search.value.toLowerCase(); - gradioApp().querySelectorAll('#' + tabname + '_extra_tabs div.card').forEach(function(elem) { var searchOnly = elem.querySelector('.search_only'); - var text = Array.prototype.map.call(elem.querySelectorAll('.search_terms'), function(t) { return t.textContent.toLowerCase(); }).join(" "); var visible = text.indexOf(searchTerm) != -1; - if (searchOnly && searchTerm.length < 4) { visible = false; } @@ -66,7 +64,6 @@ function setupExtraNetworksForTab(tabname) { var applySort = function() { var cards = gradioApp().querySelectorAll('#' + tabname + '_extra_tabs div.card'); - var reverse = sort_dir.dataset.sortdir == "Descending"; var sortKey = sort_mode.dataset.sortmode.toLowerCase().replace("sort", "").replaceAll(" ", "_").replace(/_+$/, "").trim() || "name"; sortKey = "sort" + sortKey.charAt(0).toUpperCase() + sortKey.slice(1); @@ -104,9 +101,8 @@ function setupExtraNetworksForTab(tabname) { search.addEventListener("input", applyFilter); applySort(); applyFilter(); - - extraNetworksApplySort[extra_networks_tabname] = applySort; - extraNetworksApplyFilter[extra_networks_tabname] = applyFilter; + extraNetworksApplySort[tabname_full] = applySort; + extraNetworksApplyFilter[tabname_full] = applyFilter; }); registerPrompt(tabname, tabname + "_prompt"); @@ -147,12 +143,12 @@ function extraNetworksTabSelected(tabname, id, showPrompt, showNegativePrompt) { extraNetworksMovePromptToTab(tabname, id, showPrompt, showNegativePrompt); } -function applyExtraNetworkFilter(tabname) { - setTimeout(extraNetworksApplyFilter[tabname], 1); +function applyExtraNetworkFilter(tabname_full) { + setTimeout(extraNetworksApplyFilter[tabname_full], 1); } -function applyExtraNetworkSort(tabname) { - setTimeout(extraNetworksApplySort[tabname], 1); +function applyExtraNetworkSort(tabname_full) { + setTimeout(extraNetworksApplySort[tabname_full], 1); } var extraNetworksApplyFilter = {}; diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index a03207b2bc0..55cd1da2772 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -237,7 +237,7 @@ def create_item_html( "tabname": tabname, "prompt": item["prompt"], "neg_prompt": item.get("negative_prompt", ""), - "allow_neg": "true" if self.allow_negative_prompt else "false" + "allow_neg": str(self.allow_negative_prompt).lower(), } ) onclick = html.escape(onclick) @@ -291,7 +291,7 @@ def create_item_html( search_terms_html += search_term_template.format( **{ "style": "display: none;", - "class": "search_terms" + (" search_only" if search_only else ""), + "class": f"search_terms{' search_only' if search_only else ''}", "search_term": search_term, } ) @@ -307,7 +307,7 @@ def create_item_html( "metadata_button": btn_metadata, "name": html.escape(item["name"]), "prompt": item.get("prompt", None), - "save_card_preview": '"' + html.escape(f"""return saveCardPreview(event, {quote_js(tabname)}, {quote_js(item["local_preview"])})""") + '"', + "save_card_preview": html.escape(f"return saveCardPreview(event, '{tabname}', '{item['local_preview']}');"), "search_only": " search_only" if search_only else "", "search_terms": search_terms_html, "sort_keys": sort_keys, @@ -369,7 +369,7 @@ def create_tree_dir_item_html( ul = f"" return ( "
                • " - f"{btn + ul}" + f"{btn}{ul}" "
                • " ) @@ -561,7 +561,7 @@ def find_preview(self, path): Find a preview PNG for a given path (without extension) and call link_preview on it. """ - potential_files = sum([[path + "." + ext, path + ".preview." + ext] for ext in allowed_preview_extensions()], []) + potential_files = sum([[f"{path}.{ext}", f"{path}.preview.{ext}"] for ext in allowed_preview_extensions()], []) for file in potential_files: if self.lister.exists(file): @@ -642,7 +642,7 @@ def create_ui(interface: gr.Blocks, unrelated_tabs, tabname): related_tabs = [] - button_refresh = gr.Button("Refresh", elem_id=tabname+"_extra_refresh_internal", visible=False) + button_refresh = gr.Button("Refresh", elem_id=f"{tabname}_extra_refresh_internal", visible=False) for page in ui.stored_extra_pages: with gr.Tab(page.title, elem_id=f"{tabname}_{page.extra_networks_tabname}", elem_classes=["extra-page"]) as tab: @@ -652,24 +652,25 @@ def create_ui(interface: gr.Blocks, unrelated_tabs, tabname): elem_id = f"{tabname}_{page.extra_networks_tabname}_cards_html" page_elem = gr.HTML('Loading...', elem_id=elem_id) ui.pages.append(page_elem) - page_elem.change( - fn=lambda: None, - _js=f"function(){{applyExtraNetworkFilter({tabname}_{page.extra_networks_tabname}_extra_search); return []}}", - inputs=[], - outputs=[], - ) - editor = page.create_user_metadata_editor(ui, tabname) editor.create_ui() ui.user_metadata_editors.append(editor) - related_tabs.append(tab) - ui.button_save_preview = gr.Button('Save preview', elem_id=tabname+"_save_preview", visible=False) - ui.preview_target_filename = gr.Textbox('Preview save filename', elem_id=tabname+"_preview_filename", visible=False) + ui.button_save_preview = gr.Button('Save preview', elem_id=f"{tabname}_save_preview", visible=False) + ui.preview_target_filename = gr.Textbox('Preview save filename', elem_id=f"{tabname}_preview_filename", visible=False) for tab in unrelated_tabs: - tab.select(fn=None, _js='function(){ extraNetworksUrelatedTabSelected("' + tabname + '"); }', inputs=[], outputs=[], show_progress=False) + tab.select(fn=None, _js=f"function(){{extraNetworksUnrelatedTabSelected('{tabname}');}}", inputs=[], outputs=[], show_progress=False) + + for page, tab in zip(ui.stored_extra_pages, related_tabs): + jscode = ( + "function(){{" + f"extraNetworksTabSelected('{tabname}', '{tabname}_{page.extra_networks_tabname}_prompts', {str(page.allow_prompt).lower()}, {str(page.allow_negative_prompt).lower()});" + f"applyExtraNetworkFilter('{tabname}_{page.extra_networks_tabname}');" + "}}" + ) + tab.select(fn=None, _js=jscode, inputs=[], outputs=[], show_progress=False) def create_html(): ui.pages_contents = [pg.create_html(ui.tabname) for pg in ui.stored_extra_pages] diff --git a/modules/ui_extra_networks_checkpoints.py b/modules/ui_extra_networks_checkpoints.py index e7976ba1260..a8c3367198f 100644 --- a/modules/ui_extra_networks_checkpoints.py +++ b/modules/ui_extra_networks_checkpoints.py @@ -2,7 +2,6 @@ import os from modules import shared, ui_extra_networks, sd_models -from modules.ui_extra_networks import quote_js from modules.ui_extra_networks_checkpoints_user_metadata import CheckpointUserMetadataEditor @@ -31,7 +30,7 @@ def create_item(self, name, index=None, enable_filter=True): "preview": self.find_preview(path), "description": self.find_description(path), "search_terms": search_terms, - "onclick": '"' + html.escape(f"""return selectCheckpoint({quote_js(name)})""") + '"', + "onclick": html.escape(f"return selectCheckpoint('{name}');"), "local_preview": f"{path}.{shared.opts.samples_format}", "metadata": checkpoint.metadata, "sort_keys": {'default': index, **self.get_sort_keys(checkpoint.filename)}, From 0b83f4c26376c87505769a3687cb06e321b413a1 Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Sat, 13 Jan 2024 18:38:05 +0900 Subject: [PATCH 0599/1174] reuse seed from infotexts --- modules/processing_scripts/seed.py | 19 ++++++------------- 1 file changed, 6 insertions(+), 13 deletions(-) diff --git a/modules/processing_scripts/seed.py b/modules/processing_scripts/seed.py index 2d3cbb97f06..7b727d17cb1 100644 --- a/modules/processing_scripts/seed.py +++ b/modules/processing_scripts/seed.py @@ -6,6 +6,7 @@ from modules.infotext_utils import PasteField from modules.shared import cmd_opts from modules.ui_components import ToolButton +from modules import infotext_utils class ScriptSeed(scripts.ScriptBuiltinUI): @@ -77,7 +78,6 @@ def setup(self, p, seed, seed_checkbox, subseed, subseed_strength, seed_resize_f p.seed_resize_from_h = seed_resize_from_h - def connect_reuse_seed(seed: gr.Number, reuse_seed: gr.Button, generation_info: gr.Textbox, is_subseed): """ Connects a 'reuse (sub)seed' button's click event so that it copies last used (sub)seed value from generation info the to the seed field. If copying subseed and subseed strength @@ -85,21 +85,14 @@ def connect_reuse_seed(seed: gr.Number, reuse_seed: gr.Button, generation_info: def copy_seed(gen_info_string: str, index): res = -1 - try: gen_info = json.loads(gen_info_string) - index -= gen_info.get('index_of_first_image', 0) - - if is_subseed and gen_info.get('subseed_strength', 0) > 0: - all_subseeds = gen_info.get('all_subseeds', [-1]) - res = all_subseeds[index if 0 <= index < len(all_subseeds) else 0] - else: - all_seeds = gen_info.get('all_seeds', [-1]) - res = all_seeds[index if 0 <= index < len(all_seeds) else 0] - - except json.decoder.JSONDecodeError: + infotext = gen_info.get('infotexts')[index] + gen_parameters = infotext_utils.parse_generation_parameters(infotext) + res = int(gen_parameters.get('Variation seed' if is_subseed else 'Seed', -1)) + except Exception: if gen_info_string: - errors.report(f"Error parsing JSON generation info: {gen_info_string}") + errors.report(f"Error retrieving seed from generation info: {gen_info_string}", exc_info=True) return [res, gr.update()] From 6acd8e28fceccbbea0c09c91aed0eff5a3504315 Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Sun, 14 Jan 2024 01:58:01 +0900 Subject: [PATCH 0600/1174] save_files info base on infotexts --- modules/ui_common.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/modules/ui_common.py b/modules/ui_common.py index f17259c2956..6a4465e9358 100644 --- a/modules/ui_common.py +++ b/modules/ui_common.py @@ -40,8 +40,9 @@ def save_files(js_data, images, do_make_zip, index): import csv filenames = [] fullfns = [] + parsed_infotexts = [] - #quick dictionary to class object conversion. Its necessary due apply_filename_pattern requiring it + # quick dictionary to class object conversion. Its necessary due apply_filename_pattern requiring it class MyObject: def __init__(self, d=None): if d is not None: @@ -49,16 +50,14 @@ def __init__(self, d=None): setattr(self, key, value) data = json.loads(js_data) - p = MyObject(data) + path = shared.opts.outdir_save save_to_dirs = shared.opts.use_save_to_dirs_for_ui extension: str = shared.opts.samples_format start_index = 0 - only_one = False if index > -1 and shared.opts.save_selected_only and (index >= data["index_of_first_image"]): # ensures we are looking at a specific non-grid picture, and we have save_selected_only - only_one = True images = [images[index]] start_index = index @@ -74,10 +73,12 @@ def __init__(self, d=None): image = image_from_url_text(filedata) is_grid = image_index < p.index_of_first_image - i = 0 if is_grid else (image_index - p.index_of_first_image) p.batch_index = image_index-1 - fullfn, txt_fullfn = modules.images.save_image(image, path, "", seed=p.all_seeds[i], prompt=p.all_prompts[i], extension=extension, info=p.infotexts[image_index], grid=is_grid, p=p, save_to_dirs=save_to_dirs) + + parameters = parameters_copypaste.parse_generation_parameters(data["infotexts"][image_index]) + parsed_infotexts.append(parameters) + fullfn, txt_fullfn = modules.images.save_image(image, path, "", seed=parameters['Seed'], prompt=parameters['Prompt'], extension=extension, info=p.infotexts[image_index], grid=is_grid, p=p, save_to_dirs=save_to_dirs) filename = os.path.relpath(fullfn, path) filenames.append(filename) @@ -86,12 +87,12 @@ def __init__(self, d=None): filenames.append(os.path.basename(txt_fullfn)) fullfns.append(txt_fullfn) - writer.writerow([data["prompt"], data["seed"], data["width"], data["height"], data["sampler_name"], data["cfg_scale"], data["steps"], filenames[0], data["negative_prompt"]]) + writer.writerow([parsed_infotexts[0]['Prompt'], parsed_infotexts[0]['Seed'], data["width"], data["height"], data["sampler_name"], data["cfg_scale"], data["steps"], filenames[0], parsed_infotexts[0]['Negative prompt']]) # Make Zip if do_make_zip: - zip_fileseed = p.all_seeds[index-1] if only_one else p.all_seeds[0] - namegen = modules.images.FilenameGenerator(p, zip_fileseed, p.all_prompts[0], image, True) + p.all_seeds = [parameters['Seed'] for parameters in parsed_infotexts] + namegen = modules.images.FilenameGenerator(p, parsed_infotexts[0]['Seed'], parsed_infotexts[0]['Prompt'], image, True) zip_filename = namegen.apply(shared.opts.grid_zip_filename_pattern or "[datetime]_[[model_name]]_[seed]-[seed_last]") zip_filepath = os.path.join(path, f"{zip_filename}.zip") From d224fed0ce16e6f1507b0da29e7495ab2b0035e4 Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Tue, 16 Jan 2024 20:16:07 +0900 Subject: [PATCH 0601/1174] parse_generation_parameters skip_fields --- modules/infotext_utils.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/modules/infotext_utils.py b/modules/infotext_utils.py index 9a02cdf2410..1049c6c3c59 100644 --- a/modules/infotext_utils.py +++ b/modules/infotext_utils.py @@ -230,7 +230,7 @@ def restore_old_hires_fix_params(res): res['Hires resize-2'] = height -def parse_generation_parameters(x: str): +def parse_generation_parameters(x: str, skip_fields: list[str] | None = None): """parses generation parameters string, the one you see in text field under the picture in UI: ``` girl with an artist's beret, determined, blue eyes, desert scene, computer monitors, heavy makeup, by Alphonse Mucha and Charlie Bowater, ((eyeshadow)), (coquettish), detailed, intricate @@ -240,6 +240,8 @@ def parse_generation_parameters(x: str): returns a dict with field values """ + if skip_fields is None: + skip_fields = shared.opts.infotext_skip_pasting res = {} @@ -356,8 +358,8 @@ def parse_generation_parameters(x: str): infotext_versions.backcompat(res) - skip = set(shared.opts.infotext_skip_pasting) - res = {k: v for k, v in res.items() if k not in skip} + for key in skip_fields: + res.pop(key, None) return res From 45a51c07e2ce5a36dc3cddde9a666a4c7b61af82 Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Tue, 16 Jan 2024 20:21:58 +0900 Subject: [PATCH 0602/1174] parse_generation_parameters with no skip_fields --- modules/processing_scripts/seed.py | 2 +- modules/ui_common.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/modules/processing_scripts/seed.py b/modules/processing_scripts/seed.py index 7b727d17cb1..7a4c0159831 100644 --- a/modules/processing_scripts/seed.py +++ b/modules/processing_scripts/seed.py @@ -88,7 +88,7 @@ def copy_seed(gen_info_string: str, index): try: gen_info = json.loads(gen_info_string) infotext = gen_info.get('infotexts')[index] - gen_parameters = infotext_utils.parse_generation_parameters(infotext) + gen_parameters = infotext_utils.parse_generation_parameters(infotext, []) res = int(gen_parameters.get('Variation seed' if is_subseed else 'Seed', -1)) except Exception: if gen_info_string: diff --git a/modules/ui_common.py b/modules/ui_common.py index 6a4465e9358..6d7f3a67574 100644 --- a/modules/ui_common.py +++ b/modules/ui_common.py @@ -76,7 +76,7 @@ def __init__(self, d=None): p.batch_index = image_index-1 - parameters = parameters_copypaste.parse_generation_parameters(data["infotexts"][image_index]) + parameters = parameters_copypaste.parse_generation_parameters(data["infotexts"][image_index], []) parsed_infotexts.append(parameters) fullfn, txt_fullfn = modules.images.save_image(image, path, "", seed=parameters['Seed'], prompt=parameters['Prompt'], extension=extension, info=p.infotexts[image_index], grid=is_grid, p=p, save_to_dirs=save_to_dirs) From 6916de5c0bd8df3835d450caa3327d1924db081c Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Tue, 16 Jan 2024 20:16:07 +0900 Subject: [PATCH 0603/1174] parse_generation_parameters skip_fields --- modules/infotext_utils.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/modules/infotext_utils.py b/modules/infotext_utils.py index 9a02cdf2410..1049c6c3c59 100644 --- a/modules/infotext_utils.py +++ b/modules/infotext_utils.py @@ -230,7 +230,7 @@ def restore_old_hires_fix_params(res): res['Hires resize-2'] = height -def parse_generation_parameters(x: str): +def parse_generation_parameters(x: str, skip_fields: list[str] | None = None): """parses generation parameters string, the one you see in text field under the picture in UI: ``` girl with an artist's beret, determined, blue eyes, desert scene, computer monitors, heavy makeup, by Alphonse Mucha and Charlie Bowater, ((eyeshadow)), (coquettish), detailed, intricate @@ -240,6 +240,8 @@ def parse_generation_parameters(x: str): returns a dict with field values """ + if skip_fields is None: + skip_fields = shared.opts.infotext_skip_pasting res = {} @@ -356,8 +358,8 @@ def parse_generation_parameters(x: str): infotext_versions.backcompat(res) - skip = set(shared.opts.infotext_skip_pasting) - res = {k: v for k, v in res.items() if k not in skip} + for key in skip_fields: + res.pop(key, None) return res From e1dfd452c0447e729a7341c434b4aab6063aa654 Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Tue, 16 Jan 2024 20:27:29 +0900 Subject: [PATCH 0604/1174] parse_generation_parameters with no skip_fields --- modules/txt2img.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/txt2img.py b/modules/txt2img.py index e617cb1c576..92a160d61da 100644 --- a/modules/txt2img.py +++ b/modules/txt2img.py @@ -70,7 +70,7 @@ def txt2img_upscale(id_task: str, request: gr.Request, gallery, gallery_index, g image_info = gallery[gallery_index] if 0 <= gallery_index < len(gallery) else gallery[0] p.firstpass_image = infotext_utils.image_from_url_text(image_info) - parameters = parse_generation_parameters(geninfo.get('infotexts')[gallery_index]) + parameters = parse_generation_parameters(geninfo.get('infotexts')[gallery_index], []) p.seed = parameters.get('Seed', -1) p.subseed = parameters.get('Variation seed', -1) From 2cf23099ebb81832a27c8016f14062885f5a9c98 Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Thu, 18 Jan 2024 04:44:21 +0900 Subject: [PATCH 0605/1174] fix console total progress bar when using txt2img_upscale add p.txt2img_upscale as indicator --- modules/processing.py | 7 +++++-- modules/txt2img.py | 1 + 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/modules/processing.py b/modules/processing.py index dcc807fe38d..547ab2f86a0 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -1227,8 +1227,11 @@ def init(self, all_prompts, all_seeds, all_subseeds): if not state.processing_has_refined_job_count: if state.job_count == -1: state.job_count = self.n_iter - - shared.total_tqdm.updateTotal((self.steps + (self.hr_second_pass_steps or self.steps)) * state.job_count) + if getattr(self, 'txt2img_upscale', False): + total_steps = (self.hr_second_pass_steps or self.steps) * state.job_count + else: + total_steps = (self.steps + (self.hr_second_pass_steps or self.steps)) * state.job_count + shared.total_tqdm.updateTotal(total_steps) state.job_count = state.job_count * 2 state.processing_has_refined_job_count = True diff --git a/modules/txt2img.py b/modules/txt2img.py index 92a160d61da..4efcb4c3d71 100644 --- a/modules/txt2img.py +++ b/modules/txt2img.py @@ -64,6 +64,7 @@ def txt2img_upscale(id_task: str, request: gr.Request, gallery, gallery_index, g p.enable_hr = True p.batch_size = 1 p.n_iter = 1 + p.txt2img_upscale = True geninfo = json.loads(generation_info) From f25c81a74462554890ac7327a30629b332db1084 Mon Sep 17 00:00:00 2001 From: Sj-Si Date: Wed, 17 Jan 2024 22:38:51 -0500 Subject: [PATCH 0606/1174] Fix embeddings add/remove to/from prompt on click bugs. --- javascript/extraNetworks.js | 13 +++---------- modules/ui_extra_networks.py | 2 +- 2 files changed, 4 insertions(+), 11 deletions(-) diff --git a/javascript/extraNetworks.js b/javascript/extraNetworks.js index caaa3fae0d2..1e2786ab0d2 100644 --- a/javascript/extraNetworks.js +++ b/javascript/extraNetworks.js @@ -169,8 +169,8 @@ function tryToRemoveExtraNetworkFromPrompt(textarea, text, isNeg) { var m = text.match(isNeg ? re_extranet_neg : re_extranet); var replaced = false; var newTextareaText; + var extraTextBeforeNet = opts.extra_networks_add_text_separator; if (m) { - var extraTextBeforeNet = opts.extra_networks_add_text_separator; var extraTextAfterNet = m[2]; var partToSearch = m[1]; var foundAtPosition = -1; @@ -183,7 +183,6 @@ function tryToRemoveExtraNetworkFromPrompt(textarea, text, isNeg) { } return found; }); - if (foundAtPosition >= 0) { if (extraTextAfterNet && newTextareaText.substr(foundAtPosition, extraTextAfterNet.length) == extraTextAfterNet) { newTextareaText = newTextareaText.substr(0, foundAtPosition) + newTextareaText.substr(foundAtPosition + extraTextAfterNet.length); @@ -193,13 +192,8 @@ function tryToRemoveExtraNetworkFromPrompt(textarea, text, isNeg) { } } } else { - newTextareaText = textarea.value.replaceAll(new RegExp(text, "g"), function(found) { - if (found == text) { - replaced = true; - return ""; - } - return found; - }); + newTextareaText = textarea.value.replaceAll(new RegExp(`((?:${extraTextBeforeNet})?${text})`, "g"), ""); + replaced = (newTextareaText != textarea.value); } if (replaced) { @@ -211,7 +205,6 @@ function tryToRemoveExtraNetworkFromPrompt(textarea, text, isNeg) { } function updatePromptArea(text, textArea, isNeg) { - if (!tryToRemoveExtraNetworkFromPrompt(textArea, text, isNeg)) { textArea.value = textArea.value + opts.extra_networks_add_text_separator + text; } diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index 55cd1da2772..5dd4e443f8e 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -236,7 +236,7 @@ def create_item_html( **{ "tabname": tabname, "prompt": item["prompt"], - "neg_prompt": item.get("negative_prompt", ""), + "neg_prompt": item.get("negative_prompt", "''"), "allow_neg": str(self.allow_negative_prompt).lower(), } ) From 0181c1f76b97162c42401f1e6286ae73d8aa6033 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Fri, 19 Jan 2024 00:14:03 +0800 Subject: [PATCH 0607/1174] Fix nested manual cast --- modules/devices.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/modules/devices.py b/modules/devices.py index 0321d12c6ab..3702862979c 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -165,6 +165,8 @@ def forward_wrapper(self, *args, **kwargs): @contextlib.contextmanager def manual_cast(target_dtype): for module_type in patch_module_list: + if hasattr(module_type, "org_forward"): + continue org_forward = module_type.forward if module_type == torch.nn.MultiheadAttention and has_xpu(): module_type.forward = manual_cast_forward(torch.float32) @@ -175,7 +177,9 @@ def manual_cast(target_dtype): yield None finally: for module_type in patch_module_list: - module_type.forward = module_type.org_forward + if hasattr(module_type, "org_forward"): + module_type.forward = module_type.org_forward + delattr(module_type, "org_forward") def autocast(disable=False): From 50e444fa1d1c779b0c9c915e44ed2d78b587f277 Mon Sep 17 00:00:00 2001 From: Sj-Si Date: Thu, 18 Jan 2024 12:13:09 -0500 Subject: [PATCH 0608/1174] Fix missing important style. --- style.css | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/style.css b/style.css index 1090e4368ab..57c52354a0b 100644 --- a/style.css +++ b/style.css @@ -28,7 +28,7 @@ div.gradio-container{ } .hidden{ - display: none; + display: none !important; } .compact{ From 69f4f148dce0868b748b700f96942d4036e848c9 Mon Sep 17 00:00:00 2001 From: Sj-Si Date: Thu, 18 Jan 2024 12:13:33 -0500 Subject: [PATCH 0609/1174] Fix various bugs including refresh bug. --- javascript/extraNetworks.js | 7 +++++-- modules/ui_extra_networks.py | 31 ++++++++++++++++--------------- 2 files changed, 21 insertions(+), 17 deletions(-) diff --git a/javascript/extraNetworks.js b/javascript/extraNetworks.js index 1e2786ab0d2..3029afec847 100644 --- a/javascript/extraNetworks.js +++ b/javascript/extraNetworks.js @@ -55,8 +55,11 @@ function setupExtraNetworksForTab(tabname) { if (searchOnly && searchTerm.length < 4) { visible = false; } - - elem.style.display = visible ? "" : "none"; + if (visible) { + elem.classList.remove("hidden"); + } else { + elem.classList.add("hidden"); + } }); applySort(); diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index 5dd4e443f8e..656e7f181e4 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -216,22 +216,17 @@ def create_item_html( Can be empty if the item is not meant to be shown. If no template is passed: A dictionary containing the generated item's attributes. """ - metadata = item.get("metadata") - if metadata: - self.metadata[item["name"]] = metadata - - if "user_metadata" not in item: - self.read_user_metadata(item) - preview = item.get("preview", None) - height = f"height: {shared.opts.extra_networks_card_height}px;" if shared.opts.extra_networks_card_height else '' - width = f"width: {shared.opts.extra_networks_card_width}px;" if shared.opts.extra_networks_card_width else '' + style_height = f"height: {shared.opts.extra_networks_card_height}px;" if shared.opts.extra_networks_card_height else '' + style_width = f"width: {shared.opts.extra_networks_card_width}px;" if shared.opts.extra_networks_card_width else '' + style_font_size = f"font-size: {shared.opts.extra_networks_card_text_scale*100}%;" + card_style = style_height + style_width + style_font_size background_image = f'' if preview else '' onclick = item.get("onclick", None) if onclick is None: # Don't quote prompt/neg_prompt since they are stored as js strings already. - onclick_js_tpl = "cardClicked('{tabname}', {prompt}, {neg_prompt}, '{allow_neg}');" + onclick_js_tpl = "cardClicked('{tabname}', {prompt}, {neg_prompt}, {allow_neg});" onclick = onclick_js_tpl.format( **{ "tabname": tabname, @@ -286,11 +281,10 @@ def create_item_html( ).strip() search_terms_html = "" - search_term_template = "{search_term}" + search_term_template = "" for search_term in item.get("search_terms", []): search_terms_html += search_term_template.format( **{ - "style": "display: none;", "class": f"search_terms{' search_only' if search_only else ''}", "search_term": search_term, } @@ -301,7 +295,7 @@ def create_item_html( "background_image": background_image, "card_clicked": onclick, "copy_path_button": btn_copy_path, - "description": (item.get("description") or "" if shared.opts.extra_networks_card_show_desc else ""), + "description": (item.get("description", "") or "" if shared.opts.extra_networks_card_show_desc else ""), "edit_button": btn_edit_item, "local_preview": quote_js(item["local_preview"]), "metadata_button": btn_metadata, @@ -311,7 +305,7 @@ def create_item_html( "search_only": " search_only" if search_only else "", "search_terms": search_terms_html, "sort_keys": sort_keys, - "style": f"display: none; {height}{width}; font-size: {shared.opts.extra_networks_card_text_scale*100}%", + "style": card_style, "tabname": tabname, "extra_networks_tabname": self.extra_networks_tabname, } @@ -500,7 +494,6 @@ def create_card_view_html(self, tabname: str) -> str: HTML formatted string. """ res = "" - self.items = {x["name"]: x for x in self.list_items()} for item in self.items.values(): res += self.create_item_html(tabname, item, self.card_tpl) @@ -524,6 +517,14 @@ def create_html(self, tabname): self.lister.reset() self.metadata = {} self.items = {x["name"]: x for x in self.list_items()} + # Populate the instance metadata for each item. + for item in self.items.values(): + metadata = item.get("metadata") + if metadata: + self.metadata[item["name"]] = metadata + + if "user_metadata" not in item: + self.read_user_metadata(item) return self.pane_tpl.format( **{ From a97147bc8a43ade7c18bb755f0cfac111fc1a619 Mon Sep 17 00:00:00 2001 From: n0kovo Date: Fri, 19 Jan 2024 00:10:02 +0100 Subject: [PATCH 0610/1174] Add support for DAT upscaler models --- modules/dat_model.py | 81 +++++++++++++++++++++++++++++++++++++++ modules/shared_items.py | 5 +++ modules/shared_options.py | 3 ++ 3 files changed, 89 insertions(+) create mode 100644 modules/dat_model.py diff --git a/modules/dat_model.py b/modules/dat_model.py new file mode 100644 index 00000000000..8637351c500 --- /dev/null +++ b/modules/dat_model.py @@ -0,0 +1,81 @@ +import os +import sys + +from modules import modelloader, devices +from modules.shared import cmd_opts, opts +from modules.upscaler import Upscaler, UpscalerData +from modules.upscaler_utils import upscale_with_model + +from icecream import ic + +class UpscalerDAT(Upscaler): + def __init__(self, user_path): + self.name = "DAT" + self.user_path = user_path + self.scalers = [] + super().__init__() + + for file in self.find_models(ext_filter=[".pt", ".pth"]): + name = modelloader.friendly_name(file) + scaler_data = UpscalerData(name, file, upscaler=self, scale=None) + self.scalers.append(scaler_data) + + for model in get_dat_models(self): + if model.name in opts.dat_enabled_models: + self.scalers.append(model) + + def do_upscale(self, img, selected_model): + try: + info = self.load_model(selected_model) + except Exception as e: + errors.report(f"Unable to load DAT model {path}", exc_info=True) + return img + + model_descriptor = modelloader.load_spandrel_model( + info.local_data_path, + device=self.device, + prefer_half=(not cmd_opts.no_half and not cmd_opts.upcast_sampling), + expected_architecture="DAT", + ) + return upscale_with_model( + model_descriptor, + img, + tile_size=opts.DAT_tile, + tile_overlap=opts.DAT_tile_overlap, + ) + + def load_model(self, path): + for scaler in self.scalers: + if scaler.data_path == path: + if scaler.local_data_path.startswith("http"): + scaler.local_data_path = modelloader.load_file_from_url( + scaler.data_path, + model_dir=self.model_download_path, + ) + if not os.path.exists(scaler.local_data_path): + raise FileNotFoundError(f"DAT data missing: {scaler.local_data_path}") + return scaler + raise ValueError(f"Unable to find model info: {path}") + + +def get_dat_models(scaler): + return [ + UpscalerData( + name="DAT x2", + path="https://github.com/n0kovo/dat_upscaler_models/raw/main/DAT/DAT_x2.pth", + scale=2, + upscaler=scaler, + ), + UpscalerData( + name="DAT x3", + path="https://github.com/n0kovo/dat_upscaler_models/raw/main/DAT/DAT_x3.pth", + scale=3, + upscaler=scaler, + ), + UpscalerData( + name="DAT x4", + path="https://github.com/n0kovo/dat_upscaler_models/raw/main/DAT/DAT_x4.pth", + scale=4, + upscaler=scaler, + ), + ] diff --git a/modules/shared_items.py b/modules/shared_items.py index 13fb2814f2b..88f636452c7 100644 --- a/modules/shared_items.py +++ b/modules/shared_items.py @@ -8,6 +8,11 @@ def realesrgan_models_names(): return [x.name for x in modules.realesrgan_model.get_realesrgan_models(None)] +def dat_models_names(): + import modules.dat_model + return [x.name for x in modules.dat_model.get_dat_models(None)] + + def postprocessing_scripts(): import modules.scripts diff --git a/modules/shared_options.py b/modules/shared_options.py index 63488f4e700..48a206ce941 100644 --- a/modules/shared_options.py +++ b/modules/shared_options.py @@ -97,6 +97,9 @@ "ESRGAN_tile": OptionInfo(192, "Tile size for ESRGAN upscalers.", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}).info("0 = no tiling"), "ESRGAN_tile_overlap": OptionInfo(8, "Tile overlap for ESRGAN upscalers.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}).info("Low values = visible seam"), "realesrgan_enabled_models": OptionInfo(["R-ESRGAN 4x+", "R-ESRGAN 4x+ Anime6B"], "Select which Real-ESRGAN models to show in the web UI.", gr.CheckboxGroup, lambda: {"choices": shared_items.realesrgan_models_names()}), + "dat_enabled_models": OptionInfo(["R-ESRGAN 4x+", "R-ESRGAN 4x+ Anime6B"], "Select which DAT models to show in the web UI.", gr.CheckboxGroup, lambda: {"choices": shared_items.dat_models_names()}), + "DAT_tile": OptionInfo(192, "Tile size for DAT upscalers.", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}).info("0 = no tiling"), + "DAT_tile_overlap": OptionInfo(8, "Tile overlap for DAT upscalers.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}).info("Low values = visible seam"), "upscaler_for_img2img": OptionInfo(None, "Upscaler for img2img", gr.Dropdown, lambda: {"choices": [x.name for x in shared.sd_upscalers]}), })) From 2e7efe47b6c78a59f9f3d64c1d30f54751a31a56 Mon Sep 17 00:00:00 2001 From: n0kovo Date: Fri, 19 Jan 2024 00:39:14 +0100 Subject: [PATCH 0611/1174] Minor cleanup --- modules/dat_model.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/modules/dat_model.py b/modules/dat_model.py index 8637351c500..495d5f4937d 100644 --- a/modules/dat_model.py +++ b/modules/dat_model.py @@ -1,12 +1,10 @@ import os -import sys -from modules import modelloader, devices +from modules import modelloader, errors from modules.shared import cmd_opts, opts from modules.upscaler import Upscaler, UpscalerData from modules.upscaler_utils import upscale_with_model -from icecream import ic class UpscalerDAT(Upscaler): def __init__(self, user_path): @@ -24,13 +22,13 @@ def __init__(self, user_path): if model.name in opts.dat_enabled_models: self.scalers.append(model) - def do_upscale(self, img, selected_model): + def do_upscale(self, img, path): try: - info = self.load_model(selected_model) - except Exception as e: + info = self.load_model(path) + except Exception: errors.report(f"Unable to load DAT model {path}", exc_info=True) return img - + model_descriptor = modelloader.load_spandrel_model( info.local_data_path, device=self.device, From 1ddb886a804dc69f542ebc71bdd7baec48f677b6 Mon Sep 17 00:00:00 2001 From: n0kovo Date: Fri, 19 Jan 2024 00:48:46 +0100 Subject: [PATCH 0612/1174] Fix wrong options value --- modules/shared_options.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/shared_options.py b/modules/shared_options.py index 48a206ce941..74a2a67f933 100644 --- a/modules/shared_options.py +++ b/modules/shared_options.py @@ -97,7 +97,7 @@ "ESRGAN_tile": OptionInfo(192, "Tile size for ESRGAN upscalers.", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}).info("0 = no tiling"), "ESRGAN_tile_overlap": OptionInfo(8, "Tile overlap for ESRGAN upscalers.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}).info("Low values = visible seam"), "realesrgan_enabled_models": OptionInfo(["R-ESRGAN 4x+", "R-ESRGAN 4x+ Anime6B"], "Select which Real-ESRGAN models to show in the web UI.", gr.CheckboxGroup, lambda: {"choices": shared_items.realesrgan_models_names()}), - "dat_enabled_models": OptionInfo(["R-ESRGAN 4x+", "R-ESRGAN 4x+ Anime6B"], "Select which DAT models to show in the web UI.", gr.CheckboxGroup, lambda: {"choices": shared_items.dat_models_names()}), + "dat_enabled_models": OptionInfo(["DAT x2", "DAT x3", "DAT x4"], "Select which DAT models to show in the web UI.", gr.CheckboxGroup, lambda: {"choices": shared_items.dat_models_names()}), "DAT_tile": OptionInfo(192, "Tile size for DAT upscalers.", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}).info("0 = no tiling"), "DAT_tile_overlap": OptionInfo(8, "Tile overlap for DAT upscalers.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}).info("Low values = visible seam"), "upscaler_for_img2img": OptionInfo(None, "Upscaler for img2img", gr.Dropdown, lambda: {"choices": [x.name for x in shared.sd_upscalers]}), From d31dc7a73996ee611a72ef641237606fc9eddca5 Mon Sep 17 00:00:00 2001 From: Andray Date: Sat, 20 Jan 2024 00:40:03 +0400 Subject: [PATCH 0613/1174] fix extras big batch crashes --- modules/postprocessing.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/modules/postprocessing.py b/modules/postprocessing.py index 7449b0dc51c..f14882321ec 100644 --- a/modules/postprocessing.py +++ b/modules/postprocessing.py @@ -62,8 +62,6 @@ def get_images(extras_mode, image, image_folder, input_dir): else: image_data = image_placeholder - shared.state.assign_current_image(image_data) - parameters, existing_pnginfo = images.read_info_from_image(image_data) if parameters: existing_pnginfo["parameters"] = parameters @@ -92,6 +90,8 @@ def get_images(extras_mode, image, image_folder, input_dir): pp.image.info = existing_pnginfo pp.image.info["postprocessing"] = infotext + shared.state.assign_current_image(pp.image) + if save_output: fullfn, _ = images.save_image(pp.image, path=outpath, basename=basename, extension=opts.samples_format, info=infotext, short_filename=True, no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=existing_pnginfo, forced_filename=forced_filename, suffix=suffix) From 4dde96109bb9621bb5608f7f792dc569d692ecf5 Mon Sep 17 00:00:00 2001 From: WebDev <6970043+WebDev9000@users.noreply.github.com> Date: Fri, 19 Jan 2024 16:38:34 -0800 Subject: [PATCH 0614/1174] Update zoom.js --- extensions-builtin/canvas-zoom-and-pan/javascript/zoom.js | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/extensions-builtin/canvas-zoom-and-pan/javascript/zoom.js b/extensions-builtin/canvas-zoom-and-pan/javascript/zoom.js index dd6028c1e25..df60c1a1775 100644 --- a/extensions-builtin/canvas-zoom-and-pan/javascript/zoom.js +++ b/extensions-builtin/canvas-zoom-and-pan/javascript/zoom.js @@ -218,8 +218,8 @@ onUiLoaded(async() => { canvas_hotkey_fullscreen: "KeyS", canvas_hotkey_move: "KeyF", canvas_hotkey_overlap: "KeyO", - canvas_hotkey_shrink_brush: "BracketLeft", - canvas_hotkey_grow_brush: "BracketRight", + canvas_hotkey_shrink_brush: "KeyQ", + canvas_hotkey_grow_brush: "KeyW", canvas_disabled_functions: [], canvas_show_tooltip: true, canvas_auto_expand: true, From bcdcc8be7d5fb923d64b7eb8a9a831f26f179d97 Mon Sep 17 00:00:00 2001 From: WebDev <6970043+WebDev9000@users.noreply.github.com> Date: Fri, 19 Jan 2024 16:39:07 -0800 Subject: [PATCH 0615/1174] Update hotkey_config.py --- .../canvas-zoom-and-pan/scripts/hotkey_config.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/extensions-builtin/canvas-zoom-and-pan/scripts/hotkey_config.py b/extensions-builtin/canvas-zoom-and-pan/scripts/hotkey_config.py index 9d9accce8cf..6d9e34d8b0e 100644 --- a/extensions-builtin/canvas-zoom-and-pan/scripts/hotkey_config.py +++ b/extensions-builtin/canvas-zoom-and-pan/scripts/hotkey_config.py @@ -4,8 +4,8 @@ shared.options_templates.update(shared.options_section(('canvas_hotkey', "Canvas Hotkeys"), { "canvas_hotkey_zoom": shared.OptionInfo("Alt", "Zoom canvas", gr.Radio, {"choices": ["Shift","Ctrl", "Alt"]}).info("If you choose 'Shift' you cannot scroll horizontally, 'Alt' can cause a little trouble in firefox"), "canvas_hotkey_adjust": shared.OptionInfo("Ctrl", "Adjust brush size", gr.Radio, {"choices": ["Shift","Ctrl", "Alt"]}).info("If you choose 'Shift' you cannot scroll horizontally, 'Alt' can cause a little trouble in firefox"), - "canvas_hotkey_shrink_brush": shared.OptionInfo("[", "Shrink the brush size"), - "canvas_hotkey_grow_brush": shared.OptionInfo("]", "Enlarge the brush size"), + "canvas_hotkey_shrink_brush": shared.OptionInfo("Q", "Shrink the brush size"), + "canvas_hotkey_grow_brush": shared.OptionInfo("W", "Enlarge the brush size"), "canvas_hotkey_move": shared.OptionInfo("F", "Moving the canvas").info("To work correctly in firefox, turn off 'Automatically search the page text when typing' in the browser settings"), "canvas_hotkey_fullscreen": shared.OptionInfo("S", "Fullscreen Mode, maximizes the picture so that it fits into the screen and stretches it to its full width "), "canvas_hotkey_reset": shared.OptionInfo("R", "Reset zoom and canvas positon"), From 30f31d23c6d212389aef035137c381dd79d44a26 Mon Sep 17 00:00:00 2001 From: WebDev <6970043+WebDev9000@users.noreply.github.com> Date: Fri, 19 Jan 2024 16:39:33 -0800 Subject: [PATCH 0616/1174] Update hotkey_config.py --- extensions-builtin/canvas-zoom-and-pan/scripts/hotkey_config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/extensions-builtin/canvas-zoom-and-pan/scripts/hotkey_config.py b/extensions-builtin/canvas-zoom-and-pan/scripts/hotkey_config.py index 6d9e34d8b0e..89b7c31f22d 100644 --- a/extensions-builtin/canvas-zoom-and-pan/scripts/hotkey_config.py +++ b/extensions-builtin/canvas-zoom-and-pan/scripts/hotkey_config.py @@ -5,7 +5,7 @@ "canvas_hotkey_zoom": shared.OptionInfo("Alt", "Zoom canvas", gr.Radio, {"choices": ["Shift","Ctrl", "Alt"]}).info("If you choose 'Shift' you cannot scroll horizontally, 'Alt' can cause a little trouble in firefox"), "canvas_hotkey_adjust": shared.OptionInfo("Ctrl", "Adjust brush size", gr.Radio, {"choices": ["Shift","Ctrl", "Alt"]}).info("If you choose 'Shift' you cannot scroll horizontally, 'Alt' can cause a little trouble in firefox"), "canvas_hotkey_shrink_brush": shared.OptionInfo("Q", "Shrink the brush size"), - "canvas_hotkey_grow_brush": shared.OptionInfo("W", "Enlarge the brush size"), + "canvas_hotkey_grow_brush": shared.OptionInfo("W", "Enlarge the brush size"), "canvas_hotkey_move": shared.OptionInfo("F", "Moving the canvas").info("To work correctly in firefox, turn off 'Automatically search the page text when typing' in the browser settings"), "canvas_hotkey_fullscreen": shared.OptionInfo("S", "Fullscreen Mode, maximizes the picture so that it fits into the screen and stretches it to its full width "), "canvas_hotkey_reset": shared.OptionInfo("R", "Reset zoom and canvas positon"), From 56676ff923497901d56fcbca1ac549095e71d72d Mon Sep 17 00:00:00 2001 From: Andray Date: Sat, 20 Jan 2024 11:49:05 +0400 Subject: [PATCH 0617/1174] fix tab indexes reset after restart ui --- modules/ui.py | 4 ++-- modules/ui_postprocessing.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/modules/ui.py b/modules/ui.py index a716a04055d..ebd33a85616 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -532,7 +532,7 @@ def add_copy_image_controls(tab_name, elem): if category == "image": with gr.Tabs(elem_id="mode_img2img"): - img2img_selected_tab = gr.State(0) + img2img_selected_tab = gr.Number(value=0, visible=False) with gr.TabItem('img2img', id='img2img', elem_id="img2img_img2img_tab") as tab_img2img: init_img = gr.Image(label="Image for img2img", elem_id="img2img_image", show_label=False, source="upload", interactive=True, type="pil", tool="editor", image_mode="RGBA", height=opts.img2img_editor_height) @@ -613,7 +613,7 @@ def copy_image(img): elif category == "dimensions": with FormRow(): with gr.Column(elem_id="img2img_column_size", scale=4): - selected_scale_tab = gr.State(value=0) + selected_scale_tab = gr.Number(value=0, visible=False) with gr.Tabs(): with gr.Tab(label="Resize to", elem_id="img2img_tab_resize_to") as tab_scale_to: diff --git a/modules/ui_postprocessing.py b/modules/ui_postprocessing.py index 7a132ac22ed..ff22a178075 100644 --- a/modules/ui_postprocessing.py +++ b/modules/ui_postprocessing.py @@ -5,7 +5,7 @@ def create_ui(): dummy_component = gr.Label(visible=False) - tab_index = gr.State(value=0) + tab_index = gr.Number(value=0, visible=False) with gr.Row(equal_height=False, variant='compact'): with gr.Column(variant='compact'): From 81126027f5226e7ee58e1a99194eb9ec7b8ec6e7 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Sat, 20 Jan 2024 16:31:12 +0800 Subject: [PATCH 0618/1174] Avoid early disable --- modules/devices.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/modules/devices.py b/modules/devices.py index 3702862979c..3bde169904b 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -164,9 +164,11 @@ def forward_wrapper(self, *args, **kwargs): @contextlib.contextmanager def manual_cast(target_dtype): + applied = False for module_type in patch_module_list: if hasattr(module_type, "org_forward"): continue + applied = True org_forward = module_type.forward if module_type == torch.nn.MultiheadAttention and has_xpu(): module_type.forward = manual_cast_forward(torch.float32) @@ -176,6 +178,8 @@ def manual_cast(target_dtype): try: yield None finally: + if not applied: + return for module_type in patch_module_list: if hasattr(module_type, "org_forward"): module_type.forward = module_type.org_forward From 4a66d2fb228584bb38dc22db6a3e657561834c7a Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Sat, 20 Jan 2024 16:33:59 +0800 Subject: [PATCH 0619/1174] Avoid exceptions to be silenced --- modules/devices.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/modules/devices.py b/modules/devices.py index 3bde169904b..dfffaf24fd0 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -178,12 +178,11 @@ def manual_cast(target_dtype): try: yield None finally: - if not applied: - return - for module_type in patch_module_list: - if hasattr(module_type, "org_forward"): - module_type.forward = module_type.org_forward - delattr(module_type, "org_forward") + if applied: + for module_type in patch_module_list: + if hasattr(module_type, "org_forward"): + module_type.forward = module_type.org_forward + delattr(module_type, "org_forward") def autocast(disable=False): From ed383eb5a0db78d0da60420bb464943e4cc9b847 Mon Sep 17 00:00:00 2001 From: Andray Date: Sat, 20 Jan 2024 15:37:49 +0400 Subject: [PATCH 0620/1174] keep postprocessing upscale selected tab after restart --- scripts/postprocessing_upscale.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/postprocessing_upscale.py b/scripts/postprocessing_upscale.py index a57f9d4a4b6..e269682d0f3 100644 --- a/scripts/postprocessing_upscale.py +++ b/scripts/postprocessing_upscale.py @@ -15,7 +15,7 @@ class ScriptPostprocessingUpscale(scripts_postprocessing.ScriptPostprocessing): order = 1000 def ui(self): - selected_tab = gr.State(value=0) + selected_tab = gr.Number(value=0, visible=False) with gr.Column(): with FormRow(): From 2310cd66e5381fbe6b966894381c6ee7b762898f Mon Sep 17 00:00:00 2001 From: Sj-Si Date: Sat, 20 Jan 2024 11:43:45 -0500 Subject: [PATCH 0621/1174] Add toggle button for tree view. Use default settings for sortmode and direction. --- html/extra-networks-pane.html | 55 ++++++++++++++++++++-- html/extra-networks-tree.html | 37 --------------- javascript/extraNetworks.js | 20 ++++++-- modules/ui_extra_networks.py | 7 +++ style.css | 86 +++++++++++++++++++++++------------ 5 files changed, 132 insertions(+), 73 deletions(-) diff --git a/html/extra-networks-pane.html b/html/extra-networks-pane.html index bf46ca16305..73dad2ab248 100644 --- a/html/extra-networks-pane.html +++ b/html/extra-networks-pane.html @@ -1,8 +1,55 @@
                  -
                  - {tree_html} +
                  + +
                  + +
                  +
                  + +
                  +
                  + +
                  +
                  + +
                  -
                  - {items_html} +
                  +
                  + {tree_html} +
                  +
                  + {items_html} +
                  \ No newline at end of file diff --git a/html/extra-networks-tree.html b/html/extra-networks-tree.html index 23f6af1058c..39649e86093 100644 --- a/html/extra-networks-tree.html +++ b/html/extra-networks-tree.html @@ -1,41 +1,4 @@
                  -
                  - -
                  - -
                  -
                  - -
                  -
                  - -
                  -
                  {tree}
                  diff --git a/javascript/extraNetworks.js b/javascript/extraNetworks.js index 3029afec847..ce788328ce5 100644 --- a/javascript/extraNetworks.js +++ b/javascript/extraNetworks.js @@ -348,7 +348,7 @@ function extraNetworksTreeOnClick(event, tabname, extra_networks_tabname) { } } -function extraNetworksTreeSortOnClick(event, tabname, extra_networks_tabname) { +function extraNetworksControlSortOnClick(event, tabname, extra_networks_tabname) { /** * Handles `onclick` events for the Sort Mode button. * @@ -382,7 +382,7 @@ function extraNetworksTreeSortOnClick(event, tabname, extra_networks_tabname) { applyExtraNetworkSort(tabname + "_" + extra_networks_tabname); } -function extraNetworksTreeSortDirOnClick(event, tabname, extra_networks_tabname) { +function extraNetworksControlSortDirOnClick(event, tabname, extra_networks_tabname) { /** * Handles `onclick` events for the Sort Direction button. * @@ -403,7 +403,21 @@ function extraNetworksTreeSortDirOnClick(event, tabname, extra_networks_tabname) applyExtraNetworkSort(tabname + "_" + extra_networks_tabname); } -function extraNetworksTreeRefreshOnClick(event, tabname, extra_networks_tabname) { +function extraNetworksControlTreeViewOnClick(event, tabname, extra_networks_tabname) { + /** + * Handles `onclick` events for the Tree View button. + * + * Toggles the tree view in the extra networks pane. + * + * @param event The generated event. + * @param tabname The name of the active tab in the sd webui. Ex: txt2img, img2img, etc. + * @param extra_networks_tabname The id of the active extraNetworks tab. Ex: lora, checkpoints, etc. + */ + gradioApp().getElementById(tabname + "_" + extra_networks_tabname + "_tree").classList.toggle("hidden"); + event.currentTarget.classList.toggle("extra-network-control--enabled"); +} + +function extraNetworksControlRefreshOnClick(event, tabname, extra_networks_tabname) { /** * Handles `onclick` events for the Refresh Page button. * diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index 656e7f181e4..4c8a407449e 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -526,10 +526,17 @@ def create_html(self, tabname): if "user_metadata" not in item: self.read_user_metadata(item) + data_sortdir = shared.opts.extra_networks_card_order + data_sortmode = shared.opts.extra_networks_card_order_field.lower().replace("sort", "").replace(" ", "_").rstrip("_").strip() + data_sortkey = f"{data_sortmode}-{data_sortdir}-{len(self.items)}" + return self.pane_tpl.format( **{ "tabname": tabname, "extra_networks_tabname": self.extra_networks_tabname, + "data_sortmode": data_sortmode, + "data_sortkey": data_sortkey, + "data_sortdir": data_sortdir, "tree_html": self.create_tree_view_html(tabname), "items_html": self.create_card_view_html(tabname), } diff --git a/style.css b/style.css index 57c52354a0b..f3fd1571bb1 100644 --- a/style.css +++ b/style.css @@ -1178,6 +1178,13 @@ body.resizing .resize-handle { height: calc(100vh - 24rem); resize: vertical; min-height: 52rem; + flex-direction: column; +} + +.extra-network-pane .extra-network-pane-content { + display: flex; + flex: 1; + flex-direction: row; } .extra-network-pane .extra-network-tree { @@ -1234,7 +1241,7 @@ body.resizing .resize-handle { display: none; } -.extra-network-tree .tree-list .tree-list-controls { +.extra-network-pane .extra-network-control { position: relative; display: grid; width: 100%; @@ -1248,8 +1255,7 @@ body.resizing .resize-handle { border: none; transition: background 33.333ms linear; grid-template-rows: min-content; - grid-template-areas: "tree-list-controls-col-0 tree-list-controls-col-1 tree-list-controls-col-2 tree-list-controls-col-3"; - grid-template-columns: minmax(0, auto) min-content min-content min-content; + grid-template-columns: minmax(0, auto) repeat(4, min-content); grid-gap: 0.1rem; align-items: start; } @@ -1342,16 +1348,16 @@ body.resizing .resize-handle { background-color: var(--neutral-700); } +.extra-network-tree div.tree-list-content[data-selected] { + background-color: var(--neutral-300); +} + .extra-network-tree div.tree-list-content:hover { -webkit-transition: all 0.05s ease-in-out; transition: all 0.05s ease-in-out; background-color: var(--neutral-200); } -.extra-network-tree div.tree-list-content[data-selected] { - background-color: var(--neutral-300); -} - /* ==== CHEVRON ICON ACTIONS ==== */ /* Define the animation for the arrow when it is clicked. */ .extra-network-tree .tree-list-content-dir .tree-list-item-action-chevron { @@ -1378,7 +1384,7 @@ body.resizing .resize-handle { /* ==== SEARCH INPUT ACTIONS ==== */ /* Add icon to left side of */ -.extra-network-tree .tree-list-controls .tree-list-search::before { +.extra-network-pane .extra-network-control .extra-network-control--search::before { content: "🔎︎"; position: absolute; margin: 0.5rem; @@ -1386,14 +1392,12 @@ body.resizing .resize-handle { color: var(--input-placeholder-color); } -.extra-network-tree .tree-list-controls .tree-list-search { +.extra-network-pane .extra-network-control .extra-network-control--search { display: inline-flex; - grid-area: tree-list-controls-col-0; position: relative; - margin: 0.5rem; } -.extra-network-tree .tree-list-controls .tree-list-search .tree-list-search-text { +.extra-network-pane .extra-network-control .extra-network-control--search .extra-network-control--search-text { border: 1px solid var(--button-secondary-border-color); border-radius: 0.5rem; color: var(--button-secondary-text-color); @@ -1404,7 +1408,7 @@ body.resizing .resize-handle { } /* clear button (x on right side) styling */ -.extra-network-tree .tree-list-controls .tree-list-search .tree-list-search-text::-webkit-search-cancel-button { +.extra-network-pane .extra-network-control .extra-network-control--search .extra-network-control--search-text::-webkit-search-cancel-button { -webkit-appearance: none; appearance: none; cursor: pointer; @@ -1418,8 +1422,7 @@ body.resizing .resize-handle { } /* ==== SORT ICON ACTIONS ==== */ -.extra-network-tree .tree-list-controls .tree-list-sort { - grid-area: tree-list-controls-col-1; +.extra-network-pane .extra-network-control .extra-network-control--sort { padding: 0.25rem; display: inline-flex; cursor: pointer; @@ -1427,7 +1430,7 @@ body.resizing .resize-handle { align-self: center; } -.extra-network-tree .tree-list-controls .tree-list-sort .tree-list-sort-icon { +.extra-network-pane .extra-network-control .extra-network-control--sort .extra-network-control--sort-icon { height: 1.5rem; width: 1.5rem; mask-repeat: no-repeat; @@ -1436,25 +1439,24 @@ body.resizing .resize-handle { background-color: var(--input-placeholder-color); } -.extra-network-tree .tree-list-sort[data-sortmode="path"] .tree-list-sort-icon { +.extra-network-pane .extra-network-control .extra-network-control--sort[data-sortmode="path"] .extra-network-control--sort-icon { mask-image: url('data:image/svg+xml,'); } -.extra-network-tree .tree-list-sort[data-sortmode="name"] .tree-list-sort-icon { +.extra-network-pane .extra-network-control .extra-network-control--sort[data-sortmode="name"] .extra-network-control--sort-icon { mask-image: url('data:image/svg+xml,'); } -.extra-network-tree .tree-list-sort[data-sortmode="date_created"] .tree-list-sort-icon { +.extra-network-pane .extra-network-control .extra-network-control--sort[data-sortmode="date_created"] .extra-network-control--sort-icon { mask-image: url('data:image/svg+xml,'); } -.extra-network-tree .tree-list-sort[data-sortmode="date_modified"] .tree-list-sort-icon { +.extra-network-pane .extra-network-control .extra-network-control--sort[data-sortmode="date_modified"] .extra-network-control--sort-icon { mask-image: url('data:image/svg+xml,'); } /* ==== SORT DIRECTION ICON ACTIONS ==== */ -.extra-network-tree .tree-list-controls .tree-list-sort-dir { - grid-area: tree-list-controls-col-2; +.extra-network-pane .extra-network-control .extra-network-control--sort-dir { padding: 0.25rem; display: inline-flex; cursor: pointer; @@ -1462,7 +1464,7 @@ body.resizing .resize-handle { align-self: center; } -.extra-network-tree .tree-list-controls .tree-list-sort-dir .tree-list-sort-dir-icon { +.extra-network-pane .extra-network-control .extra-network-control--sort-dir .extra-network-control--sort-dir-icon { height: 1.5rem; width: 1.5rem; mask-repeat: no-repeat; @@ -1471,17 +1473,43 @@ body.resizing .resize-handle { background-color: var(--input-placeholder-color); } -.extra-network-tree .tree-list-sort-dir[data-sortdir="Ascending"] .tree-list-sort-dir-icon { +.extra-network-pane .extra-network-control .extra-network-control--sort-dir[data-sortdir="Ascending"] .extra-network-control--sort-dir-icon { mask-image: url('data:image/svg+xml,'); } -.extra-network-tree .tree-list-sort-dir[data-sortdir="Descending"] .tree-list-sort-dir-icon { +.extra-network-pane .extra-network-control .extra-network-control--sort-dir[data-sortdir="Descending"] .extra-network-control--sort-dir-icon { mask-image: url('data:image/svg+xml,'); } +/* ==== TREE VIEW ICON ACTIONS ==== */ +.extra-network-pane .extra-network-control .extra-network-control--tree-view { + padding: 0.25rem; + display: inline-flex; + cursor: pointer; + justify-self: center; + align-self: center; +} + +.extra-network-pane .extra-network-control .extra-network-control--tree-view .extra-network-control--tree-view-icon { + height: 1.5rem; + width: 1.5rem; + mask-image: url('data:image/svg+xml,'); + mask-repeat: no-repeat; + mask-position: center center; + mask-size: 100%; + background-color: var(--input-placeholder-color); +} + +.dark .extra-network-pane .extra-network-control .extra-network-control--enabled { + background-color: var(--neutral-700); +} + +.dark .extra-network-pane .extra-network-control .extra-network-control--enabled { + background-color: var(--neutral-300); +} + /* ==== REFRESH ICON ACTIONS ==== */ -.extra-network-tree .tree-list-controls .tree-list-refresh { - grid-area: tree-list-controls-col-3; +.extra-network-pane .extra-network-control .extra-network-control--refresh { padding: 0.25rem; display: inline-flex; cursor: pointer; @@ -1489,7 +1517,7 @@ body.resizing .resize-handle { align-self: center; } -.extra-network-tree .tree-list-controls .tree-list-refresh .tree-list-refresh-icon { +.extra-network-pane .extra-network-control .extra-network-control--refresh .extra-network-control--refresh-icon { height: 1.5rem; width: 1.5rem; mask-image: url('data:image/svg+xml,'); @@ -1499,7 +1527,7 @@ body.resizing .resize-handle { background-color: var(--input-placeholder-color); } -.extra-network-tree .tree-list-refresh-icon:active { +.extra-network-pane .extra-network-control .extra-network-control--refresh-icon:active { -ms-transform: rotate(180deg); -webkit-transform: rotate(180deg); transform: rotate(180deg); From b67a49441fc420f37c6bef1172a0b1ad5c42f30f Mon Sep 17 00:00:00 2001 From: Sj-Si Date: Sat, 20 Jan 2024 13:28:37 -0500 Subject: [PATCH 0622/1174] Add option in settings to enable/disable tree view by default. --- html/extra-networks-pane.html | 4 ++-- modules/shared_options.py | 1 + modules/ui_extra_networks.py | 7 +++++++ 3 files changed, 10 insertions(+), 2 deletions(-) diff --git a/html/extra-networks-pane.html b/html/extra-networks-pane.html index 73dad2ab248..9f5b3ecee62 100644 --- a/html/extra-networks-pane.html +++ b/html/extra-networks-pane.html @@ -29,7 +29,7 @@
                  @@ -45,7 +45,7 @@
                  -
                  +
                  {tree_html}
                  diff --git a/modules/shared_options.py b/modules/shared_options.py index 63488f4e700..e0a6d977c30 100644 --- a/modules/shared_options.py +++ b/modules/shared_options.py @@ -251,6 +251,7 @@ "extra_networks_card_show_desc": OptionInfo(True, "Show description on card"), "extra_networks_card_order_field": OptionInfo("Path", "Default order field for Extra Networks cards", gr.Dropdown, {"choices": ['Path', 'Name', 'Date Created', 'Date Modified']}).needs_reload_ui(), "extra_networks_card_order": OptionInfo("Ascending", "Default order for Extra Networks cards", gr.Dropdown, {"choices": ['Ascending', 'Descending']}).needs_reload_ui(), + "extra_networks_tree_view_default_enabled": OptionInfo(False, "Enables the Extra Networks directory tree view by default").needs_reload_ui(), "extra_networks_add_text_separator": OptionInfo(" ", "Extra networks separator").info("extra text to add before <...> when adding extra network to prompt"), "ui_extra_networks_tab_reorder": OptionInfo("", "Extra networks tab order").needs_reload_ui(), "textual_inversion_print_at_load": OptionInfo(False, "Print a list of Textual Inversion embeddings when loading model"), diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index 4c8a407449e..80160b847b6 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -529,6 +529,11 @@ def create_html(self, tabname): data_sortdir = shared.opts.extra_networks_card_order data_sortmode = shared.opts.extra_networks_card_order_field.lower().replace("sort", "").replace(" ", "_").rstrip("_").strip() data_sortkey = f"{data_sortmode}-{data_sortdir}-{len(self.items)}" + tree_view_btn_extra_class = "" + tree_view_div_extra_class = "hidden" + if shared.opts.extra_networks_tree_view_default_enabled: + tree_view_btn_extra_class = "extra-network-control--enabled" + tree_view_div_extra_class = "" return self.pane_tpl.format( **{ @@ -537,6 +542,8 @@ def create_html(self, tabname): "data_sortmode": data_sortmode, "data_sortkey": data_sortkey, "data_sortdir": data_sortdir, + "tree_view_btn_extra_class": tree_view_btn_extra_class, + "tree_view_div_extra_class": tree_view_div_extra_class, "tree_html": self.create_tree_view_html(tabname), "items_html": self.create_card_view_html(tabname), } From 25e8273d2f6481deca221b29d35093f6d0c9da6a Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Sun, 21 Jan 2024 02:21:36 +0900 Subject: [PATCH 0623/1174] re-work multi --styles-file --styles-file change to append str --styles-file is [] then defaults to [styles.csv] --styles-file accepts paths or paths with wildcard "*" the first `--styles-file` entry is use as the default styles file path if filename a wildcard then the first matching file is used if no match is found, create a new "styles.csv" in the same dir as the first path when saving a new style it will be save in the default styles file when saving a existing style, it will be saved to file it belongs to order of the styles files in the styles dropdown can be controlled to a certain degree by the order of --styles-file --- modules/cmd_args.py | 2 +- modules/shared.py | 3 +- modules/styles.py | 85 +++++++++++++++++++------------------ modules/ui_prompt_styles.py | 9 ++-- 4 files changed, 52 insertions(+), 47 deletions(-) diff --git a/modules/cmd_args.py b/modules/cmd_args.py index e58059a1fea..f1251b6c850 100644 --- a/modules/cmd_args.py +++ b/modules/cmd_args.py @@ -88,7 +88,7 @@ parser.add_argument("--gradio-inpaint-tool", type=str, help="does not do anything") parser.add_argument("--gradio-allowed-path", action='append', help="add path to gradio's allowed_paths, make it possible to serve files from it", default=[data_path]) parser.add_argument("--opt-channelslast", action='store_true', help="change memory type for stable diffusion to channels last") -parser.add_argument("--styles-file", type=str, help="filename to use for styles", default=os.path.join(data_path, 'styles.csv')) +parser.add_argument("--styles-file", type=str, action='append', help="path or wildcard path of styles files, allow multiple entries.", default=[]) parser.add_argument("--autolaunch", action='store_true', help="open the webui URL in the system's default browser upon launch", default=False) parser.add_argument("--theme", type=str, help="launches the UI with light or dark theme", default=None) parser.add_argument("--use-textbox-seed", action='store_true', help="use textbox for seeds in UI (no up/down, but possible to input long seeds)", default=False) diff --git a/modules/shared.py b/modules/shared.py index 636619391fc..ccdca4e7040 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -1,3 +1,4 @@ +import os import sys import gradio as gr @@ -11,7 +12,7 @@ batch_cond_uncond = True # old field, unused now in favor of shared.opts.batch_cond_uncond parallel_processing_allowed = True -styles_filename = cmd_opts.styles_file +styles_filename = cmd_opts.styles_file = cmd_opts.styles_file if len(cmd_opts.styles_file) > 0 else [os.path.join(data_path, 'styles.csv')] config_filename = cmd_opts.ui_settings_file hide_dirs = {"visible": not cmd_opts.hide_ui_dir_config} diff --git a/modules/styles.py b/modules/styles.py index 026c43001a4..9edcc7e4447 100644 --- a/modules/styles.py +++ b/modules/styles.py @@ -1,16 +1,15 @@ +from pathlib import Path import csv -import fnmatch import os -import os.path import typing import shutil class PromptStyle(typing.NamedTuple): name: str - prompt: str - negative_prompt: str - path: str = None + prompt: str | None + negative_prompt: str | None + path: str | None = None def merge_prompts(style_prompt: str, prompt: str) -> str: @@ -79,14 +78,19 @@ def extract_original_prompts(style: PromptStyle, prompt, negative_prompt): class StyleDatabase: - def __init__(self, path: str): + def __init__(self, paths: list[str | Path]): self.no_style = PromptStyle("None", "", "", None) self.styles = {} - self.path = path - - folder, file = os.path.split(self.path) - filename, _, ext = file.partition('*') - self.default_path = os.path.join(folder, filename + ext) + self.paths = paths + self.all_styles_files: list[Path] = [] + + folder, file = os.path.split(self.paths[0]) + if '*' in file or '?' in file: + # if the first path is a wildcard pattern, find the first match else use "folder/styles.csv" as the default path + self.default_path = next(Path(folder).glob(file), Path(os.path.join(folder, 'styles.csv'))) + self.paths.insert(0, self.default_path) + else: + self.default_path = Path(self.paths[0]) self.prompt_fields = [field for field in PromptStyle._fields if field != "path"] @@ -99,33 +103,31 @@ def reload(self): """ self.styles.clear() - path, filename = os.path.split(self.path) - - if "*" in filename: - fileglob = filename.split("*")[0] + "*.csv" - filelist = [] - for file in os.listdir(path): - if fnmatch.fnmatch(file, fileglob): - filelist.append(file) - # Add a visible divider to the style list - half_len = round(len(file) / 2) - divider = f"{'-' * (20 - half_len)} {file.upper()}" - divider = f"{divider} {'-' * (40 - len(divider))}" - self.styles[divider] = PromptStyle( - f"{divider}", None, None, "do_not_save" - ) - # Add styles from this CSV file - self.load_from_csv(os.path.join(path, file)) - if len(filelist) == 0: - print(f"No styles found in {path} matching {fileglob}") - return - elif not os.path.exists(self.path): - print(f"Style database not found: {self.path}") - return - else: - self.load_from_csv(self.path) - - def load_from_csv(self, path: str): + # scans for all styles files + all_styles_files = [] + for pattern in self.paths: + folder, file = os.path.split(pattern) + if '*' in file or '?' in file: + found_files = Path(folder).glob(file) + [all_styles_files.append(file) for file in found_files] + else: + # if os.path.exists(pattern): + all_styles_files.append(Path(pattern)) + + # Remove any duplicate entries + seen = set() + self.all_styles_files = [s for s in all_styles_files if not (s in seen or seen.add(s))] + + for styles_file in self.all_styles_files: + if len(all_styles_files) > 1: + # add divider when more than styles file + # '---------------- STYLES ----------------' + divider = f' {styles_file.stem.upper()} '.center(40, '-') + self.styles[divider] = PromptStyle(f"{divider}", None, None, "do_not_save") + if styles_file.is_file(): + self.load_from_csv(styles_file) + + def load_from_csv(self, path: str | Path): with open(path, "r", encoding="utf-8-sig", newline="") as file: reader = csv.DictReader(file, skipinitialspace=True) for row in reader: @@ -137,7 +139,7 @@ def load_from_csv(self, path: str): negative_prompt = row.get("negative_prompt", "") # Add style to database self.styles[row["name"]] = PromptStyle( - row["name"], prompt, negative_prompt, path + row["name"], prompt, negative_prompt, str(path) ) def get_style_paths(self) -> set: @@ -145,11 +147,11 @@ def get_style_paths(self) -> set: # Update any styles without a path to the default path for style in list(self.styles.values()): if not style.path: - self.styles[style.name] = style._replace(path=self.default_path) + self.styles[style.name] = style._replace(path=str(self.default_path)) # Create a list of all distinct paths, including the default path style_paths = set() - style_paths.add(self.default_path) + style_paths.add(str(self.default_path)) for _, style in self.styles.items(): if style.path: style_paths.add(style.path) @@ -177,7 +179,6 @@ def apply_negative_styles_to_prompt(self, prompt, styles): def save_styles(self, path: str = None) -> None: # The path argument is deprecated, but kept for backwards compatibility - _ = path style_paths = self.get_style_paths() diff --git a/modules/ui_prompt_styles.py b/modules/ui_prompt_styles.py index 0d74c23fa19..d67e3f17e10 100644 --- a/modules/ui_prompt_styles.py +++ b/modules/ui_prompt_styles.py @@ -22,9 +22,12 @@ def save_style(name, prompt, negative_prompt): if not name: return gr.update(visible=False) - style = styles.PromptStyle(name, prompt, negative_prompt) + existing_style = shared.prompt_styles.styles.get(name) + path = existing_style.path if existing_style is not None else None + + style = styles.PromptStyle(name, prompt, negative_prompt, path) shared.prompt_styles.styles[style.name] = style - shared.prompt_styles.save_styles(shared.styles_filename) + shared.prompt_styles.save_styles() return gr.update(visible=True) @@ -34,7 +37,7 @@ def delete_style(name): return shared.prompt_styles.styles.pop(name, None) - shared.prompt_styles.save_styles(shared.styles_filename) + shared.prompt_styles.save_styles() return '', '', '' From 8459015017fde8833a4159b8d56a32f92ebf4186 Mon Sep 17 00:00:00 2001 From: Arturo Albacete Date: Sat, 20 Jan 2024 21:19:53 +0100 Subject: [PATCH 0624/1174] skip if headers haven't changed --- modules/ui_common.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/modules/ui_common.py b/modules/ui_common.py index 92b00719238..2c67834c6c5 100644 --- a/modules/ui_common.py +++ b/modules/ui_common.py @@ -1,3 +1,4 @@ +import csv import dataclasses import json import html @@ -37,8 +38,6 @@ def plaintext_to_html(text, classname=None): def update_logfile(logfile_path, fields): - import csv - with open(logfile_path, "r", encoding="utf8", newline="") as file: reader = csv.reader(file) rows = list(reader) @@ -47,6 +46,10 @@ def update_logfile(logfile_path, fields): if not rows: return + # file is already synced, do nothing + if len(rows[0]) == len(fields): + return + rows[0] = fields # append new fields to each row as empty values @@ -60,7 +63,6 @@ def update_logfile(logfile_path, fields): def save_files(js_data, images, do_make_zip, index): - import csv filenames = [] fullfns = [] parsed_infotexts = [] From f190b85182a89f873fa99d897ac20310047131ea Mon Sep 17 00:00:00 2001 From: Arturo Albacete Date: Sat, 20 Jan 2024 21:27:38 +0100 Subject: [PATCH 0625/1174] restore saving fields --- modules/ui_common.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/ui_common.py b/modules/ui_common.py index 2c67834c6c5..6eb740f16c5 100644 --- a/modules/ui_common.py +++ b/modules/ui_common.py @@ -132,7 +132,7 @@ def __init__(self, d=None): filenames.append(os.path.basename(txt_fullfn)) fullfns.append(txt_fullfn) - writer.writerow([parsed_infotexts[0]['Prompt'], parsed_infotexts[0]['Seed'], data["width"], data["height"], data["sampler_name"], data["cfg_scale"], data["steps"], filenames[0], parsed_infotexts[0]['Negative prompt']]) + writer.writerow([parsed_infotexts[0]['Prompt'], parsed_infotexts[0]['Seed'], data["width"], data["height"], data["sampler_name"], data["cfg_scale"], data["steps"], filenames[0], parsed_infotexts[0]['Negative prompt'], data["sd_model_name"], data["sd_model_hash"]]) # Make Zip if do_make_zip: From 4aa99f77abd439469f20e6e2941dd553031ed2d0 Mon Sep 17 00:00:00 2001 From: Arturo Albacete Date: Sat, 20 Jan 2024 22:04:53 +0100 Subject: [PATCH 0626/1174] add docstring --- modules/ui_common.py | 1 + 1 file changed, 1 insertion(+) diff --git a/modules/ui_common.py b/modules/ui_common.py index 6eb740f16c5..29fe7d0e920 100644 --- a/modules/ui_common.py +++ b/modules/ui_common.py @@ -38,6 +38,7 @@ def plaintext_to_html(text, classname=None): def update_logfile(logfile_path, fields): + """Update a logfile from old format to new format to maintain CSV integrity.""" with open(logfile_path, "r", encoding="utf8", newline="") as file: reader = csv.reader(file) rows = list(reader) From e36827af3254f7bac9f8c78d6d56c709959b40b6 Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Sun, 21 Jan 2024 07:20:52 +0900 Subject: [PATCH 0627/1174] improve get_crop_region --- modules/masking.py | 43 +++++++++---------------------------------- modules/processing.py | 2 +- 2 files changed, 10 insertions(+), 35 deletions(-) diff --git a/modules/masking.py b/modules/masking.py index be9f84c76f6..29a3945278e 100644 --- a/modules/masking.py +++ b/modules/masking.py @@ -3,40 +3,15 @@ def get_crop_region(mask, pad=0): """finds a rectangular region that contains all masked ares in an image. Returns (x1, y1, x2, y2) coordinates of the rectangle. - For example, if a user has painted the top-right part of a 512x512 image", the result may be (256, 0, 512, 256)""" - - h, w = mask.shape - - crop_left = 0 - for i in range(w): - if not (mask[:, i] == 0).all(): - break - crop_left += 1 - - crop_right = 0 - for i in reversed(range(w)): - if not (mask[:, i] == 0).all(): - break - crop_right += 1 - - crop_top = 0 - for i in range(h): - if not (mask[i] == 0).all(): - break - crop_top += 1 - - crop_bottom = 0 - for i in reversed(range(h)): - if not (mask[i] == 0).all(): - break - crop_bottom += 1 - - return ( - int(max(crop_left-pad, 0)), - int(max(crop_top-pad, 0)), - int(min(w - crop_right + pad, w)), - int(min(h - crop_bottom + pad, h)) - ) + For example, if a user has painted the top-right part of a 512x512 image, the result may be (256, 0, 512, 256)""" + mask_img = mask if isinstance(mask, Image.Image) else Image.fromarray(mask) + box = mask_img.getbbox() + if box: + x1, y1, x2, y2 = box + else: # when no box is found + x1, y1 = mask_img.size + x2 = y2 = 0 + return max(x1 - pad, 0), max(y1 - pad, 0), min(x2 + pad, mask_img.size[0]), min(y2 + pad, mask_img.size[1]) def expand_crop_region(crop_region, processing_width, processing_height, image_width, image_height): diff --git a/modules/processing.py b/modules/processing.py index 6b63179512e..72d8093ba32 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -1562,7 +1562,7 @@ def init(self, all_prompts, all_seeds, all_subseeds): if self.inpaint_full_res: self.mask_for_overlay = image_mask mask = image_mask.convert('L') - crop_region = masking.get_crop_region(np.array(mask), self.inpaint_full_res_padding) + crop_region = masking.get_crop_region(mask, self.inpaint_full_res_padding) crop_region = masking.expand_crop_region(crop_region, self.width, self.height, mask.width, mask.height) x1, y1, x2, y2 = crop_region From 2974b9cee94dc474ffbc9e9617d14c9aaf9e1e63 Mon Sep 17 00:00:00 2001 From: Stefan Benten Date: Sun, 21 Jan 2024 14:05:47 +0100 Subject: [PATCH 0628/1174] modules/api/api.py: add api endpoint to refresh embeddings list --- modules/api/api.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/modules/api/api.py b/modules/api/api.py index b3d74e513a3..b6bb9d06ab1 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -230,6 +230,7 @@ def __init__(self, app: FastAPI, queue_lock: Lock): self.add_api_route("/sdapi/v1/realesrgan-models", self.get_realesrgan_models, methods=["GET"], response_model=list[models.RealesrganItem]) self.add_api_route("/sdapi/v1/prompt-styles", self.get_prompt_styles, methods=["GET"], response_model=list[models.PromptStyleItem]) self.add_api_route("/sdapi/v1/embeddings", self.get_embeddings, methods=["GET"], response_model=models.EmbeddingsResponse) + self.add_api_route("/sdapi/v1/refresh-embeddings", self.refresh_embeddings, methods=["POST"]) self.add_api_route("/sdapi/v1/refresh-checkpoints", self.refresh_checkpoints, methods=["POST"]) self.add_api_route("/sdapi/v1/refresh-vae", self.refresh_vae, methods=["POST"]) self.add_api_route("/sdapi/v1/create/embedding", self.create_embedding, methods=["POST"], response_model=models.CreateResponse) @@ -643,6 +644,10 @@ def convert_embeddings(embeddings): "skipped": convert_embeddings(db.skipped_embeddings), } + def refresh_embeddings(self): + with self.queue_lock: + sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True) + def refresh_checkpoints(self): with self.queue_lock: shared.refresh_checkpoints() From d7d3166a2749fe04f0ba1d8d0f88c8e8819a379d Mon Sep 17 00:00:00 2001 From: Sj-Si Date: Sun, 21 Jan 2024 11:27:24 -0500 Subject: [PATCH 0629/1174] Fix broken scrollbars --- html/extra-networks-tree.html | 4 +--- style.css | 20 +++++++------------- 2 files changed, 8 insertions(+), 16 deletions(-) diff --git a/html/extra-networks-tree.html b/html/extra-networks-tree.html index 39649e86093..beec888cd9d 100644 --- a/html/extra-networks-tree.html +++ b/html/extra-networks-tree.html @@ -1,5 +1,3 @@
                  -
                  - {tree} -
                  + {tree}
                  \ No newline at end of file diff --git a/style.css b/style.css index f3fd1571bb1..ff1d907226a 100644 --- a/style.css +++ b/style.css @@ -1179,20 +1179,20 @@ body.resizing .resize-handle { resize: vertical; min-height: 52rem; flex-direction: column; + overflow: hidden; } .extra-network-pane .extra-network-pane-content { display: flex; flex: 1; - flex-direction: row; + overflow: hidden; } .extra-network-pane .extra-network-tree { flex: 1; - flex-direction: column; - display: flex; font-size: 1rem; border: 1px solid var(--block-border-color); + overflow: clip auto !important; } .extra-network-pane .extra-network-cards { @@ -1210,34 +1210,28 @@ body.resizing .resize-handle { overflow: hidden; } -.extra-network-pane .extra-network-tree .tree-list .tree-list-container { - flex: 1; - overflow: clip auto !important; - width: 100%; -} - .extra-network-pane .extra-network-cards::-webkit-scrollbar, -.extra-network-pane .tree-list-container::-webkit-scrollbar { +.extra-network-pane .extra-network-tree::-webkit-scrollbar { background-color: transparent; width: 16px; } .extra-network-pane .extra-network-cards::-webkit-scrollbar-track, -.extra-network-pane .tree-list-container::-webkit-scrollbar-track { +.extra-network-pane .extra-network-tree::-webkit-scrollbar-track { background-color: transparent; background-clip: content-box; } .extra-network-pane .extra-network-cards::-webkit-scrollbar-thumb, -.extra-network-pane .tree-list-container::-webkit-scrollbar-thumb { +.extra-network-pane .extra-network-tree::-webkit-scrollbar-thumb { background-color: var(--border-color-primary); border-radius: 16px; border: 4px solid var(--background-fill-primary); } .extra-network-pane .extra-network-cards::-webkit-scrollbar-button, -.extra-network-pane .tree-list-container::-webkit-scrollbar-button { +.extra-network-pane .extra-network-tree::-webkit-scrollbar-button { display: none; } From 26e1cd7ec47c8d234d2ea3f189b1147329c9059c Mon Sep 17 00:00:00 2001 From: Sj-Si Date: Sun, 21 Jan 2024 11:34:08 -0500 Subject: [PATCH 0630/1174] Remove unnecessary template and simplify tree list. --- html/extra-networks-tree.html | 3 --- modules/ui_extra_networks.py | 11 +---------- 2 files changed, 1 insertion(+), 13 deletions(-) delete mode 100644 html/extra-networks-tree.html diff --git a/html/extra-networks-tree.html b/html/extra-networks-tree.html deleted file mode 100644 index beec888cd9d..00000000000 --- a/html/extra-networks-tree.html +++ /dev/null @@ -1,3 +0,0 @@ -
                  - {tree} -
                  \ No newline at end of file diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index 80160b847b6..157b3a6d4a9 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -164,7 +164,6 @@ def __init__(self, title): self.lister = util.MassFileLister() # HTML Templates self.pane_tpl = shared.html("extra-networks-pane.html") - self.tree_tpl = shared.html("extra-networks-tree.html") self.card_tpl = shared.html("extra-networks-card.html") self.btn_tree_tpl = shared.html("extra-networks-tree-button.html") self.btn_copy_path_tpl = shared.html("extra-networks-copy-path-button.html") @@ -420,8 +419,6 @@ def create_tree_file_item_html(self, tabname: str, file_path: str, item: dict) - def create_tree_view_html(self, tabname: str) -> str: """Generates HTML for displaying folders in a tree view. - The generated HTML uses `extra-networks-tree.html` as a template. - Args: tabname: The name of the active tab. @@ -473,13 +470,7 @@ def _build_tree(data: Optional[dict[str, ExtraNetworksItem]] = None) -> Optional if item_html is not None: res += item_html - return self.tree_tpl.format( - **{ - "tabname": tabname, - "extra_networks_tabname": self.extra_networks_tabname, - "tree": f"
                    {res}
                  " - } - ) + return f"
                    {res}
                  " def create_card_view_html(self, tabname: str) -> str: """Generates HTML for the network Card View section for a tab. From fd383140cf405100f3c619f106472273a7545beb Mon Sep 17 00:00:00 2001 From: v0xie <28695009+v0xie@users.noreply.github.com> Date: Mon, 22 Jan 2024 02:52:34 -0800 Subject: [PATCH 0631/1174] fix: wrong devices for eye and constraint --- extensions-builtin/Lora/network_oft.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/extensions-builtin/Lora/network_oft.py b/extensions-builtin/Lora/network_oft.py index 342fcd0dcaf..d1c46a4b22e 100644 --- a/extensions-builtin/Lora/network_oft.py +++ b/extensions-builtin/Lora/network_oft.py @@ -57,12 +57,12 @@ def __init__(self, net: network.Network, weights: network.NetworkWeights): def calc_updown(self, orig_weight): oft_blocks = self.oft_blocks.to(orig_weight.device) - eye = torch.eye(self.block_size, device=self.oft_blocks.device) + eye = torch.eye(self.block_size, device=oft_blocks.device) if self.is_kohya: block_Q = oft_blocks - oft_blocks.transpose(1, 2) # ensure skew-symmetric orthogonal matrix norm_Q = torch.norm(block_Q.flatten()) - new_norm_Q = torch.clamp(norm_Q, max=self.constraint) + new_norm_Q = torch.clamp(norm_Q, max=self.constraint.to(oft_blocks.device)) block_Q = block_Q * ((new_norm_Q + 1e-8) / (norm_Q + 1e-8)) oft_blocks = torch.matmul(eye + block_Q, (eye - block_Q).float().inverse()) From f4e931f18fa4f94aece1f4dabd4dd0d635ecec13 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Mon, 22 Jan 2024 23:20:30 +0300 Subject: [PATCH 0632/1174] put extra networks controls row into the tabs UI element for #14588 --- html/extra-networks-pane.html | 2 +- javascript/extraNetworks.js | 24 +++++++++++++++-- modules/ui.py | 4 +-- modules/ui_extra_networks.py | 2 +- style.css | 51 +++++++++++++++++++---------------- 5 files changed, 54 insertions(+), 29 deletions(-) diff --git a/html/extra-networks-pane.html b/html/extra-networks-pane.html index 9f5b3ecee62..0c763f71058 100644 --- a/html/extra-networks-pane.html +++ b/html/extra-networks-pane.html @@ -1,5 +1,5 @@
                  -
                  +