From 2a4336fbb20438abaaa66196dace7e3e69467dc7 Mon Sep 17 00:00:00 2001 From: ljleb Date: Sat, 27 Jan 2024 13:32:19 -0500 Subject: [PATCH] cache stuff --- scripts/api.py | 10 ++++++++-- sd_webui_bayesian_merger/merger.py | 1 + 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/scripts/api.py b/scripts/api.py index 0f2e371..3edaaeb 100644 --- a/scripts/api.py +++ b/scripts/api.py @@ -11,6 +11,7 @@ from sd_meh.merge import NUM_TOTAL_BLOCKS, merge_methods, merge_models MEMORY_DESTINATION = "memory" +persistent_cache: Optional[Dict] = None def on_app_started(_gui: Optional[gr.Blocks], api: fastapi.FastAPI): @@ -51,6 +52,7 @@ async def merge_models_api( title="Number of threads", description="Number of keys to merge simultaneously. Only useful with device='cpu'", ), + cache: bool = fastapi.Body(False, title="Cache intermediate merge values"), ): validate_merge_method(merge_method) alpha, beta, input_models, weights, bases = normalize_merge_args( @@ -63,6 +65,10 @@ async def merge_models_api( model_c, ) + global persistent_cache + if cache and persistent_cache is None: + persistent_cache = {} + model_a_info = get_checkpoint_info(Path(model_a)) load_in_memory = destination == MEMORY_DESTINATION if not load_in_memory: @@ -88,9 +94,9 @@ async def merge_models_api( work_device=work_device, prune=prune, threads=threads, + cache=persistent_cache, ) - if not isinstance(merged, dict): - merged = merged.to_dict() + merged = merged.to_dict() if load_in_memory: sd_models.load_model(model_a_info, merged) diff --git a/sd_webui_bayesian_merger/merger.py b/sd_webui_bayesian_merger/merger.py index 37d963d..78f8130 100644 --- a/sd_webui_bayesian_merger/merger.py +++ b/sd_webui_bayesian_merger/merger.py @@ -108,6 +108,7 @@ def merge( "unload_before": True, "re_basin": self.cfg.rebasin, "re_basin_iterations": self.cfg.rebasin_iterations, + "cache": self.cfg.cache_merge, } print("Merging models")