diff --git a/README.md b/README.md index a5b8f601..c8caa4d5 100644 --- a/README.md +++ b/README.md @@ -88,6 +88,7 @@ Notes: The following models from Hugging Face hub are currently supported - [microsoft/phi-2](https://huggingface.co/microsoft/phi-2) - [microsoft/Phi-3-mini-4k-instruct](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct) +- [microsoft/phi-4](https://huggingface.co/microsoft/phi-4) - [meta-llama/Llama-2-7b-hf](https://huggingface.co/meta-llama/Llama-2-7b) - [meta-llama/Llama-2-13b-hf](https://huggingface.co/meta-llama/Llama-2-13b) - [meta-llama/Llama-2-70b-hf](https://huggingface.co/meta-llama/Llama-2-70b) @@ -95,6 +96,12 @@ The following models from Hugging Face hub are currently supported - [meta-llama/Meta-Llama-3-8B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct) - [meta-llama/Meta-Llama-3-70B](https://huggingface.co/meta-llama/Meta-Llama-3-70B) - [meta-llama/Meta-Llama-3-70B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3-70B-Instruct) +- [meta-llama/Llama-3.1-8B](https://huggingface.co/meta-llama/Llama-3.1-8B) +- [meta-llama/Llama-3.1-8B-Instruct](https://huggingface.co/meta-llama/Llama-3.1-8B-Instruct) +- [meta-llama/Llama-3.2-1B](https://huggingface.co/meta-llama/Llama-3.2-1B) +- [meta-llama/Llama-3.2-1B-Instruct](https://huggingface.co/meta-llama/Llama-3.2-1B-Instruct) +- [meta-llama/Llama-3.2-3B](https://huggingface.co/meta-llama/Llama-3.2-3B) +- [meta-llama/Llama-3.2-3B-Instruct](https://huggingface.co/meta-llama/Llama-3.2-3B-Instruct) - [facebook/opt-125m](https://huggingface.co/facebook/opt-125m) - [facebook/opt-1.3b](https://huggingface.co/facebook/opt-1.3b) - [facebook/opt-2.7b](https://huggingface.co/facebook/opt-2.7b) diff --git a/experiments/bo_options.py b/experiments/bo_options.py index dea363c7..1bc7a9a7 100644 --- a/experiments/bo_options.py +++ b/experiments/bo_options.py @@ -44,7 +44,10 @@ def lora_target_map(model: str): 'lm_head', ], } - case 'meta-llama/Llama-2-7b-hf' | 'meta-llama/Llama-2-13b-hf' | 'meta-llama/Llama-2-70b-hf' | 'meta-llama/Meta-Llama-3-8B' | 'meta-llama/Meta-Llama-3-8B-Instruct' | 'meta-llama/Meta-Llama-3-70B' | 'meta-llama/Meta-Llama-3-70B-Instruct': + case 'meta-llama/Llama-2-7b-hf' | 'meta-llama/Llama-2-13b-hf' | 'meta-llama/Llama-2-70b-hf' | \ + 'meta-llama/Meta-Llama-3-8B' | 'meta-llama/Meta-Llama-3-8B-Instruct' | 'meta-llama/Meta-Llama-3-70B' | 'meta-llama/Meta-Llama-3-70B-Instruct' | \ + 'meta-llama/Llama-3.2-1B' | 'meta-llama/Llama-3.2-1B-Instruct' | 'meta-llama/Llama-3.2-3B' | 'meta-llama/Llama-3.2-3B-Instruct' | \ + 'meta-llama/Llama-3.1-8B' | 'meta-llama/Llama-3.1-8B-Instruct': return { 'qkv_proj': ['k_proj', 'q_proj', 'v_proj'], 'attn_head': ['k_proj', 'q_proj', 'v_proj', 'o_proj'], @@ -81,7 +84,7 @@ def lora_target_map(model: str): 'lm_head', ], } - case 'microsoft/Phi-3-mini-4k-instruct': + case 'microsoft/Phi-3-mini-4k-instruct' | 'microsoft/phi-4': return { 'qkv_proj': ['qkv_proj'], 'attn_head': ['qkv_proj', 'o_proj'], diff --git a/experiments/run_finetuning.py b/experiments/run_finetuning.py index 5c67c6f9..d7e267e8 100644 --- a/experiments/run_finetuning.py +++ b/experiments/run_finetuning.py @@ -330,6 +330,7 @@ def finetuning_main(args: argparse.Namespace) -> None: args=training_args, optimizers=(optimizer, lr_scheduler), callbacks=[EarlyStoppingCallback(early_stopping_patience=args.early_stopping_patience)], + eval_dataset=args.ppl_eval_dataset, # update for transformers==4.48.0 ) # required to enable gradient_checkpointing diff --git a/pyproject.toml b/pyproject.toml index f18dffa9..16bad3d5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,7 +27,7 @@ dependencies = [ "numpy", "torch", "tqdm", - "transformers==4.41.0", + "transformers==4.48.0", ] [project.optional-dependencies] diff --git a/src/slicegpt/adapters/llama_adapter.py b/src/slicegpt/adapters/llama_adapter.py index 1a7eb920..4f97069e 100644 --- a/src/slicegpt/adapters/llama_adapter.py +++ b/src/slicegpt/adapters/llama_adapter.py @@ -13,6 +13,7 @@ from transformers.models.llama.modeling_llama import LlamaConfig, LlamaDecoderLayer, LlamaForCausalLM, LlamaRMSNorm from slicegpt.model_adapter import LayerAdapter, ModelAdapter +from typing import Optional, Tuple class CompressedLlamaDecoderLayer(LlamaDecoderLayer): @@ -22,16 +23,19 @@ class CompressedLlamaDecoderLayer(LlamaDecoderLayer): but with the addition of a shortcut_Q attribute. This attribute is used to rotate the residual tensors. """ - def forward( + def forward( # Update forward() for transformers 4.48.0 self, - hidden_states: Tensor, - attention_mask: Tensor | None = None, - position_ids: LongTensor | None = None, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, past_key_value: tuple[Tensor] | None = None, - output_attentions: bool | None = False, - use_cache: bool | None = False, - **kwargs, - ) -> tuple: + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + **kwargs + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ Args: hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` @@ -51,15 +55,28 @@ def forward( hidden_states = self.input_layernorm(hidden_states) # Self Attention - hidden_states, self_attn_weights, present_key_value = self.self_attn( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - **kwargs, - ) + # TODO(mercy): Not sure if this works for older versions of LlaMa + try: + hidden_states, self_attn_weights = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + position_embeddings=position_embeddings, + **kwargs, + ) + except ValueError: # Older LlaMa version + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + **kwargs, + ) if self.attn_shortcut_Q is not None: rotated_residual = matmul(residual, self.attn_shortcut_Q) hidden_states = rotated_residual + hidden_states @@ -82,7 +99,7 @@ def forward( if output_attentions: outputs += (self_attn_weights,) - if use_cache: + if use_cache and 'present_key_value' in locals(): outputs += (present_key_value,) return outputs @@ -224,7 +241,7 @@ def _from_pretrained( local_files_only: bool = False, token: str | bool | None = None, ) -> ModelAdapter | None: - if not (model_name.startswith("meta-llama/Llama-2") or model_name.startswith("meta-llama/Meta-Llama-3")): + if not (model_name.startswith("meta-llama/Llama-") or model_name.startswith("meta-llama/Meta-Llama-3")): return None model = LlamaForCausalLM.from_pretrained( @@ -244,7 +261,7 @@ def _from_uninitialized( local_files_only: bool = False, token: str | bool | None = None, ) -> ModelAdapter | None: - if not (model_name.startswith("meta-llama/Llama-2") or model_name.startswith("meta-llama/Meta-Llama-3")): + if not (model_name.startswith("meta-llama/Llama-") or model_name.startswith("meta-llama/Meta-Llama-3")): return None class UninitializedLlamaForCausalLM(LlamaForCausalLM): diff --git a/src/slicegpt/adapters/opt_adapter.py b/src/slicegpt/adapters/opt_adapter.py index 709d2be3..f725a0cd 100644 --- a/src/slicegpt/adapters/opt_adapter.py +++ b/src/slicegpt/adapters/opt_adapter.py @@ -14,6 +14,7 @@ from slicegpt.model_adapter import LayerAdapter, ModelAdapter +from typing import Optional, Tuple class CompressedOPTDecoderLayer(OPTDecoderLayer): """ @@ -24,12 +25,13 @@ class CompressedOPTDecoderLayer(OPTDecoderLayer): def forward( self, - hidden_states: Tensor, - attention_mask: Tensor | None = None, - layer_head_mask: Tensor | None = None, - past_key_value: tuple[Tensor] | None = None, - output_attentions: bool | None = False, - use_cache: bool | None = False, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + position_ids: Optional[torch.LongTensor] = None ) -> tuple: """ Args: @@ -57,6 +59,7 @@ def forward( hidden_states, self_attn_weights, present_key_value = self.self_attn( hidden_states=hidden_states, past_key_value=past_key_value, + position_ids=position_ids, attention_mask=attention_mask, layer_head_mask=layer_head_mask, output_attentions=output_attentions, diff --git a/src/slicegpt/adapters/phi2_adapter.py b/src/slicegpt/adapters/phi2_adapter.py index 3d1aa1bc..10ea5304 100644 --- a/src/slicegpt/adapters/phi2_adapter.py +++ b/src/slicegpt/adapters/phi2_adapter.py @@ -12,7 +12,7 @@ from torch.nn import LayerNorm, Linear, Module from transformers import PretrainedConfig, PreTrainedTokenizerBase from transformers.models.phi.modeling_phi import PhiConfig, PhiDecoderLayer, PhiForCausalLM - +from typing import Optional, Tuple from slicegpt.model_adapter import LayerAdapter, ModelAdapter @@ -25,12 +25,14 @@ class CompressedPhiDecoderLayer(PhiDecoderLayer): def forward( self, - hidden_states: Tensor, - attention_mask: Tensor | None = None, - position_ids: LongTensor | None = None, - output_attentions: bool | None = False, - use_cache: bool | None = False, - past_key_value: tuple[Tensor] | None = None, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + **kwargs ) -> tuple: """ Args: @@ -55,14 +57,27 @@ def forward( hidden_states = self.input_layernorm(hidden_states) # Self Attention - attn_outputs, self_attn_weights, present_key_value = self.self_attn( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - ) + try: + attn_outputs, self_attn_weights = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + position_embeddings=position_embeddings, + **kwargs, + ) + except TypeError: + attn_outputs, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + attn_outputs = self.resid_dropout(attn_outputs) feed_forward_hidden_states = self.resid_dropout(self.mlp(hidden_states)) @@ -78,7 +93,7 @@ def forward( if output_attentions: outputs += (self_attn_weights,) - if use_cache: + if use_cache and 'present_key_value' in locals(): outputs += (present_key_value,) return outputs diff --git a/src/slicegpt/adapters/phi3_adapter.py b/src/slicegpt/adapters/phi3_adapter.py index ccd268da..a191661f 100644 --- a/src/slicegpt/adapters/phi3_adapter.py +++ b/src/slicegpt/adapters/phi3_adapter.py @@ -31,16 +31,18 @@ class CompressedPhi3DecoderLayer(Phi3DecoderLayer): but with the addition of a shortcut_Q attribute. This attribute is used to rotate the residual tensors. """ + def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - output_attentions: Optional[bool] = False, - use_cache: Optional[bool] = False, - **kwargs, - ) -> Tuple[Any]: + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tensor] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: if "padding_mask" in kwargs: warnings.warn( "Passing `padding_mask` is deprecated and will be removed in v4.37. " @@ -68,15 +70,26 @@ def forward( hidden_states = self.input_layernorm(hidden_states) - # Self Attention - attn_outputs, self_attn_weights, present_key_value = self.self_attn( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - ) + try: + attn_outputs, self_attn_weights = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + position_embeddings=position_embeddings, + **kwargs, + ) + except ValueError: # Older version + attn_outputs, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) if self.attn_shortcut_Q is not None: rotated_residual = matmul(residual, self.attn_shortcut_Q) @@ -99,7 +112,7 @@ def forward( if output_attentions: outputs += (self_attn_weights,) - if use_cache: + if use_cache and 'present_key_value' in locals(): outputs += (present_key_value,) return outputs @@ -237,9 +250,13 @@ def _from_pretrained( local_files_only: bool = False, token: str | bool | None = None, ) -> ModelAdapter | None: - if not model_name.startswith("microsoft/Phi-3-mini-4k-instruct"): + if not (model_name.startswith("microsoft/Phi-3-mini-4k-instruct") or \ + model_name.startswith("microsoft/phi-4") or \ + model_name.startswith("microsoft/Phi-4-mini-instruct")): # TODO: Unfinished return None + print(model_path,dtype,token,local_files_only) + model = Phi3ForCausalLM.from_pretrained( model_path, torch_dtype=dtype, token=token, local_files_only=local_files_only ) @@ -257,7 +274,9 @@ def _from_uninitialized( local_files_only: bool = False, token: str | bool | None = None, ) -> ModelAdapter | None: - if not model_name.startswith("microsoft/Phi-3-mini-4k-instruct"): + if not (model_name.startswith("microsoft/Phi-3-mini-4k-instruct") or \ + model_name.startswith("microsoft/phi-4") or \ + model_name.startswith("microsoft/Phi-4-mini-instruct")): return None class UninitializedPhi3ForCausalLM(Phi3ForCausalLM): diff --git a/src/slicegpt/model_utils.py b/src/slicegpt/model_utils.py index 4e8833aa..e4300c64 100644 --- a/src/slicegpt/model_utils.py +++ b/src/slicegpt/model_utils.py @@ -26,7 +26,7 @@ def get_layer0_inputs(model_adapter: ModelAdapter, batch: Tensor) -> tuple[Tenso """ # Move embeddings to device. for W in model_adapter.get_embeddings(): - W.weight = torch.nn.Parameter(W.weight.to(config.device)) + W.weight = torch.nn.Parameter(W.weight.to(model_adapter.model.device)) class Catcher(torch.nn.Module): def __init__(self): @@ -42,7 +42,7 @@ def forward(self, *args, **kwargs): model_adapter.set_raw_layer_at(0, layer0_catcher) try: - batch = utils.map_tensors(batch, device=config.device) + batch = utils.map_tensors(batch, device=model_adapter.model.device) model_adapter.model(**batch) except ValueError: pass