From 01d797638e2c2204f00c38d21091f02c843eb349 Mon Sep 17 00:00:00 2001 From: Gaspar Rochette Date: Wed, 18 Feb 2026 15:37:51 +0000 Subject: [PATCH] fix: allign sage_attn default hyperparameter return type with refactor --- src/pruna/algorithms/sage_attn.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/src/pruna/algorithms/sage_attn.py b/src/pruna/algorithms/sage_attn.py index a20c2537..aa921e6c 100644 --- a/src/pruna/algorithms/sage_attn.py +++ b/src/pruna/algorithms/sage_attn.py @@ -19,6 +19,7 @@ import torch from diffusers import DiffusionPipeline +from typing_extensions import cast from pruna.algorithms.base.pruna_base import PrunaAlgorithmBase from pruna.algorithms.base.tags import AlgorithmTag as tags @@ -91,10 +92,8 @@ def _apply(self, model: Any, smash_config: SmashConfigPrefixWrapper) -> Any: target_modules = smash_config["target_modules"] if target_modules is None: - target_modules = self.get_model_dependent_hyperparameter_defaults( - model, - smash_config - ) + target_modules = self.get_model_dependent_hyperparameter_defaults(model, smash_config)["target_modules"] + target_modules = cast(TARGET_MODULES_TYPE, target_modules) def apply_sage_attn( root_name: str | None, @@ -154,7 +153,7 @@ def get_model_dependent_hyperparameter_defaults( self, model: Any, smash_config: SmashConfigPrefixWrapper, - ) -> TARGET_MODULES_TYPE: + ) -> dict[str, Any]: """ Provide default `target_modules` targeting all transformer modules. @@ -178,5 +177,5 @@ def get_model_dependent_hyperparameter_defaults( # SageAttn might also be applicable to other modules but could significantly decrease model quality. include = ["transformer*"] exclude = [] - - return {"include": include, "exclude": exclude} + target_modules = {"include": include, "exclude": exclude} + return {"target_modules": target_modules}