diff --git a/collie/controller/trainer.py b/collie/controller/trainer.py index 9c082ce..8b914c0 100644 --- a/collie/controller/trainer.py +++ b/collie/controller/trainer.py @@ -237,6 +237,10 @@ def __init__( self.on_after_trainer_initialized() torch.cuda.empty_cache() + for name, layer in self.model.named_modules(): + if hasattr(layer, 'set_norm_precision_to_float32'): + layer.set_norm_precision_to_float32() + def init_state_dict(self): """初始化优化器的自身状态字典""" self.epoch_idx = 0 diff --git a/collie/models/base.py b/collie/models/base.py index 247c8a7..3a8796d 100644 --- a/collie/models/base.py +++ b/collie/models/base.py @@ -337,7 +337,6 @@ def from_pretrained(cls, model_path_or_name: str, config: Union[CollieConfig, st if name in state_dict: assert param.data.shape == state_dict[name].data.shape, f"The shape of the parameter corresponding to the `{name}` does not match: {param.data.shape} vs {state_dict[name].data.shape}" param.data = value.to(param.device) - if config.peft_config.peft_type is not None: model = get_peft_model(model, config.peft_config) diff --git a/collie/models/internlm/model.py b/collie/models/internlm/model.py index b674317..e7cb347 100644 --- a/collie/models/internlm/model.py +++ b/collie/models/internlm/model.py @@ -80,12 +80,16 @@ def __init__(self, dim=None, dtype=torch.float, eps=1e-5, weight=None): ) self.eps = eps + def set_norm_precision_to_float32(self): + self.weight.data = self.weight.data.to(torch.float32) + def forward(self, hidden_states): + input_dtype = hidden_states.dtype variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) hidden_states = hidden_states * torch.rsqrt(variance + self.eps) if self.weight.dtype in [torch.float16, torch.bfloat16]: hidden_states = hidden_states.to(self.weight.dtype) - return hidden_states * self.weight + return (hidden_states * self.weight).to(input_dtype) class InternLMLayer(nn.Module): diff --git a/collie/models/internlm2/model.py b/collie/models/internlm2/model.py index 53b1d8d..73535cc 100644 --- a/collie/models/internlm2/model.py +++ b/collie/models/internlm2/model.py @@ -122,12 +122,15 @@ def __init__(self, hidden_size, eps=1e-6): self.weight = nn.Parameter(torch.ones(hidden_size)) self.variance_epsilon = eps + def set_norm_precision_to_float32(self): + self.weight.data = self.weight.data.to(torch.float32) + def forward(self, hidden_states): input_dtype = hidden_states.dtype hidden_states = hidden_states.to(torch.float32) variance = hidden_states.pow(2).mean(-1, keepdim=True) hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - return self.weight * hidden_states.to(input_dtype) + return (self.weight * hidden_states).to(input_dtype) # Copied from transformers.model.llama.modeling_llama.LlamaRotaryEmbedding with Llama->InternLM2 diff --git a/collie/models/llama/model.py b/collie/models/llama/model.py index 067d48c..97a4efd 100644 --- a/collie/models/llama/model.py +++ b/collie/models/llama/model.py @@ -80,12 +80,16 @@ def __init__(self, dim=None, dtype=torch.float, eps=1e-5, weight=None): ) self.eps = eps + def set_norm_precision_to_float32(self): + self.weight.data = self.weight.data.to(torch.float32) + def forward(self, hidden_states): + input_dtype = hidden_states.dtype variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) hidden_states = hidden_states * torch.rsqrt(variance + self.eps) if self.weight.dtype in [torch.float16, torch.bfloat16]: hidden_states = hidden_states.to(self.weight.dtype) - return hidden_states * self.weight + return (hidden_states * self.weight).to(input_dtype) def post_init(self): """ diff --git a/collie/models/moss/model.py b/collie/models/moss/model.py index 5a87bf7..da18d02 100644 --- a/collie/models/moss/model.py +++ b/collie/models/moss/model.py @@ -80,12 +80,16 @@ def __init__(self, dim=None, dtype=torch.float, eps=1e-5, weight=None): ) self.eps = eps + def set_norm_precision_to_float32(self): + self.weight.data = self.weight.data.to(torch.float32) + def forward(self, hidden_states): + input_dtype = hidden_states.dtype variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) hidden_states = hidden_states * torch.rsqrt(variance + self.eps) if self.weight.dtype in [torch.float16, torch.bfloat16]: hidden_states = hidden_states.to(self.weight.dtype) - return hidden_states * self.weight + return (hidden_states * self.weight).to(input_dtype) class MossBlock(nn.Module): diff --git a/collie/models/moss_moon/model.py b/collie/models/moss_moon/model.py index f8e3b15..763b941 100644 --- a/collie/models/moss_moon/model.py +++ b/collie/models/moss_moon/model.py @@ -271,6 +271,10 @@ def __init__(self, config, layer_idx): self.use_cache = False self.hidden_states = None + def set_norm_precision_to_float32(self): + self.ln_1.weight.data = self.ln_1.weight.data.to(torch.float32) + self.ln_1.bias.data = self.ln_1.bias.data.to(torch.float32) + def _forward( self, hidden_states: Optional[torch.FloatTensor], @@ -284,7 +288,8 @@ def _forward( Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]], ]: residual = hidden_states - hidden_states = self.ln_1(hidden_states) + input_dtype = hidden_states.dtype + hidden_states = self.ln_1(hidden_states.to(torch.float32)).to(input_dtype) attn_outputs = self.attn( hidden_states=hidden_states, layer_past=layer_past, @@ -394,6 +399,10 @@ def __init__(self, config): self.h = nn.ModuleList([MossBlock(config, i) for i in range(config.n_layer)]) self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) + def set_norm_precision_to_float32(self): + self.ln_f.weight.data = self.ln_f.weight.data.to(torch.float32) + self.ln_f.bias.data = self.ln_f.bias.data.to(torch.float32) + def forward( self, input_ids, @@ -425,7 +434,8 @@ def forward( all_hidden_states += (input_dict["hidden_states"],) input_dict.update(l(input_dict)) - hidden_states = self.ln_f(input_dict["hidden_states"]) + input_dtype = input_dict["hidden_states"].dtype + hidden_states = self.ln_f(input_dict["hidden_states"].to(torch.float32)).to(input_dtype) all_hidden_states += (hidden_states,) past_key_values = None