From 3654aa3e685e2bdc820c45d1d164004034fc9958 Mon Sep 17 00:00:00 2001 From: ikhyunAn Date: Tue, 4 Nov 2025 13:36:54 -0500 Subject: [PATCH 1/2] feat: support cross-model family merge --- mergekit/architecture/__init__.py | 9 +++++++++ mergekit/io/tensor_writer.py | 1 + mergekit/merge_methods/slerp.py | 12 +++++++++--- mergekit/tokenizer/embed.py | 5 +++-- 4 files changed, 22 insertions(+), 5 deletions(-) diff --git a/mergekit/architecture/__init__.py b/mergekit/architecture/__init__.py index f97fd9db..b39260eb 100644 --- a/mergekit/architecture/__init__.py +++ b/mergekit/architecture/__init__.py @@ -86,6 +86,15 @@ def get_architecture_info( raise RuntimeError( "Must specify --allow-crimes to attempt to mix different architectures" ) + # Prefer using the base_model's architecture when mixing different families + # to ensure the output layout matches the base. + if config.base_model is not None: + try: + idx = models.index(config.base_model) + return model_arch_info[idx] + except ValueError: + # base_model not in referenced models; fall back to first + pass return model_arch_info[0] # try to infer from all models diff --git a/mergekit/io/tensor_writer.py b/mergekit/io/tensor_writer.py index f75ca4c5..bf537b69 100644 --- a/mergekit/io/tensor_writer.py +++ b/mergekit/io/tensor_writer.py @@ -39,6 +39,7 @@ def __init__( use_async: bool = False, max_write_threads: int = 1, ) -> None: + out_path = os.path.abspath(os.path.expanduser(out_path)) os.makedirs(out_path, exist_ok=True) self.out_path = out_path diff --git a/mergekit/merge_methods/slerp.py b/mergekit/merge_methods/slerp.py index b48d0e4d..3793e09e 100644 --- a/mergekit/merge_methods/slerp.py +++ b/mergekit/merge_methods/slerp.py @@ -21,7 +21,7 @@ class SlerpTask(Task[torch.Tensor]): gather_tensors: MergeTensorInput base_model: ModelReference - t: float + t: Optional[float] weight_info: WeightInfo def uses_accelerator(self) -> bool: @@ -38,6 +38,12 @@ def execute(self, tensors: Dict[ModelReference, torch.Tensor]) -> torch.Tensor: elif self.base_model not in tensors: raise RuntimeError("Base model not in input tensors") + # If no interpolation parameter was provided for this tensor, do not attempt to merge; + # simply return the base model's weight unchanged. This avoids shape/broadcast errors + # when the secondary model has incompatible tensor shapes. + if self.t is None: + return tensors[self.base_model] + [a, b] = list(tensors.items()) if a[0] != self.base_model: [a, b] = [b, a] @@ -72,7 +78,7 @@ def reference_url(self): return "https://en.wikipedia.org/wiki/Slerp" def parameters(self) -> List[ConfigParameterDef]: - return [ConfigParameterDef(name="t", required=True)] + return [ConfigParameterDef(name="t", required=False, default_value=None)] def make_task( self, @@ -92,7 +98,7 @@ def make_task( def lerp( - t: float, v0: Union[np.ndarray, torch.Tensor], v1: Union[np.ndarray, torch.Tensor] + t: Optional[float], v0: Union[np.ndarray, torch.Tensor], v1: Union[np.ndarray, torch.Tensor] ) -> Union[np.ndarray, torch.Tensor]: return (1 - t) * v0 + t * v1 diff --git a/mergekit/tokenizer/embed.py b/mergekit/tokenizer/embed.py index aa0dcb89..e6379fb6 100644 --- a/mergekit/tokenizer/embed.py +++ b/mergekit/tokenizer/embed.py @@ -130,7 +130,7 @@ def assign_embedding_sources( continue if num_present == 0: - token_configs[token] = TokenEmbeddingConfig(source=ZeroEmbedding()) + token_configs[token] = TokenEmbeddingConfig(source=ZeroEmbedding(kind="zero")) logging.warning(f"Token {repr(token)} not found in any model") continue @@ -152,7 +152,8 @@ def compute_default_embedding( cfg: TokenEmbeddingConfig, ) -> torch.Tensor: if isinstance(cfg.source, ZeroEmbedding): - pass + any_tensor = next(iter(tensors.values())) + embed = torch.zeros(any_tensor.shape[1], dtype=any_tensor.dtype, device=any_tensor.device) elif isinstance(cfg.source, ModelTokenEmbedding): model = cfg.source.model assert ( From e85a454f39ec3669d9ed37d9455dfdb61bbb3cf4 Mon Sep 17 00:00:00 2001 From: ikhyunAn Date: Tue, 4 Nov 2025 14:20:00 -0500 Subject: [PATCH 2/2] fix: pre-commit hook --- mergekit/merge_methods/slerp.py | 4 +++- mergekit/tokenizer/embed.py | 8 ++++++-- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/mergekit/merge_methods/slerp.py b/mergekit/merge_methods/slerp.py index 3793e09e..db36a748 100644 --- a/mergekit/merge_methods/slerp.py +++ b/mergekit/merge_methods/slerp.py @@ -98,7 +98,9 @@ def make_task( def lerp( - t: Optional[float], v0: Union[np.ndarray, torch.Tensor], v1: Union[np.ndarray, torch.Tensor] + t: Optional[float], + v0: Union[np.ndarray, torch.Tensor], + v1: Union[np.ndarray, torch.Tensor], ) -> Union[np.ndarray, torch.Tensor]: return (1 - t) * v0 + t * v1 diff --git a/mergekit/tokenizer/embed.py b/mergekit/tokenizer/embed.py index e6379fb6..5755fdc6 100644 --- a/mergekit/tokenizer/embed.py +++ b/mergekit/tokenizer/embed.py @@ -130,7 +130,9 @@ def assign_embedding_sources( continue if num_present == 0: - token_configs[token] = TokenEmbeddingConfig(source=ZeroEmbedding(kind="zero")) + token_configs[token] = TokenEmbeddingConfig( + source=ZeroEmbedding(kind="zero") + ) logging.warning(f"Token {repr(token)} not found in any model") continue @@ -153,7 +155,9 @@ def compute_default_embedding( ) -> torch.Tensor: if isinstance(cfg.source, ZeroEmbedding): any_tensor = next(iter(tensors.values())) - embed = torch.zeros(any_tensor.shape[1], dtype=any_tensor.dtype, device=any_tensor.device) + embed = torch.zeros( + any_tensor.shape[1], dtype=any_tensor.dtype, device=any_tensor.device + ) elif isinstance(cfg.source, ModelTokenEmbedding): model = cfg.source.model assert (