Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions scripts/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from sd_meh.merge import NUM_TOTAL_BLOCKS, merge_methods, merge_models

MEMORY_DESTINATION = "memory"
persistent_cache: Optional[Dict] = None


def on_app_started(_gui: Optional[gr.Blocks], api: fastapi.FastAPI):
Expand Down Expand Up @@ -51,6 +52,7 @@ async def merge_models_api(
title="Number of threads",
description="Number of keys to merge simultaneously. Only useful with device='cpu'",
),
cache: bool = fastapi.Body(False, title="Cache intermediate merge values"),
):
validate_merge_method(merge_method)
alpha, beta, input_models, weights, bases = normalize_merge_args(
Expand All @@ -63,6 +65,10 @@ async def merge_models_api(
model_c,
)

global persistent_cache
if cache and persistent_cache is None:
persistent_cache = {}

model_a_info = get_checkpoint_info(Path(model_a))
load_in_memory = destination == MEMORY_DESTINATION
if not load_in_memory:
Expand All @@ -88,9 +94,9 @@ async def merge_models_api(
work_device=work_device,
prune=prune,
threads=threads,
cache=persistent_cache,
)
if not isinstance(merged, dict):
merged = merged.to_dict()
merged = merged.to_dict()

if load_in_memory:
sd_models.load_model(model_a_info, merged)
Expand Down
1 change: 1 addition & 0 deletions sd_webui_bayesian_merger/merger.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ def merge(
"unload_before": True,
"re_basin": self.cfg.rebasin,
"re_basin_iterations": self.cfg.rebasin_iterations,
"cache": self.cfg.cache_merge,
}

print("Merging models")
Expand Down