Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions mergekit/architecture/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,15 @@ def get_architecture_info(
raise RuntimeError(
"Must specify --allow-crimes to attempt to mix different architectures"
)
# Prefer using the base_model's architecture when mixing different families
# to ensure the output layout matches the base.
if config.base_model is not None:
try:
idx = models.index(config.base_model)
return model_arch_info[idx]
except ValueError:
# base_model not in referenced models; fall back to first
pass
return model_arch_info[0]

# try to infer from all models
Expand Down
1 change: 1 addition & 0 deletions mergekit/io/tensor_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def __init__(
use_async: bool = False,
max_write_threads: int = 1,
) -> None:
out_path = os.path.abspath(os.path.expanduser(out_path))
os.makedirs(out_path, exist_ok=True)

self.out_path = out_path
Expand Down
14 changes: 11 additions & 3 deletions mergekit/merge_methods/slerp.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
class SlerpTask(Task[torch.Tensor]):
gather_tensors: MergeTensorInput
base_model: ModelReference
t: float
t: Optional[float]
weight_info: WeightInfo

def uses_accelerator(self) -> bool:
Expand All @@ -38,6 +38,12 @@ def execute(self, tensors: Dict[ModelReference, torch.Tensor]) -> torch.Tensor:
elif self.base_model not in tensors:
raise RuntimeError("Base model not in input tensors")

# If no interpolation parameter was provided for this tensor, do not attempt to merge;
# simply return the base model's weight unchanged. This avoids shape/broadcast errors
# when the secondary model has incompatible tensor shapes.
if self.t is None:
return tensors[self.base_model]

[a, b] = list(tensors.items())
if a[0] != self.base_model:
[a, b] = [b, a]
Expand Down Expand Up @@ -72,7 +78,7 @@ def reference_url(self):
return "https://en.wikipedia.org/wiki/Slerp"

def parameters(self) -> List[ConfigParameterDef]:
return [ConfigParameterDef(name="t", required=True)]
return [ConfigParameterDef(name="t", required=False, default_value=None)]

def make_task(
self,
Expand All @@ -92,7 +98,9 @@ def make_task(


def lerp(
t: float, v0: Union[np.ndarray, torch.Tensor], v1: Union[np.ndarray, torch.Tensor]
t: Optional[float],
v0: Union[np.ndarray, torch.Tensor],
v1: Union[np.ndarray, torch.Tensor],
) -> Union[np.ndarray, torch.Tensor]:
return (1 - t) * v0 + t * v1

Expand Down
9 changes: 7 additions & 2 deletions mergekit/tokenizer/embed.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,9 @@ def assign_embedding_sources(
continue

if num_present == 0:
token_configs[token] = TokenEmbeddingConfig(source=ZeroEmbedding())
token_configs[token] = TokenEmbeddingConfig(
source=ZeroEmbedding(kind="zero")
)
logging.warning(f"Token {repr(token)} not found in any model")
continue

Expand All @@ -152,7 +154,10 @@ def compute_default_embedding(
cfg: TokenEmbeddingConfig,
) -> torch.Tensor:
if isinstance(cfg.source, ZeroEmbedding):
pass
any_tensor = next(iter(tensors.values()))
embed = torch.zeros(
any_tensor.shape[1], dtype=any_tensor.dtype, device=any_tensor.device
)
elif isinstance(cfg.source, ModelTokenEmbedding):
model = cfg.source.model
assert (
Expand Down
Loading