diff --git a/torchao/_models/llama/model.py b/torchao/_models/llama/model.py index 46d8c2484a..1388db2ee8 100644 --- a/torchao/_models/llama/model.py +++ b/torchao/_models/llama/model.py @@ -478,12 +478,24 @@ def forward( class FeedForward(nn.Module): def __init__(self, config: ModelArgs) -> None: super().__init__() - self.w1 = nn.Linear(config.dim, config.intermediate_size, bias=False) - self.w3 = nn.Linear(config.dim, config.intermediate_size, bias=False) self.w2 = nn.Linear(config.intermediate_size, config.dim, bias=False) + self.w13 = nn.Linear(config.dim, 2 * config.intermediate_size, bias=False) + self._register_load_state_dict_pre_hook(self._merge_w1_w3) + + @staticmethod + def _merge_w1_w3(state_dict, prefix, *args): + w1_key = prefix + "w1.weight" + w3_key = prefix + "w3.weight" + if w1_key in state_dict and w3_key in state_dict: + w1 = state_dict.pop(w1_key) + w3 = state_dict.pop(w3_key) + w13_key = prefix + "w13.weight" + state_dict[w13_key] = torch.cat([w1, w3]) + def forward(self, x: Tensor) -> Tensor: - return self.w2(F.silu(self.w1(x)) * self.w3(x)) + x1, x3 = self.w13(x).chunk(2, dim=-1) + return self.w2(F.silu(x1) * x3) class RMSNorm(nn.Module):