Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions collie/controller/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,10 @@ def __init__(
self.on_after_trainer_initialized()
torch.cuda.empty_cache()

for name, layer in self.model.named_modules():
if hasattr(layer, 'set_norm_precision_to_float32'):
layer.set_norm_precision_to_float32()

def init_state_dict(self):
"""初始化优化器的自身状态字典"""
self.epoch_idx = 0
Expand Down
1 change: 0 additions & 1 deletion collie/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,6 @@ def from_pretrained(cls, model_path_or_name: str, config: Union[CollieConfig, st
if name in state_dict:
assert param.data.shape == state_dict[name].data.shape, f"The shape of the parameter corresponding to the `{name}` does not match: {param.data.shape} vs {state_dict[name].data.shape}"
param.data = value.to(param.device)


if config.peft_config.peft_type is not None:
model = get_peft_model(model, config.peft_config)
Expand Down
6 changes: 5 additions & 1 deletion collie/models/internlm/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,12 +80,16 @@ def __init__(self, dim=None, dtype=torch.float, eps=1e-5, weight=None):
)
self.eps = eps

def set_norm_precision_to_float32(self):
self.weight.data = self.weight.data.to(torch.float32)

def forward(self, hidden_states):
input_dtype = hidden_states.dtype
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
if self.weight.dtype in [torch.float16, torch.bfloat16]:
hidden_states = hidden_states.to(self.weight.dtype)
return hidden_states * self.weight
return (hidden_states * self.weight).to(input_dtype)


class InternLMLayer(nn.Module):
Expand Down
5 changes: 4 additions & 1 deletion collie/models/internlm2/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,12 +122,15 @@ def __init__(self, hidden_size, eps=1e-6):
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps

def set_norm_precision_to_float32(self):
self.weight.data = self.weight.data.to(torch.float32)

def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
return (self.weight * hidden_states).to(input_dtype)


# Copied from transformers.model.llama.modeling_llama.LlamaRotaryEmbedding with Llama->InternLM2
Expand Down
6 changes: 5 additions & 1 deletion collie/models/llama/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,12 +80,16 @@ def __init__(self, dim=None, dtype=torch.float, eps=1e-5, weight=None):
)
self.eps = eps

def set_norm_precision_to_float32(self):
self.weight.data = self.weight.data.to(torch.float32)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

llama里hidden_states和self.weight相乘之后是fp32,应该需要转回16位才能传给下一层,否则会报dtype不match的错。可以检查一下别的模型会不会有这个问题。

def forward(self, hidden_states):
input_dtype = hidden_states.dtype
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
if self.weight.dtype in [torch.float16, torch.bfloat16]:
hidden_states = hidden_states.to(self.weight.dtype)
return hidden_states * self.weight
return (hidden_states * self.weight).to(input_dtype)

def post_init(self):
"""
Expand Down
6 changes: 5 additions & 1 deletion collie/models/moss/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,12 +80,16 @@ def __init__(self, dim=None, dtype=torch.float, eps=1e-5, weight=None):
)
self.eps = eps

def set_norm_precision_to_float32(self):
self.weight.data = self.weight.data.to(torch.float32)

def forward(self, hidden_states):
input_dtype = hidden_states.dtype
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
if self.weight.dtype in [torch.float16, torch.bfloat16]:
hidden_states = hidden_states.to(self.weight.dtype)
return hidden_states * self.weight
return (hidden_states * self.weight).to(input_dtype)


class MossBlock(nn.Module):
Expand Down
14 changes: 12 additions & 2 deletions collie/models/moss_moon/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,10 @@ def __init__(self, config, layer_idx):
self.use_cache = False
self.hidden_states = None

def set_norm_precision_to_float32(self):
self.ln_1.weight.data = self.ln_1.weight.data.to(torch.float32)
self.ln_1.bias.data = self.ln_1.bias.data.to(torch.float32)

def _forward(
self,
hidden_states: Optional[torch.FloatTensor],
Expand All @@ -284,7 +288,8 @@ def _forward(
Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]],
]:
residual = hidden_states
hidden_states = self.ln_1(hidden_states)
input_dtype = hidden_states.dtype
hidden_states = self.ln_1(hidden_states.to(torch.float32)).to(input_dtype)
attn_outputs = self.attn(
hidden_states=hidden_states,
layer_past=layer_past,
Expand Down Expand Up @@ -394,6 +399,10 @@ def __init__(self, config):
self.h = nn.ModuleList([MossBlock(config, i) for i in range(config.n_layer)])
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)

def set_norm_precision_to_float32(self):
self.ln_f.weight.data = self.ln_f.weight.data.to(torch.float32)
self.ln_f.bias.data = self.ln_f.bias.data.to(torch.float32)

def forward(
self,
input_ids,
Expand Down Expand Up @@ -425,7 +434,8 @@ def forward(
all_hidden_states += (input_dict["hidden_states"],)
input_dict.update(l(input_dict))

hidden_states = self.ln_f(input_dict["hidden_states"])
input_dtype = input_dict["hidden_states"].dtype
hidden_states = self.ln_f(input_dict["hidden_states"].to(torch.float32)).to(input_dtype)
all_hidden_states += (hidden_states,)

past_key_values = None
Expand Down