Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 46 additions & 0 deletions .github/workflows/pr-title.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
name: PR Title Convention

on:
pull_request:
types: [opened, edited, synchronize, reopened]
branches: [main]

permissions:
pull-requests: read

jobs:
check-title:
name: Validate PR title
runs-on: ubuntu-22.04
timeout-minutes: 1
steps:
- name: Check conventional commit format
env:
PR_TITLE: ${{ github.event.pull_request.title }}
run: |
# Allowed conventional commit types
TYPES="feat|fix|docs|style|refactor|perf|test|build|ci|chore|revert"

# Pattern: type(optional-scope): description
# OR: type!: description (breaking change)
PATTERN="^($TYPES)(\(.+\))?\!?: .+"

if echo "$PR_TITLE" | grep -qP "$PATTERN"; then
echo "PR title is valid: $PR_TITLE"
else
echo "::error::PR title does not follow Conventional Commits."
echo ""
echo "Got: $PR_TITLE"
echo ""
echo "Expected: <type>[optional scope]: <description>"
echo ""
echo "Allowed types: feat, fix, docs, style, refactor, perf, test, build, ci, chore, revert"
echo "Read more: https://www.conventionalcommits.org/en/v1.0.0/"
echo ""
echo "Examples:"
echo " feat: add new optimization algorithm"
echo " fix: resolve memory leak in model loading"
echo " ci(pruna): pin transformers version"
echo ""
exit 1
fi
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ repos:
hooks:
- id: ty
name: type checking using ty
entry: uvx ty check .
entry: uvx ty check src/pruna
language: system
types: [python]
pass_filenames: false
Expand Down
6 changes: 4 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,10 @@ invalid-return-type = "ignore" # mypy is more permissive with return types
invalid-parameter-default = "ignore" # mypy is more permissive with parameter defaults
no-matching-overload = "ignore" # mypy is more permissive with overloads
unresolved-reference = "ignore" # mypy is more permissive with references
possibly-unbound-import = "ignore"
possibly-missing-import = "ignore"
possibly-missing-attribute = "ignore"
missing-argument = "ignore"
unused-type-ignore-comment = "ignore"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I fully agree that we should ignore the comments for now, and I'll go through the code in a future PR to remove those ignore statements one by one because this frankly isn't checking anything...


[tool.coverage.run]
source = ["src/pruna"]
Expand Down Expand Up @@ -181,7 +183,7 @@ dev = [
"pytest-rerunfailures",
"coverage",
"docutils",
"ty==0.0.1a21",
"ty==0.0.17",
"types-PyYAML",
"logbar",
"pytest-xdist>=3.8.0",
Expand Down
4 changes: 2 additions & 2 deletions src/pruna/algorithms/c_translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,7 +393,7 @@ def __call__(
x_tensor = x["input_ids"]
else:
x_tensor = x
token_list = [self.tokenizer.convert_ids_to_tokens(x_tensor[i]) for i in range(len(x_tensor))]
token_list = [self.tokenizer.convert_ids_to_tokens(x_tensor[i]) for i in range(len(x_tensor))] # type: ignore[not-subscriptable]
return self.generator.generate_batch(token_list, min_length=min_length, max_length=max_length, *args, **kwargs) # type: ignore[operator]


Expand Down Expand Up @@ -468,7 +468,7 @@ def __call__(
x_tensor = x["input_ids"]
else:
x_tensor = x
token_list = [self.tokenizer.convert_ids_to_tokens(x_tensor[i]) for i in range(len(x_tensor))]
token_list = [self.tokenizer.convert_ids_to_tokens(x_tensor[i]) for i in range(len(x_tensor))] # type: ignore[not-subscriptable]
return self.translator.translate_batch( # type: ignore[operator]
token_list,
min_decoding_length=min_decoding_length,
Expand Down
6 changes: 3 additions & 3 deletions src/pruna/algorithms/hqq.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from pruna.algorithms.base.pruna_base import PrunaAlgorithmBase
from pruna.algorithms.base.tags import AlgorithmTag as tags
from pruna.config.hyperparameters import Boolean
from pruna.config.smash_config import SmashConfigPrefixWrapper
from pruna.config.smash_config import SmashConfig, SmashConfigPrefixWrapper
from pruna.config.target_modules import (
TARGET_MODULES_TYPE,
TargetModules,
Expand Down Expand Up @@ -134,8 +134,8 @@ def model_check_fn(self, model: Any) -> bool:
return is_causal_lm(model) or is_janus_llamagen_ar(model) or is_transformers_pipeline_with_causal_lm(model)

def get_model_dependent_hyperparameter_defaults(
self, model: Any, smash_config: SmashConfigPrefixWrapper
) -> dict[str, Any]:
self, model: Any, smash_config: SmashConfig | SmashConfigPrefixWrapper
) -> TARGET_MODULES_TYPE: # ty: ignore[invalid-method-override]
"""
Provide default `target_modules` using `target_backbone` to target the model backbone.

Expand Down
6 changes: 3 additions & 3 deletions src/pruna/algorithms/hqq_diffusers.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

from pruna.algorithms.base.pruna_base import PrunaAlgorithmBase
from pruna.algorithms.base.tags import AlgorithmTag as tags
from pruna.config.smash_config import SmashConfigPrefixWrapper
from pruna.config.smash_config import SmashConfig, SmashConfigPrefixWrapper
from pruna.config.target_modules import (
TARGET_MODULES_TYPE,
TargetModules,
Expand Down Expand Up @@ -130,8 +130,8 @@ def model_check_fn(self, model: Any) -> bool:
return any(isinstance(attr_value, tuple(transformer_and_unet_models)) for attr_value in model.__dict__.values())

def get_model_dependent_hyperparameter_defaults(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The signature for get_model_dependent_hyperparameter_defaults changed in #520 :

  • the smash_config argument is always a SmashConfigPrefixWrapper
  • the output type is a dict[str, Any] where keys are hyperparameter names and values are their associated default value. Unfortunately, TARGET_MODULES_TYPE fits that type so some algorithms haven't been properly updated.

In all cases, these signature changes shouldn't be made for algorithms that return e.g. {"target_modules": some_TARGET_MODULES_TYPE_value}.

self, model: Any, smash_config: SmashConfigPrefixWrapper
) -> dict[str, Any]:
self, model: Any, smash_config: SmashConfig | SmashConfigPrefixWrapper
) -> TARGET_MODULES_TYPE: # ty: ignore[invalid-method-override]
"""
Provide default `target_modules` by detecting transformer and unet components in the pipeline.

Expand Down
6 changes: 3 additions & 3 deletions src/pruna/algorithms/llm_compressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

from pruna.algorithms.base.pruna_base import PrunaAlgorithmBase
from pruna.algorithms.base.tags import AlgorithmTag as tags
from pruna.config.smash_config import SmashConfigPrefixWrapper
from pruna.config.smash_config import SmashConfig, SmashConfigPrefixWrapper
from pruna.config.target_modules import (
TARGET_MODULES_TYPE,
TargetModules,
Expand Down Expand Up @@ -98,8 +98,8 @@ def model_check_fn(self, model: Any) -> bool:
return is_causal_lm(model) or is_transformers_pipeline_with_causal_lm(model)

def get_model_dependent_hyperparameter_defaults(
self, model: Any, smash_config: SmashConfigPrefixWrapper
) -> dict[str, Any]:
self, model: Any, smash_config: SmashConfig | SmashConfigPrefixWrapper
) -> TARGET_MODULES_TYPE: # ty: ignore[invalid-method-override]
"""
Provide default `target_modules` using `target_backbone` to target the model backbone.

Expand Down
11 changes: 4 additions & 7 deletions src/pruna/algorithms/sage_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

from pruna.algorithms.base.pruna_base import PrunaAlgorithmBase
from pruna.algorithms.base.tags import AlgorithmTag as tags
from pruna.config.smash_config import SmashConfigPrefixWrapper
from pruna.config.smash_config import SmashConfig, SmashConfigPrefixWrapper
from pruna.config.target_modules import TARGET_MODULES_TYPE, TargetModules, map_targeted_nn_roots
from pruna.engine.save import SAVE_FUNCTIONS

Expand Down Expand Up @@ -91,10 +91,7 @@ def _apply(self, model: Any, smash_config: SmashConfigPrefixWrapper) -> Any:
target_modules = smash_config["target_modules"]

if target_modules is None:
target_modules = self.get_model_dependent_hyperparameter_defaults(
model,
smash_config
)
target_modules = self.get_model_dependent_hyperparameter_defaults(model, smash_config)

def apply_sage_attn(
root_name: str | None,
Expand Down Expand Up @@ -153,8 +150,8 @@ def get_hyperparameters(self) -> list:
def get_model_dependent_hyperparameter_defaults(
self,
model: Any,
smash_config: SmashConfigPrefixWrapper,
) -> TARGET_MODULES_TYPE:
smash_config: SmashConfig | SmashConfigPrefixWrapper,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There shouldn't be any change in this file, but there is indeed a type problem here since the file hasn't been updated when changing the return type for the base class' method. I opened a quick PR #540 fixing exactly this

) -> TARGET_MODULES_TYPE: # ty: ignore[invalid-method-override]
"""
Provide default `target_modules` targeting all transformer modules.

Expand Down
2 changes: 1 addition & 1 deletion src/pruna/algorithms/torch_dynamic.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def _apply(self, model: Any, smash_config: SmashConfigPrefixWrapper) -> Any:
else:
modules_to_quantize = {torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d, torch.nn.Linear}

quantized_model = torch.quantization.quantize_dynamic(
quantized_model = torch.quantization.quantize_dynamic( # type: ignore[deprecated]
model,
modules_to_quantize,
dtype=getattr(torch, smash_config["weight_bits"]),
Expand Down
6 changes: 3 additions & 3 deletions src/pruna/algorithms/torchao.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

from pruna.algorithms.base.pruna_base import PrunaAlgorithmBase
from pruna.algorithms.base.tags import AlgorithmTag as tags
from pruna.config.smash_config import SmashConfigPrefixWrapper
from pruna.config.smash_config import SmashConfig, SmashConfigPrefixWrapper
from pruna.config.target_modules import TARGET_MODULES_TYPE, TargetModules, map_targeted_nn_roots, target_backbone
from pruna.engine.model_checks import (
get_diffusers_transformer_models,
Expand Down Expand Up @@ -174,8 +174,8 @@ def model_check_fn(self, model: Any) -> bool:
return isinstance(model, torch.nn.Module)

def get_model_dependent_hyperparameter_defaults(
self, model: Any, smash_config: SmashConfigPrefixWrapper
) -> dict[str, Any]:
self, model: Any, smash_config: SmashConfig | SmashConfigPrefixWrapper
) -> TARGET_MODULES_TYPE: # ty: ignore[invalid-method-override]
"""
Provide default `target_modules` using `target_backbone`, with additional exclusions.

Expand Down
2 changes: 1 addition & 1 deletion src/pruna/config/smash_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,7 @@ def load_from_json(self, path: str | Path) -> None:
setattr(self, name, config_dict.pop(name))

# Keep only values that still exist in the space, drop stale keys
supported_hparam_names = {hp.name for hp in SMASH_SPACE.get_hyperparameters()}
supported_hparam_names = {hp.name for hp in list(SMASH_SPACE.values())}
saved_values = {k: v for k, v in config_dict.items() if k in supported_hparam_names}

# Seed with the defaults, then overlay the saved values
Expand Down