diff --git a/configs/local_setup.yml b/configs/local_setup.yml index 99b3bdfd6..b17d5f024 100644 --- a/configs/local_setup.yml +++ b/configs/local_setup.yml @@ -26,5 +26,6 @@ "log-dir": "logs", "use_wandb": True, "wandb_host": "https://api.wandb.ai", - "wandb_project": "neox" + "wandb_project": "neox", + "ia3_tuning": False } diff --git a/megatron/checkpointing.py b/megatron/checkpointing.py index 9502f5b32..1661c962d 100644 --- a/megatron/checkpointing.py +++ b/megatron/checkpointing.py @@ -227,6 +227,8 @@ def load_checkpoint( ): """Load a model checkpoint and return the iteration.""" if neox_args.deepspeed: + if neox_args.ia3_tuning: + neox_args.load_module_strict = False load_optim_and_scheduler = ( not neox_args.no_load_optim ) # TODO: These should be configured by separate args @@ -241,6 +243,7 @@ def load_checkpoint( load_optimizer_states=load_optim_and_scheduler, load_lr_scheduler_states=load_optim_and_scheduler, tag=tag, + load_module_strict=neox_args.load_module_strict ) if checkpoint_name is None: diff --git a/megatron/model/.transformer.py.swp b/megatron/model/.transformer.py.swp new file mode 100644 index 000000000..3fb571738 Binary files /dev/null and b/megatron/model/.transformer.py.swp differ diff --git a/megatron/model/transformer.py b/megatron/model/transformer.py index e753a3532..1e8531693 100644 --- a/megatron/model/transformer.py +++ b/megatron/model/transformer.py @@ -18,6 +18,7 @@ """Transformer.""" import math +import sys import torch import torch.nn.functional as F import torch.nn as nn @@ -88,20 +89,21 @@ def __init__( # auto scale so geglu has equal parameters ff_mult = 4 * 2 / 3 if self.activation_type == "geglu" else 4 - ff_dim = ( + self.ff_dim = ( int(ff_mult * neox_args.hidden_size) * 2 if self.activation_type == "geglu" else ff_mult * neox_args.hidden_size ) + self.dense_h_to_4h = mpu.ColumnParallelLinear( neox_args=neox_args, input_size=neox_args.hidden_size, - output_size=ff_dim, + output_size=self.ff_dim, gather_output=False, init_method=init_method, skip_bias_add=True, ) - ff_dim_in = ff_dim // 2 if self.activation_type == "geglu" else ff_dim + ff_dim_in = self.ff_dim // 2 if self.activation_type == "geglu" else self.ff_dim # Project back to h. self.dense_4h_to_h = mpu.RowParallelLinear( neox_args=neox_args, @@ -134,6 +136,56 @@ def forward(self, hidden_states): return output, output_bias +class ParallelMLPIA3(ParallelMLP): + """MLP. + + MLP will take the input with h hidden state, project it to 4*h + hidden dimension, perform nonlinear transformation, and project the + state back into h hidden dimension. At the end, dropout is also + applied. + + Applies IA3 rescaling of each column after non-linearity: + https://arxiv.org/pdf/2205.05638.pdf + """ + + def __init__( + self, neox_args, init_method, output_layer_init_method, parallel_output=False + ): + super().__init__( + neox_args, + init_method=init_method, + output_layer_init_method=output_layer_init_method, + parallel_output=parallel_output + ) + + world_size = mpu.get_model_parallel_world_size() + self.hidden_size_per_partition = mpu.divide(self.ff_dim, world_size) # 4hp + self.l_ff = create_ia3_parameter(self.hidden_size_per_partition, neox_args) + + def forward(self, hidden_states): + + # [s, b, 4hp] + intermediate_parallel, bias_parallel = self.dense_h_to_4h(hidden_states) + + if ( + self.activation_type == "gelu" and self.bias_gelu_fusion + ) or self.activation_type == "geglu": + intermediate_parallel = self.activation_func( + intermediate_parallel, bias_parallel + ) + else: + intermediate_parallel = self.activation_func( + intermediate_parallel + bias_parallel + ) + + # Apply IA3 rescaling: + intermediate_parallel *= self.l_ff + + # [s, b, h] + output, output_bias = self.dense_4h_to_h(intermediate_parallel) + return output, output_bias + + class ParallelLinear(nn.Module): """ A Parallel Linear Layer transforming the transformer outputs from hidden_size -> vocab_size @@ -590,6 +642,154 @@ def forward(self, hidden_states, attention_mask, layer_past=None): return output, bias +class ParallelSelfAttentionIA3(ParallelSelfAttention): + """Applies IA3 rescaling to key and query vectors per: + https://arxiv.org/pdf/2205.05638.pdf + """ + def __init__( + self, + neox_args, + attention_mask_func, + init_method, + output_layer_init_method, + layer_number, + rpe=None, + rotary=False, + use_cache=False, + parallel_output=False, + ): + super().__init__( + neox_args, + attention_mask_func, + init_method, + output_layer_init_method, + layer_number, + rpe=rpe, + rotary=rotary, + use_cache=use_cache, + parallel_output=parallel_output, + ) + self.l_k = create_ia3_parameter(self.hidden_size_per_partition, neox_args) + self.l_v = create_ia3_parameter(self.hidden_size_per_partition, neox_args) + + + def forward(self, hidden_states, attention_mask, layer_past=None): + + # hidden_states: [sq, b, h] + + # ===================== + # Query, Key, and Value + # ===================== + + # Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)] + mixed_x_layer, _ = self.query_key_value(hidden_states) + + # [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn] + new_tensor_shape = mixed_x_layer.size()[:-1] + ( + self.num_attention_heads_per_partition, + 3 * self.hidden_size_per_attention_head, + ) + mixed_x_layer = mixed_x_layer.view(*new_tensor_shape) + + # [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn] + (query_layer, key_layer, value_layer) = mpu.split_tensor_along_last_dim( + mixed_x_layer, 3 + ) + + def _apply_ia3_rescaling(layer, scale_vector): + """Apply IA3 rescaling: + + Reshapes: [sq, b, np, hn] -> [sq, b, np * hn] to perform + rescaling and then back to [sq, b, np, hn]. + + Note: np * hn == h/p == self.hidden_size_per_partition + """ + layer_size = layer.shape + layer = layer.reshape(layer_size[0], layer_size[1], -1) + layer *= scale_vector + return layer.reshape(layer_size) + + key_layer = _apply_ia3_rescaling(key_layer, self.l_k) + value_layer = _apply_ia3_rescaling(value_layer, self.l_v) + + if exists(self.rotary_emb): + if exists(self.rotary_ndims): + # partial rotary + query_rot, query_pass = ( + query_layer[..., : self.rotary_ndims], + query_layer[..., self.rotary_ndims :], + ) + key_rot, key_pass = ( + key_layer[..., : self.rotary_ndims], + key_layer[..., self.rotary_ndims :], + ) + else: + # full rotary + query_rot, key_rot = query_layer, key_layer + apply_rotary_fn = ( + apply_rotary_pos_emb_torch if self.bf16 else apply_rotary_pos_emb + ) + + seq_len = key_layer.shape[0] + offset = 0 + if exists(layer_past) and layer_past.numel() > 0: + offset = layer_past[0].shape[0] + seq_len += offset + cos, sin = self.rotary_emb(value_layer, seq_len=seq_len) + query_layer, key_layer = apply_rotary_fn( + query_rot, key_rot, cos, sin, offset=offset + ) + + if exists(self.rotary_ndims): + query_layer = torch.cat((query_layer, query_pass), dim=-1) + key_layer = torch.cat((key_layer, key_pass), dim=-1) + + # ================================== + # Cache key and value for inference + # ================================== + + if exists(layer_past) and layer_past.numel() > 0: + past_key, past_value = layer_past + key_layer = torch.cat((past_key.type_as(key_layer), key_layer), dim=0) + value_layer = torch.cat( + (past_value.type_as(value_layer), value_layer), dim=0 + ) + + if self.use_cache: + present = torch.stack((key_layer, value_layer)) + + if self.use_flash_attention: + context_layer = self.flash_attention(query_layer, key_layer, value_layer) + elif not self.sparse: + context_layer = self.attention( + query_layer, key_layer, value_layer, layer_past, attention_mask + ) + else: + context_layer = self.sparse_attention( + query_layer, key_layer, value_layer, attention_mask + ) + + # [b, np, sq, hn] --> [sq, b, np, hn] + context_layer = context_layer.permute(2, 0, 1, 3).contiguous() + + # [sq, b, np, hn] --> [sq, b, hp] + new_context_layer_shape = context_layer.size()[:-2] + ( + self.hidden_size_per_partition, + ) + context_layer = context_layer.view(*new_context_layer_shape) + + # ================= + # Output. [sq, b, h] + # ================= + + output, bias = self.dense(context_layer) + + if self.use_cache: + output = [output, present] + + return output, bias + + class ParallelTransformerLayer(nn.Module): """A single transformer layer. @@ -625,9 +825,10 @@ def __init__( if self.gpt_j_residual: self.reduce = mpu.mappings.reduce_from_model_parallel_region + self_attention_cls = getattr(sys.modules[__name__], neox_args.self_attention_cls) # Self attention. - self.attention = ParallelSelfAttention( + self.attention = self_attention_cls( neox_args=neox_args, attention_mask_func=attention_mask_func, init_method=init_method, @@ -645,7 +846,8 @@ def __init__( self.post_attention_layernorm = norm(neox_args.hidden_size, eps=eps) # MLP - self.mlp = ParallelMLP( + parallel_mlp_cls = getattr(sys.modules[__name__], neox_args.parallel_mlp_cls) + self.mlp = parallel_mlp_cls( neox_args=neox_args, init_method=init_method, output_layer_init_method=output_layer_init_method, @@ -804,3 +1006,29 @@ def parallel_lm_logits(input_, word_embeddings_weight, parallel_output, bias=Non return logits_parallel return mpu.gather_from_model_parallel_region(logits_parallel) + + +def create_ia3_parameter(param_size, neox_args): + """Create a parameter vector for use in IA3 scaling, per: + https://arxiv.org/pdf/2205.05638.pdf + """ + if neox_args.use_cpu_initialization: + param = torch.nn.Parameter( + torch.empty( + param_size, dtype=neox_args.params_dtype + ) + ) + else: + param = torch.nn.Parameter( + torch.empty( + param_size, + device=torch.cuda.current_device(), + dtype=neox_args.params_dtype, + ) + ) + param.model_parallel = True + param.partition_dim = 0 + # Always initialize to ones. + with torch.no_grad(): + torch.nn.init.ones_(param) + return param diff --git a/megatron/model/utils.py b/megatron/model/utils.py index 3e9940c1e..77a6ddb26 100644 --- a/megatron/model/utils.py +++ b/megatron/model/utils.py @@ -48,14 +48,14 @@ def get_params_for_weight_decay_optimization(module, neox_args): [ p for n, p in list(module_._parameters.items()) - if p is not None and n != "bias" + if p is not None and n not in neox_args.no_weight_decay_params ] ) no_weight_decay_params["params"].extend( [ p for n, p in list(module_._parameters.items()) - if p is not None and n == "bias" + if p is not None and n in neox_args.no_weight_decay_params ] ) if neox_args.weight_decay == 0.0: diff --git a/megatron/neox_arguments/neox_args.py b/megatron/neox_arguments/neox_args.py index 95e6b6b8e..65d5c3788 100644 --- a/megatron/neox_arguments/neox_args.py +++ b/megatron/neox_arguments/neox_args.py @@ -13,7 +13,7 @@ # limitations under the License. import subprocess -from dataclasses import dataclass +from dataclasses import dataclass, field try: from .template import NeoXArgsTemplate @@ -355,11 +355,36 @@ class NeoXArgsModel(NeoXArgsTemplate): """ output_layer_parallelism: Literal["row", "column"] = "row" + ia3_tuning: bool = False + """ + Run IA3 tuning based off: + Few-Shot Parameter-Efficient Fine-Tuning is Better and Cheaper than In-Context Learning + https://arxiv.org/pdf/2205.05638.pdf + """ """ Parameter controlling whether the output layer is parallelized over the hidden dim (row) or the vocab dim (column) """ + self_attention_cls: str = "ParallelSelfAttention" + """ + Default class to use for self attention + """ + + parallel_mlp_cls: str = "ParallelMLP" + """ + Default class to use for linear MLP parallelism + """ + + no_weight_decay_params: list = field(default_factory=lambda: ["bias", "l_ff", "l_v", "l_k"]) + """ + Which parameters we won't apply weight decay to + """ + + load_module_strict: bool = True + """ + Whether to strictly enforce that the keys in state_dict of module & checkpoint match. + """ @dataclass class NeoXArgsOptimizer(NeoXArgsTemplate): diff --git a/megatron/training.py b/megatron/training.py index 6ebbe780d..5cf852608 100644 --- a/megatron/training.py +++ b/megatron/training.py @@ -385,6 +385,10 @@ def get_model(neox_args, use_cache=False): # If mup isn't being used anyways, this has no effect. old_use_mup = neox_args.use_mup neox_args.use_mup = False + if neox_args.ia3_tuning: + neox_args.parallel_mlp_cls = "ParallelMLPIA3" + neox_args.self_attention_cls = "ParallelSelfAttentionIA3" + model = GPT2ModelPipe( neox_args=neox_args, num_tokentypes=0, @@ -412,6 +416,16 @@ def get_model(neox_args, use_cache=False): for name, param in model.named_parameters(): if not "soft_embedding" in name: param.requires_grad = False + elif neox_args.ia3_tuning: + layers_to_train = ["l_ff", "l_k", "l_v"] + for name, param in model.named_parameters(): + if not any([x in name for x in layers_to_train]): + param.requires_grad = False + + trainable_params = sum( + p.numel() for p in model.parameters() if p.requires_grad + ) + print(f"Number of trainable parameters (current partition): {trainable_params}") if not neox_args.is_pipe_parallel: # Export PipeParallel model to nn.Sequential model to avoid the overhead of deepspeed's pipe parallel training