From 2cecc4973c19517d6a979987cd395417489ebf93 Mon Sep 17 00:00:00 2001 From: ytgong <1145028706@qq.com> Date: Thu, 23 May 2024 15:07:15 +0800 Subject: [PATCH 1/2] fix(model.py): align_preciion_in_norm_layer --- collie/models/base.py | 5 ++++- collie/models/internlm/model.py | 3 +++ collie/models/internlm2/model.py | 3 +++ collie/models/llama/model.py | 3 +++ collie/models/moss/model.py | 3 +++ collie/models/moss_moon/model.py | 8 ++++++++ 6 files changed, 24 insertions(+), 1 deletion(-) diff --git a/collie/models/base.py b/collie/models/base.py index 247c8a7..3507a77 100644 --- a/collie/models/base.py +++ b/collie/models/base.py @@ -337,7 +337,10 @@ def from_pretrained(cls, model_path_or_name: str, config: Union[CollieConfig, st if name in state_dict: assert param.data.shape == state_dict[name].data.shape, f"The shape of the parameter corresponding to the `{name}` does not match: {param.data.shape} vs {state_dict[name].data.shape}" param.data = value.to(param.device) - + + for name, layer in model.named_modules(): + if hasattr(layer, 'set_norm_precision_to_float32'): + layer.set_norm_precision_to_float32() if config.peft_config.peft_type is not None: model = get_peft_model(model, config.peft_config) diff --git a/collie/models/internlm/model.py b/collie/models/internlm/model.py index b674317..192a4c6 100644 --- a/collie/models/internlm/model.py +++ b/collie/models/internlm/model.py @@ -80,6 +80,9 @@ def __init__(self, dim=None, dtype=torch.float, eps=1e-5, weight=None): ) self.eps = eps + def set_norm_precision_to_float32(self): + self.weight.data = self.weight.data.to(torch.float32) + def forward(self, hidden_states): variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) hidden_states = hidden_states * torch.rsqrt(variance + self.eps) diff --git a/collie/models/internlm2/model.py b/collie/models/internlm2/model.py index 53b1d8d..af226b2 100644 --- a/collie/models/internlm2/model.py +++ b/collie/models/internlm2/model.py @@ -122,6 +122,9 @@ def __init__(self, hidden_size, eps=1e-6): self.weight = nn.Parameter(torch.ones(hidden_size)) self.variance_epsilon = eps + def set_norm_precision_to_float32(self): + self.weight.data = self.weight.data.to(torch.float32) + def forward(self, hidden_states): input_dtype = hidden_states.dtype hidden_states = hidden_states.to(torch.float32) diff --git a/collie/models/llama/model.py b/collie/models/llama/model.py index 067d48c..a7d91e7 100644 --- a/collie/models/llama/model.py +++ b/collie/models/llama/model.py @@ -80,6 +80,9 @@ def __init__(self, dim=None, dtype=torch.float, eps=1e-5, weight=None): ) self.eps = eps + def set_norm_precision_to_float32(self): + self.weight.data = self.weight.data.to(torch.float32) + def forward(self, hidden_states): variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) hidden_states = hidden_states * torch.rsqrt(variance + self.eps) diff --git a/collie/models/moss/model.py b/collie/models/moss/model.py index 5a87bf7..f28aeda 100644 --- a/collie/models/moss/model.py +++ b/collie/models/moss/model.py @@ -80,6 +80,9 @@ def __init__(self, dim=None, dtype=torch.float, eps=1e-5, weight=None): ) self.eps = eps + def set_norm_precision_to_float32(self): + self.weight.data = self.weight.data.to(torch.float32) + def forward(self, hidden_states): variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) hidden_states = hidden_states * torch.rsqrt(variance + self.eps) diff --git a/collie/models/moss_moon/model.py b/collie/models/moss_moon/model.py index f8e3b15..63afce0 100644 --- a/collie/models/moss_moon/model.py +++ b/collie/models/moss_moon/model.py @@ -271,6 +271,10 @@ def __init__(self, config, layer_idx): self.use_cache = False self.hidden_states = None + def set_norm_precision_to_float32(self): + self.ln_1.weight.data = self.ln_1.weight.data.to(torch.float32) + self.ln_1.bias.data = self.ln_1.bias.data.to(torch.float32) + def _forward( self, hidden_states: Optional[torch.FloatTensor], @@ -394,6 +398,10 @@ def __init__(self, config): self.h = nn.ModuleList([MossBlock(config, i) for i in range(config.n_layer)]) self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) + def set_norm_precision_to_float32(self): + self.ln_f.weight.data = self.ln_f.weight.data.to(torch.float32) + self.ln_f.bias.data = self.ln_f.bias.data.to(torch.float32) + def forward( self, input_ids, From 2b95d69057b002a68f8a076fab3d83f1affd856c Mon Sep 17 00:00:00 2001 From: ytgong <1145028706@qq.com> Date: Tue, 28 May 2024 10:42:21 +0800 Subject: [PATCH 2/2] fix(model.py): align_precision_in_norm_layer --- collie/controller/trainer.py | 4 ++++ collie/models/base.py | 4 ---- collie/models/internlm/model.py | 3 ++- collie/models/internlm2/model.py | 2 +- collie/models/llama/model.py | 3 ++- collie/models/moss/model.py | 3 ++- collie/models/moss_moon/model.py | 6 ++++-- 7 files changed, 15 insertions(+), 10 deletions(-) diff --git a/collie/controller/trainer.py b/collie/controller/trainer.py index 9c082ce..8b914c0 100644 --- a/collie/controller/trainer.py +++ b/collie/controller/trainer.py @@ -237,6 +237,10 @@ def __init__( self.on_after_trainer_initialized() torch.cuda.empty_cache() + for name, layer in self.model.named_modules(): + if hasattr(layer, 'set_norm_precision_to_float32'): + layer.set_norm_precision_to_float32() + def init_state_dict(self): """初始化优化器的自身状态字典""" self.epoch_idx = 0 diff --git a/collie/models/base.py b/collie/models/base.py index 3507a77..3a8796d 100644 --- a/collie/models/base.py +++ b/collie/models/base.py @@ -338,10 +338,6 @@ def from_pretrained(cls, model_path_or_name: str, config: Union[CollieConfig, st assert param.data.shape == state_dict[name].data.shape, f"The shape of the parameter corresponding to the `{name}` does not match: {param.data.shape} vs {state_dict[name].data.shape}" param.data = value.to(param.device) - for name, layer in model.named_modules(): - if hasattr(layer, 'set_norm_precision_to_float32'): - layer.set_norm_precision_to_float32() - if config.peft_config.peft_type is not None: model = get_peft_model(model, config.peft_config) model.print_trainable_parameters() diff --git a/collie/models/internlm/model.py b/collie/models/internlm/model.py index 192a4c6..e7cb347 100644 --- a/collie/models/internlm/model.py +++ b/collie/models/internlm/model.py @@ -84,11 +84,12 @@ def set_norm_precision_to_float32(self): self.weight.data = self.weight.data.to(torch.float32) def forward(self, hidden_states): + input_dtype = hidden_states.dtype variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) hidden_states = hidden_states * torch.rsqrt(variance + self.eps) if self.weight.dtype in [torch.float16, torch.bfloat16]: hidden_states = hidden_states.to(self.weight.dtype) - return hidden_states * self.weight + return (hidden_states * self.weight).to(input_dtype) class InternLMLayer(nn.Module): diff --git a/collie/models/internlm2/model.py b/collie/models/internlm2/model.py index af226b2..73535cc 100644 --- a/collie/models/internlm2/model.py +++ b/collie/models/internlm2/model.py @@ -130,7 +130,7 @@ def forward(self, hidden_states): hidden_states = hidden_states.to(torch.float32) variance = hidden_states.pow(2).mean(-1, keepdim=True) hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - return self.weight * hidden_states.to(input_dtype) + return (self.weight * hidden_states).to(input_dtype) # Copied from transformers.model.llama.modeling_llama.LlamaRotaryEmbedding with Llama->InternLM2 diff --git a/collie/models/llama/model.py b/collie/models/llama/model.py index a7d91e7..97a4efd 100644 --- a/collie/models/llama/model.py +++ b/collie/models/llama/model.py @@ -84,11 +84,12 @@ def set_norm_precision_to_float32(self): self.weight.data = self.weight.data.to(torch.float32) def forward(self, hidden_states): + input_dtype = hidden_states.dtype variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) hidden_states = hidden_states * torch.rsqrt(variance + self.eps) if self.weight.dtype in [torch.float16, torch.bfloat16]: hidden_states = hidden_states.to(self.weight.dtype) - return hidden_states * self.weight + return (hidden_states * self.weight).to(input_dtype) def post_init(self): """ diff --git a/collie/models/moss/model.py b/collie/models/moss/model.py index f28aeda..da18d02 100644 --- a/collie/models/moss/model.py +++ b/collie/models/moss/model.py @@ -84,11 +84,12 @@ def set_norm_precision_to_float32(self): self.weight.data = self.weight.data.to(torch.float32) def forward(self, hidden_states): + input_dtype = hidden_states.dtype variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) hidden_states = hidden_states * torch.rsqrt(variance + self.eps) if self.weight.dtype in [torch.float16, torch.bfloat16]: hidden_states = hidden_states.to(self.weight.dtype) - return hidden_states * self.weight + return (hidden_states * self.weight).to(input_dtype) class MossBlock(nn.Module): diff --git a/collie/models/moss_moon/model.py b/collie/models/moss_moon/model.py index 63afce0..763b941 100644 --- a/collie/models/moss_moon/model.py +++ b/collie/models/moss_moon/model.py @@ -288,7 +288,8 @@ def _forward( Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]], ]: residual = hidden_states - hidden_states = self.ln_1(hidden_states) + input_dtype = hidden_states.dtype + hidden_states = self.ln_1(hidden_states.to(torch.float32)).to(input_dtype) attn_outputs = self.attn( hidden_states=hidden_states, layer_past=layer_past, @@ -433,7 +434,8 @@ def forward( all_hidden_states += (input_dict["hidden_states"],) input_dict.update(l(input_dict)) - hidden_states = self.ln_f(input_dict["hidden_states"]) + input_dtype = input_dict["hidden_states"].dtype + hidden_states = self.ln_f(input_dict["hidden_states"].to(torch.float32)).to(input_dtype) all_hidden_states += (hidden_states,) past_key_values = None