From 3d9246a98025ca835d1445c6e66d2ee3510f150c Mon Sep 17 00:00:00 2001 From: Pavel Myshkov Date: Mon, 19 Feb 2024 13:22:39 +0000 Subject: [PATCH 1/4] Add variance based scheduler --- src/slicegpt/model_adapter.py | 17 ++-- src/slicegpt/rotate.py | 10 ++- src/slicegpt/slicing_scheduler.py | 128 +++++++++++++++++++++++------- 3 files changed, 117 insertions(+), 38 deletions(-) diff --git a/src/slicegpt/model_adapter.py b/src/slicegpt/model_adapter.py index f9b0f8bc..28315d67 100644 --- a/src/slicegpt/model_adapter.py +++ b/src/slicegpt/model_adapter.py @@ -438,21 +438,22 @@ class SlicingConfig: do_slice_head: bool = False parallel_blocks: bool = False - # use dict[int, int] instead of list[int] to allow for arbitrary order updates and default dicts + # both sequential and parallel blocks case embedding_dimensions: dict[int, int] = field(default_factory=dict) - attention_input_dimensions: dict[int, int] = field(default_factory=dict) - attention_output_dimensions: dict[int, int] = field(default_factory=dict) + mlp_output_dimensions: dict[int, int] = field(default_factory=dict) + # the 2nd path for the sequential blocks case + attention_output_dimensions: dict[int, int] = field(default_factory=dict) mlp_input_dimensions: dict[int, int] = field(default_factory=dict) - mlp_output_dimensions: dict[int, int] = field(default_factory=dict) head_dimension: int | None = None - const_dimension: int | None = None # to be able to load models without config, sliced with const sparsity + # used when loading models sliced with const sparsity that are missing a json config + const_dimension: int | None = None @staticmethod - def from_dict(d: dict) -> 'SlicingConfig': + def from_dict(d: dict) -> SlicingConfig: """Return a SliceConfig object constructed from the provided dictionary.""" def convert_dict_keys_to_int(d: Any) -> Any: @@ -470,7 +471,7 @@ def convert_dict_keys_to_int(d: Any) -> Any: return SlicingConfig(**convert_dict_keys_to_int(d)) @staticmethod - def from_json_string(json_str: str) -> 'SlicingConfig': + def from_json_string(json_str: str) -> SlicingConfig: """Return a SliceConfig object constructed from the provided JSON string.""" return SlicingConfig.from_dict(json.loads(json_str)) @@ -485,6 +486,6 @@ def to_json_string(self) -> str: """Return a JSON representation of this object.""" return json.dumps(self.to_dict()) - def clone(self) -> 'SlicingConfig': + def clone(self) -> SlicingConfig: """Return a clone of this object.""" return copy.deepcopy(self) diff --git a/src/slicegpt/rotate.py b/src/slicegpt/rotate.py index d201e323..b93a4e47 100644 --- a/src/slicegpt/rotate.py +++ b/src/slicegpt/rotate.py @@ -166,6 +166,7 @@ def rotate_and_slice_sequential( # rotate and slice embeddings eig_val, Q = pca_calc(inps, ignore_masks) + slicing_scheduler.set_embedding_eigenvalues(eig_val.detach().cpu().tolist()) Q = Q.to(device=config.device) if final_orientation == 'random': R = random_orthogonal_upper_left(Q.shape[0], slicing_scheduler.get_embedding_dimensions()[0]) @@ -193,6 +194,7 @@ def rotate_and_slice_sequential( mlp_ln_inputs, _ = get_signals(layer_adapter, args, kwargs) eig_val, Q = pca_calc(mlp_ln_inputs, ignore_masks) + slicing_scheduler.set_mlp_eigenvalues(idx, eig_val.detach().cpu().tolist()) Q = Q.to(device=config.device, dtype=torch.float64) if final_orientation == 'random': R = random_orthogonal_upper_left( @@ -224,6 +226,8 @@ def rotate_and_slice_sequential( # with slicing between Attention and mlp. _, inps = get_signals(layer_adapter, args, kwargs) eig_val, Q = pca_calc(inps, ignore_masks) + slicing_scheduler.set_attention_eigenvalues(idx, eig_val.detach().cpu().tolist()) + if final_orientation == 'random': R = random_orthogonal_upper_left(Q.shape[0], slicing_scheduler.get_mlp_output_dimension(idx)) Q = Q @ R.to(Q.device) @@ -279,7 +283,8 @@ def rotate_and_slice_parallel( slicing_scheduler.setup(hidden_size=model_adapter.hidden_size, layers_num=len(layers), parallel_blocks=True) # rotate and slice embeddings - _, Q = pca_calc(inps, ignore_masks) + eig_val, Q = pca_calc(inps, ignore_masks) + slicing_scheduler.set_embedding_eigenvalues(eig_val.detach().cpu().tolist()) Q = Q.to(device=config.device) if final_orientation == 'random': R = random_orthogonal_upper_left(Q.shape[0], slicing_scheduler.get_embedding_dimensions()[0]) @@ -322,7 +327,8 @@ def rotate_and_slice_parallel( outputs.append(out) inps = outputs - _, Q = pca_calc(inps, ignore_masks) + eig_val, Q = pca_calc(inps, ignore_masks) + slicing_scheduler.set_attention_eigenvalues(idx, eig_val.detach().cpu().tolist()) if final_orientation == 'random': R = random_orthogonal_upper_left(Q.shape[0], slicing_scheduler.get_mlp_output_dimension(idx)) diff --git a/src/slicegpt/slicing_scheduler.py b/src/slicegpt/slicing_scheduler.py index e06b78f0..2c899e75 100644 --- a/src/slicegpt/slicing_scheduler.py +++ b/src/slicegpt/slicing_scheduler.py @@ -2,6 +2,8 @@ from collections import defaultdict from typing import Callable, final +import numpy as np + from slicegpt.model_adapter import SlicingConfig @@ -20,6 +22,11 @@ def __init__(self, *, do_slice_head: bool = False): self.slicing_conf: SlicingConfig = SlicingConfig() self.slicing_conf.do_slice_head = do_slice_head + # eigenvalues obtained from PCA + self.embedding_eigenvalues: list[float] = [] + self.attention_eigenvalues: dict[int, list[float]] = {} + self.mlp_eigenvalues: dict[int, list[float]] = {} + @property def do_slice_head(self) -> bool: """Return whether to slice the head.""" @@ -49,17 +56,23 @@ def setup(self, *, hidden_size: int, layers_num: int, parallel_blocks: bool) -> @final def get_embedding_dimensions(self) -> dict[int, int]: """Return the input embedding dimensions.""" - val = self._get_input_embedding_dimensions() + if self.slicing_conf.embedding_dimensions: + return self.slicing_conf.embedding_dimensions + + val = self._get_embedding_dimensions() self.slicing_conf.embedding_dimensions = val return val @abstractmethod - def _get_input_embedding_dimensions(self) -> dict[int, int]: + def _get_embedding_dimensions(self) -> dict[int, int]: raise NotImplementedError @final def get_attention_input_dimension(self, idx: int) -> int: """Return the attention input dimension for the specified layer index.""" + if idx in self.slicing_conf.attention_input_dimensions: + return self.slicing_conf.attention_input_dimensions[idx] + val = self._get_attention_input_dimension(idx) self.slicing_conf.attention_input_dimensions[idx] = val return val @@ -69,12 +82,30 @@ def _get_attention_input_dimension(self, idx: int) -> int: raise NotImplementedError @final - def get_attention_output_dimension(self, idx, match_head_dim: bool) -> int: + def get_mlp_output_dimension(self, idx: int) -> int: + """Return the mlp output dimension for the specified layer index.""" + if idx in self.slicing_conf.mlp_output_dimensions: + return self.slicing_conf.mlp_output_dimensions[idx] + + use_head_dim = idx == self.layers_num - 1 + val = self._get_mlp_output_dimension(idx) if not use_head_dim else self.get_head_dimension() + self.slicing_conf.mlp_output_dimensions[idx] = val + return val + + @abstractmethod + def _get_mlp_output_dimension(self, idx: int) -> int: + raise NotImplementedError + + @final + def get_attention_output_dimension(self, idx, match_head_dim: bool | None = None) -> int: """Return the attention output dimension for the specified layer index.""" if self.parallel_blocks: return self.get_mlp_output_dimension(idx) - use_head_dim = match_head_dim and idx == self.layers_num - 1 + if idx in self.slicing_conf.attention_output_dimensions: + return self.slicing_conf.attention_output_dimensions[idx] + + use_head_dim = idx == self.layers_num - 1 and match_head_dim val = self._get_attention_output_dimension(idx) if not use_head_dim else self.get_head_dimension() self.slicing_conf.attention_output_dimensions[idx] = val return val @@ -89,6 +120,9 @@ def get_mlp_input_dimension(self, idx: int) -> int: if self.parallel_blocks: return self.get_attention_input_dimension(idx) + if idx in self.slicing_conf.mlp_input_dimensions: + return self.slicing_conf.mlp_input_dimensions[idx] + val = self._get_mlp_input_dimension(idx) self.slicing_conf.mlp_input_dimensions[idx] = val return val @@ -97,21 +131,12 @@ def get_mlp_input_dimension(self, idx: int) -> int: def _get_mlp_input_dimension(self, idx: int) -> int: raise NotImplementedError - @final - def get_mlp_output_dimension(self, idx: int) -> int: - """Return the mlp output dimension for the specified layer index.""" - use_head_dim = idx == self.layers_num - 1 - val = self._get_mlp_output_dimension(idx) if not use_head_dim else self.get_head_dimension() - self.slicing_conf.mlp_output_dimensions[idx] = val - return val - - @abstractmethod - def _get_mlp_output_dimension(self, idx: int) -> int: - raise NotImplementedError - @final def get_head_dimension(self) -> int: """Return the LM head dimension.""" + if self.slicing_conf.head_dimension is not None: + return self.slicing_conf.head_dimension + val = self._get_head_dimension() if self.slicing_conf.do_slice_head else self.hidden_size self.slicing_conf.head_dimension = val return val @@ -120,6 +145,18 @@ def get_head_dimension(self) -> int: def _get_head_dimension(self) -> int: raise NotImplementedError + def set_embedding_eigenvalues(self, eigenvalues: list[float]) -> None: + """Set the eigenvalues of the embeddings PCA.""" + self.embedding_eigenvalues = eigenvalues + + def set_attention_eigenvalues(self, idx: int, eigenvalues: list[float]) -> None: + """Set the eigenvalues of the attention layer PCA.""" + self.attention_eigenvalues[idx] = eigenvalues + + def set_mlp_eigenvalues(self, idx: int, eigenvalues: list[float]) -> None: + """Set the eigenvalues of the MLP layer PCA.""" + self.mlp_eigenvalues[idx] = eigenvalues + class ConfigSlicingScheduler(SlicingScheduler): """Slicing scheduler that returns the dimensions specified in the config.""" @@ -128,21 +165,21 @@ def __init__(self, config: SlicingConfig): super().__init__() self.slicing_conf = config - def _get_input_embedding_dimensions(self) -> dict[int, int]: + def _get_embedding_dimensions(self) -> dict[int, int]: return self.slicing_conf.embedding_dimensions def _get_attention_input_dimension(self, idx: int) -> int: return self.slicing_conf.attention_input_dimensions[idx] + def _get_mlp_output_dimension(self, idx: int) -> int: + return self.slicing_conf.mlp_output_dimensions[idx] + def _get_attention_output_dimension(self, idx: int) -> int: return self.slicing_conf.attention_output_dimensions[idx] def _get_mlp_input_dimension(self, idx: int) -> int: return self.slicing_conf.mlp_input_dimensions[idx] - def _get_mlp_output_dimension(self, idx: int) -> int: - return self.slicing_conf.mlp_output_dimensions[idx] - def _get_head_dimension(self) -> int: return self.slicing_conf.head_dimension @@ -154,19 +191,19 @@ def __init__(self, dimension: int, *, do_slice_head: bool = False): super().__init__(do_slice_head=do_slice_head) self.dimension: int = dimension - def _get_input_embedding_dimensions(self) -> dict[int, int]: + def _get_embedding_dimensions(self) -> dict[int, int]: return defaultdict(lambda: self.dimension) def _get_attention_input_dimension(self, idx: int) -> int: return self.dimension - def _get_attention_output_dimension(self, idx: int) -> int: + def _get_mlp_output_dimension(self, idx: int) -> int: return self.dimension - def _get_mlp_input_dimension(self, idx: int) -> int: + def _get_attention_output_dimension(self, idx: int) -> int: return self.dimension - def _get_mlp_output_dimension(self, idx: int) -> int: + def _get_mlp_input_dimension(self, idx: int) -> int: return self.dimension def _get_head_dimension(self) -> int: @@ -186,13 +223,13 @@ def __init__(self, *, do_slice_head: bool = False): def _get_attention_input_dimension(self, idx: int) -> int: # return the input embedding dimension when at the first attn layer inputs if idx == 0: - return self._get_input_embedding_dimensions()[0] # all dimensions are the same there + return self.get_embedding_dimensions()[0] # all dimensions are the same there - return self._get_mlp_output_dimension(idx - 1) + return self.get_mlp_output_dimension(idx - 1) @final def _get_mlp_input_dimension(self, idx: int) -> int: - return self._get_attention_output_dimension(idx) + return self.get_attention_output_dimension(idx) class FunctionSlicingScheduler(ForwardSlicingScheduler): @@ -222,7 +259,7 @@ def _get_layer_dimension(self, idx: int, is_attn_layer: bool = False) -> int: val -= val % self.round_interval return val - def _get_input_embedding_dimensions(self) -> dict[int, int]: + def _get_embedding_dimensions(self) -> dict[int, int]: return defaultdict(lambda: self._get_layer_dimension(0)) def _get_attention_output_dimension(self, idx: int) -> int: @@ -259,3 +296,38 @@ def linear_sparsity_func(location: float) -> float: round_interval=round_interval, do_slice_head=do_slice_head, ) + + +class ExplainedVarianceSlicingScheduler(ForwardSlicingScheduler): + """A slicing scheduler that applies sparsity based on the explained variance from the PCA.""" + + def __init__( + self, + *, + uev_threshold: float, + round_interval: int = 1, + do_slice_head: bool = False, + ): + super().__init__(do_slice_head=do_slice_head) + self.uev_threshold: float = uev_threshold + self.round_interval: int = round_interval + + def _get_layer_dimension(self, eigen_vals: list[float], plot: bool = False) -> int: + eigen_vals = np.array(eigen_vals) + cum_var = np.cumsum(np.array(eigen_vals)) / np.sum(eigen_vals) + dim = np.argmax(cum_var > 1 - self.uev_threshold) + dim -= dim % self.round_interval + dim = int(dim) + return dim + + def _get_embedding_dimensions(self) -> dict[int, int]: + return defaultdict(lambda: self._get_layer_dimension(self.embedding_eigenvalues)) + + def _get_attention_output_dimension(self, idx: int) -> int: + return self._get_layer_dimension(self.mlp_eigenvalues[idx]) + + def _get_mlp_output_dimension(self, idx: int) -> int: + return self._get_layer_dimension(self.attention_eigenvalues[idx]) + + def _get_head_dimension(self) -> int: + return self.get_attention_output_dimension(self.layers_num - 1) From c933c9b2501c2ade98aedf811b115726eeb611a8 Mon Sep 17 00:00:00 2001 From: Pavel Myshkov Date: Mon, 19 Feb 2024 15:01:25 +0000 Subject: [PATCH 2/4] Fix formatting --- src/slicegpt/rotate.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/slicegpt/rotate.py b/src/slicegpt/rotate.py index b93a4e47..f085e175 100644 --- a/src/slicegpt/rotate.py +++ b/src/slicegpt/rotate.py @@ -227,7 +227,7 @@ def rotate_and_slice_sequential( _, inps = get_signals(layer_adapter, args, kwargs) eig_val, Q = pca_calc(inps, ignore_masks) slicing_scheduler.set_attention_eigenvalues(idx, eig_val.detach().cpu().tolist()) - + if final_orientation == 'random': R = random_orthogonal_upper_left(Q.shape[0], slicing_scheduler.get_mlp_output_dimension(idx)) Q = Q @ R.to(Q.device) From 1921764129ec948cce9a6daad3414dc973221c73 Mon Sep 17 00:00:00 2001 From: Pavel Myshkov Date: Mon, 19 Feb 2024 15:37:44 +0000 Subject: [PATCH 3/4] Remove empty line --- src/slicegpt/rotate.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/slicegpt/rotate.py b/src/slicegpt/rotate.py index f085e175..ee68f92c 100644 --- a/src/slicegpt/rotate.py +++ b/src/slicegpt/rotate.py @@ -227,7 +227,6 @@ def rotate_and_slice_sequential( _, inps = get_signals(layer_adapter, args, kwargs) eig_val, Q = pca_calc(inps, ignore_masks) slicing_scheduler.set_attention_eigenvalues(idx, eig_val.detach().cpu().tolist()) - if final_orientation == 'random': R = random_orthogonal_upper_left(Q.shape[0], slicing_scheduler.get_mlp_output_dimension(idx)) Q = Q @ R.to(Q.device) From 4b056e29119fc0388f88b4140740180af54cfdac Mon Sep 17 00:00:00 2001 From: Pavel Myshkov Date: Mon, 19 Feb 2024 16:56:28 +0000 Subject: [PATCH 4/4] Add copyright, use string annotations --- src/slicegpt/slicing_scheduler.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/slicegpt/slicing_scheduler.py b/src/slicegpt/slicing_scheduler.py index 2c899e75..fe551673 100644 --- a/src/slicegpt/slicing_scheduler.py +++ b/src/slicegpt/slicing_scheduler.py @@ -1,3 +1,8 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from __future__ import annotations + from abc import ABC, abstractmethod from collections import defaultdict from typing import Callable, final @@ -279,7 +284,7 @@ def create_linear( attn_end: float | None = None, round_interval: int = 1, do_slice_head: bool = False, - ) -> 'FunctionSlicingScheduler': + ) -> FunctionSlicingScheduler: """Create a linear slicing scheduler, mainly as an example for testing.""" def linear(start: float, end: float) -> Callable[[float], float]: