diff --git a/src/pruna/algorithms/sage_attn.py b/src/pruna/algorithms/sage_attn.py index a20c2537..aa921e6c 100644 --- a/src/pruna/algorithms/sage_attn.py +++ b/src/pruna/algorithms/sage_attn.py @@ -19,6 +19,7 @@ import torch from diffusers import DiffusionPipeline +from typing_extensions import cast from pruna.algorithms.base.pruna_base import PrunaAlgorithmBase from pruna.algorithms.base.tags import AlgorithmTag as tags @@ -91,10 +92,8 @@ def _apply(self, model: Any, smash_config: SmashConfigPrefixWrapper) -> Any: target_modules = smash_config["target_modules"] if target_modules is None: - target_modules = self.get_model_dependent_hyperparameter_defaults( - model, - smash_config - ) + target_modules = self.get_model_dependent_hyperparameter_defaults(model, smash_config)["target_modules"] + target_modules = cast(TARGET_MODULES_TYPE, target_modules) def apply_sage_attn( root_name: str | None, @@ -154,7 +153,7 @@ def get_model_dependent_hyperparameter_defaults( self, model: Any, smash_config: SmashConfigPrefixWrapper, - ) -> TARGET_MODULES_TYPE: + ) -> dict[str, Any]: """ Provide default `target_modules` targeting all transformer modules. @@ -178,5 +177,5 @@ def get_model_dependent_hyperparameter_defaults( # SageAttn might also be applicable to other modules but could significantly decrease model quality. include = ["transformer*"] exclude = [] - - return {"include": include, "exclude": exclude} + target_modules = {"include": include, "exclude": exclude} + return {"target_modules": target_modules}