diff --git a/src/slicegpt/model_adapter.py b/src/slicegpt/model_adapter.py index f9b0f8bc..28315d67 100644 --- a/src/slicegpt/model_adapter.py +++ b/src/slicegpt/model_adapter.py @@ -438,21 +438,22 @@ class SlicingConfig: do_slice_head: bool = False parallel_blocks: bool = False - # use dict[int, int] instead of list[int] to allow for arbitrary order updates and default dicts + # both sequential and parallel blocks case embedding_dimensions: dict[int, int] = field(default_factory=dict) - attention_input_dimensions: dict[int, int] = field(default_factory=dict) - attention_output_dimensions: dict[int, int] = field(default_factory=dict) + mlp_output_dimensions: dict[int, int] = field(default_factory=dict) + # the 2nd path for the sequential blocks case + attention_output_dimensions: dict[int, int] = field(default_factory=dict) mlp_input_dimensions: dict[int, int] = field(default_factory=dict) - mlp_output_dimensions: dict[int, int] = field(default_factory=dict) head_dimension: int | None = None - const_dimension: int | None = None # to be able to load models without config, sliced with const sparsity + # used when loading models sliced with const sparsity that are missing a json config + const_dimension: int | None = None @staticmethod - def from_dict(d: dict) -> 'SlicingConfig': + def from_dict(d: dict) -> SlicingConfig: """Return a SliceConfig object constructed from the provided dictionary.""" def convert_dict_keys_to_int(d: Any) -> Any: @@ -470,7 +471,7 @@ def convert_dict_keys_to_int(d: Any) -> Any: return SlicingConfig(**convert_dict_keys_to_int(d)) @staticmethod - def from_json_string(json_str: str) -> 'SlicingConfig': + def from_json_string(json_str: str) -> SlicingConfig: """Return a SliceConfig object constructed from the provided JSON string.""" return SlicingConfig.from_dict(json.loads(json_str)) @@ -485,6 +486,6 @@ def to_json_string(self) -> str: """Return a JSON representation of this object.""" return json.dumps(self.to_dict()) - def clone(self) -> 'SlicingConfig': + def clone(self) -> SlicingConfig: """Return a clone of this object.""" return copy.deepcopy(self) diff --git a/src/slicegpt/rotate.py b/src/slicegpt/rotate.py index d201e323..ee68f92c 100644 --- a/src/slicegpt/rotate.py +++ b/src/slicegpt/rotate.py @@ -166,6 +166,7 @@ def rotate_and_slice_sequential( # rotate and slice embeddings eig_val, Q = pca_calc(inps, ignore_masks) + slicing_scheduler.set_embedding_eigenvalues(eig_val.detach().cpu().tolist()) Q = Q.to(device=config.device) if final_orientation == 'random': R = random_orthogonal_upper_left(Q.shape[0], slicing_scheduler.get_embedding_dimensions()[0]) @@ -193,6 +194,7 @@ def rotate_and_slice_sequential( mlp_ln_inputs, _ = get_signals(layer_adapter, args, kwargs) eig_val, Q = pca_calc(mlp_ln_inputs, ignore_masks) + slicing_scheduler.set_mlp_eigenvalues(idx, eig_val.detach().cpu().tolist()) Q = Q.to(device=config.device, dtype=torch.float64) if final_orientation == 'random': R = random_orthogonal_upper_left( @@ -224,6 +226,7 @@ def rotate_and_slice_sequential( # with slicing between Attention and mlp. _, inps = get_signals(layer_adapter, args, kwargs) eig_val, Q = pca_calc(inps, ignore_masks) + slicing_scheduler.set_attention_eigenvalues(idx, eig_val.detach().cpu().tolist()) if final_orientation == 'random': R = random_orthogonal_upper_left(Q.shape[0], slicing_scheduler.get_mlp_output_dimension(idx)) Q = Q @ R.to(Q.device) @@ -279,7 +282,8 @@ def rotate_and_slice_parallel( slicing_scheduler.setup(hidden_size=model_adapter.hidden_size, layers_num=len(layers), parallel_blocks=True) # rotate and slice embeddings - _, Q = pca_calc(inps, ignore_masks) + eig_val, Q = pca_calc(inps, ignore_masks) + slicing_scheduler.set_embedding_eigenvalues(eig_val.detach().cpu().tolist()) Q = Q.to(device=config.device) if final_orientation == 'random': R = random_orthogonal_upper_left(Q.shape[0], slicing_scheduler.get_embedding_dimensions()[0]) @@ -322,7 +326,8 @@ def rotate_and_slice_parallel( outputs.append(out) inps = outputs - _, Q = pca_calc(inps, ignore_masks) + eig_val, Q = pca_calc(inps, ignore_masks) + slicing_scheduler.set_attention_eigenvalues(idx, eig_val.detach().cpu().tolist()) if final_orientation == 'random': R = random_orthogonal_upper_left(Q.shape[0], slicing_scheduler.get_mlp_output_dimension(idx)) diff --git a/src/slicegpt/slicing_scheduler.py b/src/slicegpt/slicing_scheduler.py index e06b78f0..fe551673 100644 --- a/src/slicegpt/slicing_scheduler.py +++ b/src/slicegpt/slicing_scheduler.py @@ -1,7 +1,14 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from __future__ import annotations + from abc import ABC, abstractmethod from collections import defaultdict from typing import Callable, final +import numpy as np + from slicegpt.model_adapter import SlicingConfig @@ -20,6 +27,11 @@ def __init__(self, *, do_slice_head: bool = False): self.slicing_conf: SlicingConfig = SlicingConfig() self.slicing_conf.do_slice_head = do_slice_head + # eigenvalues obtained from PCA + self.embedding_eigenvalues: list[float] = [] + self.attention_eigenvalues: dict[int, list[float]] = {} + self.mlp_eigenvalues: dict[int, list[float]] = {} + @property def do_slice_head(self) -> bool: """Return whether to slice the head.""" @@ -49,17 +61,23 @@ def setup(self, *, hidden_size: int, layers_num: int, parallel_blocks: bool) -> @final def get_embedding_dimensions(self) -> dict[int, int]: """Return the input embedding dimensions.""" - val = self._get_input_embedding_dimensions() + if self.slicing_conf.embedding_dimensions: + return self.slicing_conf.embedding_dimensions + + val = self._get_embedding_dimensions() self.slicing_conf.embedding_dimensions = val return val @abstractmethod - def _get_input_embedding_dimensions(self) -> dict[int, int]: + def _get_embedding_dimensions(self) -> dict[int, int]: raise NotImplementedError @final def get_attention_input_dimension(self, idx: int) -> int: """Return the attention input dimension for the specified layer index.""" + if idx in self.slicing_conf.attention_input_dimensions: + return self.slicing_conf.attention_input_dimensions[idx] + val = self._get_attention_input_dimension(idx) self.slicing_conf.attention_input_dimensions[idx] = val return val @@ -69,12 +87,30 @@ def _get_attention_input_dimension(self, idx: int) -> int: raise NotImplementedError @final - def get_attention_output_dimension(self, idx, match_head_dim: bool) -> int: + def get_mlp_output_dimension(self, idx: int) -> int: + """Return the mlp output dimension for the specified layer index.""" + if idx in self.slicing_conf.mlp_output_dimensions: + return self.slicing_conf.mlp_output_dimensions[idx] + + use_head_dim = idx == self.layers_num - 1 + val = self._get_mlp_output_dimension(idx) if not use_head_dim else self.get_head_dimension() + self.slicing_conf.mlp_output_dimensions[idx] = val + return val + + @abstractmethod + def _get_mlp_output_dimension(self, idx: int) -> int: + raise NotImplementedError + + @final + def get_attention_output_dimension(self, idx, match_head_dim: bool | None = None) -> int: """Return the attention output dimension for the specified layer index.""" if self.parallel_blocks: return self.get_mlp_output_dimension(idx) - use_head_dim = match_head_dim and idx == self.layers_num - 1 + if idx in self.slicing_conf.attention_output_dimensions: + return self.slicing_conf.attention_output_dimensions[idx] + + use_head_dim = idx == self.layers_num - 1 and match_head_dim val = self._get_attention_output_dimension(idx) if not use_head_dim else self.get_head_dimension() self.slicing_conf.attention_output_dimensions[idx] = val return val @@ -89,6 +125,9 @@ def get_mlp_input_dimension(self, idx: int) -> int: if self.parallel_blocks: return self.get_attention_input_dimension(idx) + if idx in self.slicing_conf.mlp_input_dimensions: + return self.slicing_conf.mlp_input_dimensions[idx] + val = self._get_mlp_input_dimension(idx) self.slicing_conf.mlp_input_dimensions[idx] = val return val @@ -97,21 +136,12 @@ def get_mlp_input_dimension(self, idx: int) -> int: def _get_mlp_input_dimension(self, idx: int) -> int: raise NotImplementedError - @final - def get_mlp_output_dimension(self, idx: int) -> int: - """Return the mlp output dimension for the specified layer index.""" - use_head_dim = idx == self.layers_num - 1 - val = self._get_mlp_output_dimension(idx) if not use_head_dim else self.get_head_dimension() - self.slicing_conf.mlp_output_dimensions[idx] = val - return val - - @abstractmethod - def _get_mlp_output_dimension(self, idx: int) -> int: - raise NotImplementedError - @final def get_head_dimension(self) -> int: """Return the LM head dimension.""" + if self.slicing_conf.head_dimension is not None: + return self.slicing_conf.head_dimension + val = self._get_head_dimension() if self.slicing_conf.do_slice_head else self.hidden_size self.slicing_conf.head_dimension = val return val @@ -120,6 +150,18 @@ def get_head_dimension(self) -> int: def _get_head_dimension(self) -> int: raise NotImplementedError + def set_embedding_eigenvalues(self, eigenvalues: list[float]) -> None: + """Set the eigenvalues of the embeddings PCA.""" + self.embedding_eigenvalues = eigenvalues + + def set_attention_eigenvalues(self, idx: int, eigenvalues: list[float]) -> None: + """Set the eigenvalues of the attention layer PCA.""" + self.attention_eigenvalues[idx] = eigenvalues + + def set_mlp_eigenvalues(self, idx: int, eigenvalues: list[float]) -> None: + """Set the eigenvalues of the MLP layer PCA.""" + self.mlp_eigenvalues[idx] = eigenvalues + class ConfigSlicingScheduler(SlicingScheduler): """Slicing scheduler that returns the dimensions specified in the config.""" @@ -128,21 +170,21 @@ def __init__(self, config: SlicingConfig): super().__init__() self.slicing_conf = config - def _get_input_embedding_dimensions(self) -> dict[int, int]: + def _get_embedding_dimensions(self) -> dict[int, int]: return self.slicing_conf.embedding_dimensions def _get_attention_input_dimension(self, idx: int) -> int: return self.slicing_conf.attention_input_dimensions[idx] + def _get_mlp_output_dimension(self, idx: int) -> int: + return self.slicing_conf.mlp_output_dimensions[idx] + def _get_attention_output_dimension(self, idx: int) -> int: return self.slicing_conf.attention_output_dimensions[idx] def _get_mlp_input_dimension(self, idx: int) -> int: return self.slicing_conf.mlp_input_dimensions[idx] - def _get_mlp_output_dimension(self, idx: int) -> int: - return self.slicing_conf.mlp_output_dimensions[idx] - def _get_head_dimension(self) -> int: return self.slicing_conf.head_dimension @@ -154,19 +196,19 @@ def __init__(self, dimension: int, *, do_slice_head: bool = False): super().__init__(do_slice_head=do_slice_head) self.dimension: int = dimension - def _get_input_embedding_dimensions(self) -> dict[int, int]: + def _get_embedding_dimensions(self) -> dict[int, int]: return defaultdict(lambda: self.dimension) def _get_attention_input_dimension(self, idx: int) -> int: return self.dimension - def _get_attention_output_dimension(self, idx: int) -> int: + def _get_mlp_output_dimension(self, idx: int) -> int: return self.dimension - def _get_mlp_input_dimension(self, idx: int) -> int: + def _get_attention_output_dimension(self, idx: int) -> int: return self.dimension - def _get_mlp_output_dimension(self, idx: int) -> int: + def _get_mlp_input_dimension(self, idx: int) -> int: return self.dimension def _get_head_dimension(self) -> int: @@ -186,13 +228,13 @@ def __init__(self, *, do_slice_head: bool = False): def _get_attention_input_dimension(self, idx: int) -> int: # return the input embedding dimension when at the first attn layer inputs if idx == 0: - return self._get_input_embedding_dimensions()[0] # all dimensions are the same there + return self.get_embedding_dimensions()[0] # all dimensions are the same there - return self._get_mlp_output_dimension(idx - 1) + return self.get_mlp_output_dimension(idx - 1) @final def _get_mlp_input_dimension(self, idx: int) -> int: - return self._get_attention_output_dimension(idx) + return self.get_attention_output_dimension(idx) class FunctionSlicingScheduler(ForwardSlicingScheduler): @@ -222,7 +264,7 @@ def _get_layer_dimension(self, idx: int, is_attn_layer: bool = False) -> int: val -= val % self.round_interval return val - def _get_input_embedding_dimensions(self) -> dict[int, int]: + def _get_embedding_dimensions(self) -> dict[int, int]: return defaultdict(lambda: self._get_layer_dimension(0)) def _get_attention_output_dimension(self, idx: int) -> int: @@ -242,7 +284,7 @@ def create_linear( attn_end: float | None = None, round_interval: int = 1, do_slice_head: bool = False, - ) -> 'FunctionSlicingScheduler': + ) -> FunctionSlicingScheduler: """Create a linear slicing scheduler, mainly as an example for testing.""" def linear(start: float, end: float) -> Callable[[float], float]: @@ -259,3 +301,38 @@ def linear_sparsity_func(location: float) -> float: round_interval=round_interval, do_slice_head=do_slice_head, ) + + +class ExplainedVarianceSlicingScheduler(ForwardSlicingScheduler): + """A slicing scheduler that applies sparsity based on the explained variance from the PCA.""" + + def __init__( + self, + *, + uev_threshold: float, + round_interval: int = 1, + do_slice_head: bool = False, + ): + super().__init__(do_slice_head=do_slice_head) + self.uev_threshold: float = uev_threshold + self.round_interval: int = round_interval + + def _get_layer_dimension(self, eigen_vals: list[float], plot: bool = False) -> int: + eigen_vals = np.array(eigen_vals) + cum_var = np.cumsum(np.array(eigen_vals)) / np.sum(eigen_vals) + dim = np.argmax(cum_var > 1 - self.uev_threshold) + dim -= dim % self.round_interval + dim = int(dim) + return dim + + def _get_embedding_dimensions(self) -> dict[int, int]: + return defaultdict(lambda: self._get_layer_dimension(self.embedding_eigenvalues)) + + def _get_attention_output_dimension(self, idx: int) -> int: + return self._get_layer_dimension(self.mlp_eigenvalues[idx]) + + def _get_mlp_output_dimension(self, idx: int) -> int: + return self._get_layer_dimension(self.attention_eigenvalues[idx]) + + def _get_head_dimension(self) -> int: + return self.get_attention_output_dimension(self.layers_num - 1)