diff --git a/experiments/run_lm_eval.py b/experiments/run_lm_eval.py index 15b2f9f3..0e769f6f 100644 --- a/experiments/run_lm_eval.py +++ b/experiments/run_lm_eval.py @@ -7,14 +7,15 @@ import os import lm_eval +from slicegpt.model_adapter import ModelAdapter import torch -import wandb from lm_eval import tasks from lm_eval import utils as lm_eval_utils from lm_eval.api.registry import ALL_TASKS from lm_eval.models.huggingface import HFLM from lm_eval.tasks import initialize_tasks +import wandb from slicegpt import gpu_utils, hf_utils, utils from slicegpt.config import config @@ -120,29 +121,34 @@ def eval_main(args: argparse.Namespace) -> None: if args.sliced_model_path: # load the sliced model logging.info(f"Loading sliced {args.model} model from {args.sliced_model_path} with sparsity {args.sparsity}") - model_adapter, tokenizer = hf_utils.load_sliced_model( + model, tokenizer = hf_utils.load_sliced_model( args.model, args.sliced_model_path, sparsity=args.sparsity, token=args.hf_token, round_interval=args.round_interval, ) + if isinstance(model, ModelAdapter): + model = model.model + else: + model = model.to(config.dtype) else: # load the original model logging.info(f"Loading {args.model} model") model_adapter, tokenizer = hf_utils.get_model_and_tokenizer(args.model, args.model_path, token=args.hf_token) + model = model_adapter.model # the lm eval harness ties the weights, but this should not be done for sliced models unless the lm_head was sliced - model_adapter.model.tie_weights = lambda: None + model.model.tie_weights = lambda: None if args.distribute_model: # distribute model across available GPUs gpu_utils.distribute_model(model_adapter) else: - model_adapter.model.to(config.device) + model.to(config.device) ### LM Eval Harness ### - hflm = HFLM(pretrained=model_adapter.model, tokenizer=tokenizer, batch_size=args.batch_size) + hflm = HFLM(pretrained=model, tokenizer=tokenizer, batch_size=args.batch_size) if args.tasks is None: task_names = tasks.ALL_TASKS diff --git a/experiments/run_slicegpt.py b/experiments/run_slicegpt.py index e02332c3..fcec8dff 100755 --- a/experiments/run_slicegpt.py +++ b/experiments/run_slicegpt.py @@ -8,8 +8,8 @@ import shutil import torch -import wandb +import wandb from slicegpt import data_utils, gpu_utils, hf_utils, layernorm_fusion, rotate, utils from slicegpt.config import config from slicegpt.slicing_scheduler import ConstSlicingScheduler @@ -137,20 +137,20 @@ def slicing_main(args: argparse.Namespace) -> None: if args.sliced_model_path: # load the model from sliced_model_path to compute perplexity and skip rotation and slicing - model_adapter, tokenizer = hf_utils.load_sliced_model( + model, tokenizer = hf_utils.load_sliced_model( args.model, args.sliced_model_path, sparsity=args.sparsity, round_interval=args.round_interval, token=args.hf_token, ) + model = model.to(config.dtype) else: # load one of the pre-trained models model_adapter, tokenizer = hf_utils.get_model_and_tokenizer( args.model, args.model_path, token=args.hf_token, dtype=config.dtype ) - - model = model_adapter.model + model = model_adapter.model def reset_model_device() -> None: if args.distribute_model: @@ -228,14 +228,17 @@ def reset_model_device() -> None: sliced_model_dir = pathlib.Path(args.save_dir) sliced_model_dir.mkdir(parents=True, exist_ok=True) - sliced_model_name = sliced_model_dir / f'{pathlib.Path(args.model).name}_{args.sparsity}.pt' - - # Save the sliced model - torch.save(model.state_dict(), sliced_model_name) - - # Save the slicing config - config_path = sliced_model_name.with_suffix('.json') - config_path.write_text(model_adapter.slicing_conf.to_json_string()) + # Save the sliced model in HF format for Phi and Llama + hf_utils.save_sliced_model( + args.model, + config.dtype, + model, + scheduler, + sliced_model_dir, + args.sparsity, + new_embedding_dimension, + model_adapter.slicing_conf, + ) # If slicing a local model, also save HF config files in sliced model dir if args.model_path: diff --git a/src/slicegpt/__init__.py b/src/slicegpt/__init__.py index a2e75aff..dc494cd2 100644 --- a/src/slicegpt/__init__.py +++ b/src/slicegpt/__init__.py @@ -4,6 +4,8 @@ from .adapters.llama_adapter import LlamaModelAdapter from .adapters.opt_adapter import OPTModelAdapter from .adapters.phi2_adapter import Phi2ModelAdapter +from .adapters.sliced_llama import SlicedLlama, SlicedLlamaConfig, SlicedLlamaForCausalLM +from .adapters.sliced_phi import SlicedPhi, SlicedPhi2Config, SlicedPhiForCausalLM from .data_utils import get_dataset, prepare_dataloader from .gpu_utils import benchmark, distribute_model, evaluate_ppl from .hf_utils import get_model_and_tokenizer, load_sliced_model diff --git a/src/slicegpt/adapters/__init__.py b/src/slicegpt/adapters/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/slicegpt/adapters/llama_adapter.py b/src/slicegpt/adapters/llama_adapter.py index 78e60a2b..9cc4e028 100644 --- a/src/slicegpt/adapters/llama_adapter.py +++ b/src/slicegpt/adapters/llama_adapter.py @@ -14,6 +14,7 @@ from transformers.models.llama.modeling_llama import LlamaConfig, LlamaDecoderLayer, LlamaForCausalLM, LlamaRMSNorm from slicegpt.model_adapter import LayerAdapter, ModelAdapter +from slicegpt.modules import RMSN class CompressedLlamaDecoderLayer(LlamaDecoderLayer): @@ -23,6 +24,11 @@ class CompressedLlamaDecoderLayer(LlamaDecoderLayer): but with the addition of a shortcut_Q attribute. This attribute is used to rotate the residual tensors. """ + def __init__(self, config: LlamaConfig, layer_idx: int, replace_layernorm: bool = False): + super().__init__(config, layer_idx) + if replace_layernorm: + self.input_layernorm = RMSN(config.hidden_size) + def forward( self, hidden_states: Tensor, diff --git a/src/slicegpt/adapters/opt_adapter.py b/src/slicegpt/adapters/opt_adapter.py index bd62da1c..2219e446 100644 --- a/src/slicegpt/adapters/opt_adapter.py +++ b/src/slicegpt/adapters/opt_adapter.py @@ -14,6 +14,7 @@ from transformers.models.opt.modeling_opt import OPTConfig, OPTDecoderLayer, OPTForCausalLM from slicegpt.model_adapter import LayerAdapter, ModelAdapter +from slicegpt.modules import RMSN class CompressedOPTDecoderLayer(OPTDecoderLayer): @@ -23,6 +24,11 @@ class CompressedOPTDecoderLayer(OPTDecoderLayer): We also support the input rotation and mean subtraction in this class (if needed). """ + def __init__(self, config: OPTConfig, replace_layernorm: bool = False): + super().__init__(config) + if replace_layernorm: + self.input_layernorm = RMSN(config.hidden_size) + def forward( self, hidden_states: Tensor, diff --git a/src/slicegpt/adapters/phi2_adapter.py b/src/slicegpt/adapters/phi2_adapter.py index 37cbf6fe..9c856ef0 100644 --- a/src/slicegpt/adapters/phi2_adapter.py +++ b/src/slicegpt/adapters/phi2_adapter.py @@ -16,6 +16,7 @@ from transformers.models.phi.modeling_phi import PhiConfig, PhiDecoderLayer, PhiForCausalLM from slicegpt.model_adapter import LayerAdapter, ModelAdapter +from slicegpt.modules import RMSN class CompressedPhiDecoderLayer(PhiDecoderLayer): @@ -25,6 +26,11 @@ class CompressedPhiDecoderLayer(PhiDecoderLayer): but with the addition of a shortcut_Q attribute. This attribute is used to rotate the residual tensors. """ + def __init__(self, config: PhiConfig, layer_idx: int, replace_layernorm: bool = False): + super().__init__(config, layer_idx) + if replace_layernorm: + self.input_layernorm = RMSN(config.hidden_size) + def forward( self, hidden_states: Tensor, diff --git a/src/slicegpt/adapters/sliced_llama.py b/src/slicegpt/adapters/sliced_llama.py new file mode 100644 index 00000000..ce575111 --- /dev/null +++ b/src/slicegpt/adapters/sliced_llama.py @@ -0,0 +1,86 @@ +import torch +import torch.nn as nn +from transformers.configuration_utils import PretrainedConfig +from transformers.models.llama.modeling_llama import LlamaConfig, LlamaForCausalLM, LlamaModel + +from slicegpt.adapters.llama_adapter import CompressedLlamaDecoderLayer, LlamaModelAdapter +from slicegpt.modules import RMSN +from slicegpt.rotate import slice_rotated_model +from slicegpt.slicing_scheduler import SlicingScheduler + + +class SlicedLlamaConfig(LlamaConfig): + model_type = "sliced_llama" + is_composition = True + + def __init__(self, **kwargs) -> None: + self.sparsity = kwargs.pop("sparsity", None) + self.new_hidden_size = kwargs.pop("new_hidden_size", None) + super().__init__(**kwargs) + + @classmethod + def from_pretrained(cls, config_path: str, sparsity: float, new_hidden_size: int) -> PretrainedConfig: + kwargs = {"sparsity": sparsity, "new_hidden_size": new_hidden_size} + return super().from_pretrained(config_path, **kwargs) + + +class SlicedLlama(LlamaModel): + def __init__(self, config): + super().__init__(config) + self.config = config + self.layers = nn.ModuleList( + [ + CompressedLlamaDecoderLayer(config, layer_idx, replace_layernorm=True) + for layer_idx in range(config.num_hidden_layers) + ] + ) + self.final_layernorm = RMSN(config.hidden_size) + + +class SlicedLlamaForCausalLM(LlamaForCausalLM): + def __init__( + self, + config, + scheduler: SlicingScheduler | None = None, + *model_args, + **kwargs, + ): + super().__init__(config) + self.model = SlicedLlama(config) + self.model_adapter = LlamaModelAdapter(self) + + if scheduler: + self.update_dims(scheduler) + + @classmethod + def from_pretrained( + cls, + pretrained_model_name_or_path, + scheduler: SlicingScheduler | None, + sparsity: float, + new_hidden_size: int, + config_path: str, + *model_args, + **kwargs, + ): + """Overrides the from_pretrained method to accept the scheduler and returns the sliced model""" + config = SlicedLlamaConfig.from_pretrained(config_path, sparsity, new_hidden_size) + kwargs = {"scheduler": scheduler} + model = super().from_pretrained(pretrained_model_name_or_path, config=config, **kwargs) + model.load_state_dict(model.state_dict()) + return model + + def update_dims(self, scheduler: SlicingScheduler) -> None: + layers = self.model_adapter.get_layers() + + hidden_size = self.model_adapter.hidden_size + for layer_adapter in layers: + if not self.model_adapter.parallel_blocks: + layer_adapter.layer.mlp_shortcut_Q = torch.nn.Parameter( + torch.zeros(hidden_size, hidden_size).to(dtype=torch.float16).contiguous() + ) + layer_adapter.layer.attn_shortcut_Q = torch.nn.Parameter( + torch.zeros(hidden_size, hidden_size).to(dtype=torch.float16).contiguous() + ) + + slice_rotated_model(self.model_adapter, scheduler) diff --git a/src/slicegpt/adapters/sliced_phi.py b/src/slicegpt/adapters/sliced_phi.py new file mode 100644 index 00000000..87aee654 --- /dev/null +++ b/src/slicegpt/adapters/sliced_phi.py @@ -0,0 +1,86 @@ +import torch +import torch.nn as nn +from transformers.configuration_utils import PretrainedConfig +from transformers.models.phi.modeling_phi import PhiConfig, PhiForCausalLM, PhiModel + +from slicegpt.adapters.phi2_adapter import CompressedPhiDecoderLayer, Phi2ModelAdapter +from slicegpt.modules import RMSN +from slicegpt.rotate import slice_rotated_model +from slicegpt.slicing_scheduler import SlicingScheduler + + +class SlicedPhi2Config(PhiConfig): + model_type = "sliced_phi2" + is_composition = True + + def __init__(self, **kwargs) -> None: + self.sparsity = kwargs.pop("sparsity", None) + self.new_hidden_size = kwargs.pop("new_hidden_size", None) + super().__init__(**kwargs) + + @classmethod + def from_pretrained(cls, config_path: str, sparsity: float, new_hidden_size: int) -> PretrainedConfig: + kwargs = {"sparsity": sparsity, "new_hidden_size": new_hidden_size} + return super().from_pretrained(config_path, local_files_only=True, **kwargs) + + +class SlicedPhi(PhiModel): + def __init__(self, config): + super().__init__(config) + self.config = config + self.layers = nn.ModuleList( + [ + CompressedPhiDecoderLayer(config, layer_idx, replace_layernorm=True) + for layer_idx in range(config.num_hidden_layers) + ] + ) + self.final_layernorm = RMSN(config.hidden_size) + + +class SlicedPhiForCausalLM(PhiForCausalLM): + def __init__( + self, + config, + scheduler: SlicingScheduler | None = None, + *model_args, + **kwargs, + ): + super().__init__(config) + self.model = SlicedPhi(config) + self.model_adapter = Phi2ModelAdapter(self) + + if scheduler: + self.update_dims(scheduler) + + @classmethod + def from_pretrained( + cls, + pretrained_model_name_or_path, + scheduler: SlicingScheduler | None, + sparsity: float, + new_hidden_size: int, + config_path: str, + *model_args, + **kwargs, + ): + """Overrides the from_pretrained method to accept the scheduler and returns the sliced model""" + config = SlicedPhi2Config.from_pretrained(config_path, sparsity, new_hidden_size) + kwargs = {"scheduler": scheduler} + model = super().from_pretrained(pretrained_model_name_or_path, config=config, **kwargs) + model.load_state_dict(model.state_dict()) + return model + + def update_dims(self, scheduler: SlicingScheduler) -> None: + layers = self.model_adapter.get_layers() + + hidden_size = self.model_adapter.hidden_size + for layer_adapter in layers: + if not self.model_adapter.parallel_blocks: + layer_adapter.layer.mlp_shortcut_Q = torch.nn.Parameter( + torch.zeros(hidden_size, hidden_size).to(dtype=torch.float16).contiguous() + ) + layer_adapter.layer.attn_shortcut_Q = torch.nn.Parameter( + torch.zeros(hidden_size, hidden_size).to(dtype=torch.float16).contiguous() + ) + + slice_rotated_model(self.model_adapter, scheduler) diff --git a/src/slicegpt/hf_utils.py b/src/slicegpt/hf_utils.py index 7a9167b5..a4ce3d05 100755 --- a/src/slicegpt/hf_utils.py +++ b/src/slicegpt/hf_utils.py @@ -7,10 +7,15 @@ import torch from peft import LoraConfig, get_peft_model from transformers import AutoTokenizer, PreTrainedTokenizerBase +from transformers.models.llama.modeling_llama import LlamaConfig +from transformers.models.phi.modeling_phi import PhiConfig +from .adapters.sliced_llama import SlicedLlamaConfig, SlicedLlamaForCausalLM +from .adapters.sliced_phi import SlicedPhi2Config, SlicedPhiForCausalLM from .layernorm_fusion import fuse_modules, replace_layers from .model_adapter import ModelAdapter, SlicingConfig from .rotate import slice_rotated_model +from .slicing_scheduler import ConstSlicingScheduler, SlicingScheduler def do_not_initialize(func): @@ -102,7 +107,7 @@ def get_model_and_tokenizer( model.eval() # This switches off dropout. model_adapter.use_cache = False - tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True, token=token, local_files_only=local_model) + tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True, token=token, local_files_only=local_model) model_adapter.post_init(tokenizer) logging.info("Loading model done") @@ -110,24 +115,20 @@ def get_model_and_tokenizer( return model_adapter, tokenizer -@do_not_initialize def load_sliced_model( model_name: str, sliced_model_path: str, *, token: str | None = None, - lora_config: LoraConfig = None, + lora_config: LoraConfig | None = None, sparsity: float | None = None, round_interval: int | None = 1, -) -> tuple[ModelAdapter, PreTrainedTokenizerBase]: +) -> tuple[ModelAdapter | torch.nn.Module, PreTrainedTokenizerBase]: """ Load the sliced model and the tokenizer from the given path. If lora_config is supplied as an arg then this function will return a PEFT model (post-slicing finetuned model). The corresponding model adapter class must be imported before calling this method. """ - my_model_suffix = pathlib.Path(model_name).name - my_sliced_model_name = f"{my_model_suffix}_{sparsity}.pt" - my_sliced_model_config = f"{my_model_suffix}_{sparsity}.json" model_adapter, tokenizer = get_model_and_tokenizer( model_name, @@ -135,6 +136,36 @@ def load_sliced_model( uninitialized=True, token=token, ) + + # handle loading sliced HF compatible models + if model_name.startswith("microsoft") or model_name.startswith("llama"): + new_embedding_dimension = int((1 - sparsity) * model_adapter.hidden_size) + new_embedding_dimension -= new_embedding_dimension % round_interval + + scheduler = ConstSlicingScheduler(new_embedding_dimension) + + layers = model_adapter.get_layers() + scheduler.setup( + hidden_size=model_adapter.hidden_size, + layers_num=len(layers), + parallel_blocks=True, + ) + + sliced_model = SlicedPhiForCausalLM.from_pretrained( + sliced_model_path, + scheduler=scheduler, + config_path=sliced_model_path, + sparsity=sparsity, + new_hidden_size=new_embedding_dimension, + ) + + tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True, token=token, local_files_only=True) + return sliced_model, tokenizer + + my_model_suffix = pathlib.Path(model_name).name + my_sliced_model_name = f"{my_model_suffix}_{sparsity}.pt" + my_sliced_model_config = f"{my_model_suffix}_{sparsity}.json" + replace_layers(model_adapter) fuse_modules(model_adapter) @@ -173,3 +204,51 @@ def load_sliced_model( model_adapter.model.eval() return model_adapter, tokenizer + + +def save_sliced_model( + model_name: str, + dtype: torch.dtype, + model: torch.nn.Module, + scheduler: SlicingScheduler, + save_sliced_model_dir: str | pathlib.Path, + sparsity: float, + new_hidden_size: int, + slicing_conf: SlicingConfig, +): + if model_name == "microsoft/phi-2": + config = PhiConfig.from_pretrained( + model_name, + torch_dtype=torch.float16, + ) + + config.save_pretrained("phi_config") + config_to_save = SlicedPhi2Config.from_pretrained( + config_path="phi_config", sparsity=sparsity, new_hidden_size=new_hidden_size + ) + + sliced_model = SlicedPhiForCausalLM(config_to_save, scheduler).to(dtype) + sliced_model.load_state_dict(model.state_dict(), strict=True, assign=True) + sliced_model.save_pretrained(save_sliced_model_dir) + + elif "meta-llama" in model_name: + config = LlamaConfig.from_pretrained( + model_name, + torch_dtype=torch.float16, + ) + config.save_pretrained("llama_config") + config_to_save = SlicedLlamaConfig.from_pretrained( + config_path="llama_config", + sparsity=sparsity, + new_hidden_size=new_hidden_size, + ) + + sliced_model = SlicedLlamaForCausalLM(config_to_save, scheduler).to(dtype) + sliced_model.load_state_dict(model.state_dict(), strict=True, assign=True) + sliced_model.save_pretrained(save_sliced_model_dir) + else: + # Save the sliced model for other models types + sliced_model_name = save_sliced_model_dir / f'{pathlib.Path(model_name).name}_{sparsity}.pt' + torch.save(model.state_dict(), sliced_model_name) + config_path = sliced_model_name.with_suffix('.json') + config_path.write_text(slicing_conf.to_json_string()) diff --git a/src/slicegpt/model_adapter.py b/src/slicegpt/model_adapter.py index f9b0f8bc..6c4052e6 100644 --- a/src/slicegpt/model_adapter.py +++ b/src/slicegpt/model_adapter.py @@ -113,7 +113,7 @@ class ModelAdapter(ABC): To implement a new model adapter, implement the interface defined in this class """ - def __init__(self): + def __init__(self) -> None: self.slicing_conf: SlicingConfig | None = None @property diff --git a/src/slicegpt/rotate.py b/src/slicegpt/rotate.py index c5333b3f..5d31953e 100644 --- a/src/slicegpt/rotate.py +++ b/src/slicegpt/rotate.py @@ -26,7 +26,7 @@ def rotate_attention_inputs(layer_adapter: LayerAdapter, Q: torch.Tensor) -> Non def slice_attention_inputs(layer_adapter: LayerAdapter, new_embedding_dimension: int) -> None: # Slice the WQ, WK and WV matrices of the self-attention layer. for W in layer_adapter.get_attention_inputs(): - W.weight.data = W.weight.data[:, :new_embedding_dimension] + W.weight.data = W.weight.data[:, :new_embedding_dimension].contiguous() W.in_features = new_embedding_dimension layer_adapter.layer.attn_shortcut_Q = nn.Parameter(layer_adapter.layer.attn_shortcut_Q[:new_embedding_dimension, :]) @@ -47,7 +47,7 @@ def rotate_attention_output(layer_adapter: LayerAdapter, Q: torch.Tensor) -> Non def slice_attention_output(layer_adapter: LayerAdapter, new_embedding_dimension: int) -> None: # Slice output matrix of the self-attention layer. W = layer_adapter.get_attention_output() - W.weight.data = W.weight.data[:new_embedding_dimension, :] + W.weight.data = W.weight.data[:new_embedding_dimension, :].contiguous() if W.bias is not None: W.bias.data = W.bias.data[:new_embedding_dimension] W.out_features = new_embedding_dimension @@ -64,7 +64,7 @@ def rotate_mlp_input(layer_adapter: LayerAdapter, Q: torch.Tensor) -> None: def slice_mlp_input(layer_adapter: LayerAdapter, new_embedding_dimension: int) -> None: # Slice the MLP input weights. for W in layer_adapter.get_mlp_inputs(): - W.weight.data = W.weight.data[:, :new_embedding_dimension] + W.weight.data = W.weight.data[:, :new_embedding_dimension].contiguous() W.in_features = new_embedding_dimension @@ -82,7 +82,7 @@ def rotate_mlp_output(layer_adapter: LayerAdapter, Q: torch.Tensor) -> None: def slice_mlp_output(layer_adapter: LayerAdapter, new_embedding_dimension: int) -> None: # Slice the MLP output weights and bias. W = layer_adapter.get_mlp_output() - W.weight.data = W.weight.data[:new_embedding_dimension, :] + W.weight.data = W.weight.data[:new_embedding_dimension, :].contiguous() if W.bias is not None: W.bias.data = W.bias.data[:new_embedding_dimension] W.out_features = new_embedding_dimension @@ -102,7 +102,7 @@ def rotate_embeddings(model_adapter: ModelAdapter, Q: torch.Tensor) -> None: def slice_embeddings(model_adapter: ModelAdapter, new_embedding_dimensions: dict[int, int]) -> None: # Slice the embeddings. for i, W in enumerate(model_adapter.get_embeddings()): - W.weight.data = W.weight.data[:, : new_embedding_dimensions[i]] + W.weight.data = W.weight.data[:, : new_embedding_dimensions[i]].contiguous() W.embedding_dim = new_embedding_dimensions[i] @@ -117,7 +117,7 @@ def rotate_head(model_adapter: ModelAdapter, Q: torch.Tensor) -> None: def slice_head(model_adapter: ModelAdapter, new_embedding_dimension: int) -> None: # Slice the head. lm_head = model_adapter.get_lm_head() - lm_head.weight.data = lm_head.weight.data[:, :new_embedding_dimension] + lm_head.weight.data = lm_head.weight.data[:, :new_embedding_dimension].contiguous() lm_head.in_features = new_embedding_dimension @@ -163,7 +163,11 @@ def rotate_and_slice_sequential( ignore_masks.append(batch["attention_mask"]) layers = model_adapter.get_layers() - slicing_scheduler.setup(hidden_size=model_adapter.hidden_size, layers_num=len(layers), parallel_blocks=False) + slicing_scheduler.setup( + hidden_size=model_adapter.hidden_size, + layers_num=len(layers), + parallel_blocks=False, + ) # rotate and slice embeddings eig_val, Q = pca_calc(inps, ignore_masks) @@ -277,7 +281,11 @@ def rotate_and_slice_parallel( ignore_masks.append(batch["attention_mask"]) layers = model_adapter.get_layers() - slicing_scheduler.setup(hidden_size=model_adapter.hidden_size, layers_num=len(layers), parallel_blocks=True) + slicing_scheduler.setup( + hidden_size=model_adapter.hidden_size, + layers_num=len(layers), + parallel_blocks=True, + ) # rotate and slice embeddings _, Q = pca_calc(inps, ignore_masks) @@ -465,17 +473,21 @@ def slice_rotated_model(model_adapter: ModelAdapter, slicing_scheduler: SlicingS if model_adapter.parallel_blocks: # parallel case layer.attn_shortcut_Q = nn.Parameter( - layer.attn_shortcut_Q[:, : slicing_scheduler.get_attention_output_dimension(i, match_head_dim=True)] + layer.attn_shortcut_Q[ + :, : slicing_scheduler.get_attention_output_dimension(i, match_head_dim=True) + ].contiguous() ) slice_attention_output( layer_adapter, slicing_scheduler.get_attention_output_dimension(i, match_head_dim=True) ) else: # sequential case layer.attn_shortcut_Q = nn.Parameter( - layer.attn_shortcut_Q[:, : slicing_scheduler.get_attention_output_dimension(i, match_head_dim=False)] + layer.attn_shortcut_Q[ + :, : slicing_scheduler.get_attention_output_dimension(i, match_head_dim=False) + ].contiguous() ) layer.mlp_shortcut_Q = nn.Parameter( - layer.mlp_shortcut_Q[:, : slicing_scheduler.get_mlp_output_dimension(i)] + layer.mlp_shortcut_Q[:, : slicing_scheduler.get_mlp_output_dimension(i)].contiguous() ) # slice attention weights 1st dimension diff --git a/tests/test_model_adapter.py b/tests/test_model_adapter.py index c4379e4a..645a7997 100644 --- a/tests/test_model_adapter.py +++ b/tests/test_model_adapter.py @@ -116,7 +116,6 @@ def create_adapter(self) -> LlamaModelAdapter: config = LlamaConfig( vocab_size=32, hidden_size=8, - intermediate_size=32, num_hidden_layers=2, num_attention_heads=2, max_position_embeddings=16, @@ -128,8 +127,6 @@ def create_adapter(self) -> LlamaModelAdapter: class TestPhi2Adapter(ModelAdapterTestBase): def create_adapter(self) -> Phi2ModelAdapter: # a tiny phi, just to test adapter. - config = PhiConfig( - vocab_size=32, hidden_size=8, intermediate_size=32, num_hidden_layers=2, num_attention_heads=2 - ) + config = PhiConfig(vocab_size=32, hidden_size=8, num_hidden_layers=2, num_attention_heads=2) model = PhiForCausalLM(config) return Phi2ModelAdapter(model) diff --git a/tests/test_slicing.py b/tests/test_slicing.py index 79830c27..d108469b 100644 --- a/tests/test_slicing.py +++ b/tests/test_slicing.py @@ -1,8 +1,14 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -from slicegpt import hf_utils, layernorm_fusion +import pytest +import torch +from transformers.models.phi.modeling_phi import PhiConfig + +from slicegpt import data_utils, gpu_utils, hf_utils, layernorm_fusion, rotate from slicegpt.adapters.opt_adapter import OPTModelAdapter +from slicegpt.adapters.sliced_phi import SlicedPhi2Config, SlicedPhiForCausalLM +from slicegpt.slicing_scheduler import ConstSlicingScheduler def test_layernorm_fusion_replaces_modules() -> None: @@ -21,3 +27,106 @@ def test_layernorm_fusion_replaces_modules() -> None: def get_module_names(model) -> list[str]: return [name for name, _ in model.named_parameters()] + + +@pytest.mark.experiment +@pytest.mark.gpu +def test_HF_model(): + """Check that the HF model weights are equivalent to the sliced model weights""" + model_name = "microsoft/phi-2" + model_adapter, tokenizer = hf_utils.get_model_and_tokenizer(model_name) + sparsity = 0.1 + new_hidden_size = 2304 + + layernorm_fusion.replace_layers(model_adapter) + layernorm_fusion.fuse_modules(model_adapter) + + phi_config = PhiConfig.from_pretrained( + model_name, + torch_dtype=torch.float16, + ) + + phi_config.save_pretrained("phi_config") + config = SlicedPhi2Config.from_pretrained( + config_path="phi_config", sparsity=sparsity, new_hidden_size=new_hidden_size + ) + + sliced_model = SlicedPhiForCausalLM(config).to(torch.float16) + sliced_model.load_state_dict(model_adapter.model.state_dict(), strict=True, assign=True) + + # The sliced model weights should be identical to the HF model weights after layer norm fusion + assert compare_weights(model_adapter.model, sliced_model.model) + + dataset = data_utils.get_dataset("wikitext2") + train_dataset, test_dataset = dataset["train"], dataset["test"] + + train_loader = data_utils.prepare_dataloader(dataset=train_dataset, tokenizer=tokenizer) + + test_loader = data_utils.prepare_test_dataloader(dataset=test_dataset, tokenizer=tokenizer) + + scheduler = ConstSlicingScheduler(new_hidden_size) + rotate.rotate_and_slice(model_adapter, train_loader, scheduler, final_orientation="random") + + sliced_ppl = gpu_utils.evaluate_ppl(model_adapter.model.to("cuda"), tokenizer.pad_token_id, test_loader) + + sliced_model = SlicedPhiForCausalLM(config, scheduler).to(torch.float16) + sliced_model = sliced_model.to(torch.float16) + sliced_model.load_state_dict(model_adapter.model.state_dict(), strict=True, assign=True) + sliced_model.save_pretrained("sliced_phi2_model123") + + new_model_ppl = gpu_utils.evaluate_ppl(sliced_model.to("cuda"), tokenizer.pad_token_id, test_loader) + + # The perplexity of the sliced model should be the same as the HF model + assert sliced_ppl == new_model_ppl + + # load the sliced model back + sliced_model = SlicedPhiForCausalLM.from_pretrained( + "sliced_phi2_model", + scheduler=scheduler, + config_path="sliced_phi2_model", + sparsity=sparsity, + new_hidden_size=new_hidden_size, + ) + sliced_model = sliced_model.to(torch.float16) + + assert sliced_model is not None + assert isinstance(sliced_model, SlicedPhiForCausalLM) + assert sliced_model.config.sparsity == sparsity + assert sliced_model.config.new_hidden_size == new_hidden_size + + loaded_model_ppl = gpu_utils.evaluate_ppl(sliced_model.to("cuda"), tokenizer.pad_token_id, test_loader) + assert loaded_model_ppl == new_model_ppl + + +def test_save_and_load_HF_model(): + """Test HF model saving and loading""" + sparsity = 0.0 + new_hidden_size = 2506 + config_name = "sliced_model_config" + model_name = "sliced_model" + + kwargs = {"sparsity": sparsity, "new_hidden_size": new_hidden_size} + + config = SlicedPhi2Config(**kwargs) + config.save_pretrained(config_name) + + config = SlicedPhi2Config.from_pretrained(config_name, sparsity, new_hidden_size) + + sliced_model = SlicedPhiForCausalLM(config).to(torch.float16) + sliced_model.save_pretrained(model_name) + + scheduler = ConstSlicingScheduler(new_hidden_size) + sliced_model = SlicedPhiForCausalLM.from_pretrained( + model_name, scheduler=scheduler, config_path=model_name, sparsity=sparsity, new_hidden_size=new_hidden_size + ) + + assert isinstance(sliced_model, SlicedPhiForCausalLM) + assert sliced_model.config.sparsity == sparsity + assert sliced_model.config.new_hidden_size == new_hidden_size + + +def compare_weights(model1, model2): + for p1, p2 in zip(model1.parameters(), model2.parameters()): + if not torch.equal(p1.data, p2.data): + return False + return True \ No newline at end of file