-
Notifications
You must be signed in to change notification settings - Fork 80
build: update ty to v0.0.16 and align type-checking config #535
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
642a6a4
e05bae8
f521b21
88ae72a
22731e8
932cb22
df74955
b43041c
c00e487
d3a842f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,46 @@ | ||
| name: PR Title Convention | ||
|
|
||
| on: | ||
| pull_request: | ||
| types: [opened, edited, synchronize, reopened] | ||
| branches: [main] | ||
|
|
||
| permissions: | ||
| pull-requests: read | ||
|
|
||
| jobs: | ||
| check-title: | ||
| name: Validate PR title | ||
| runs-on: ubuntu-22.04 | ||
| timeout-minutes: 1 | ||
| steps: | ||
| - name: Check conventional commit format | ||
| env: | ||
| PR_TITLE: ${{ github.event.pull_request.title }} | ||
| run: | | ||
| # Allowed conventional commit types | ||
| TYPES="feat|fix|docs|style|refactor|perf|test|build|ci|chore|revert" | ||
|
|
||
| # Pattern: type(optional-scope): description | ||
| # OR: type!: description (breaking change) | ||
| PATTERN="^($TYPES)(\(.+\))?\!?: .+" | ||
|
|
||
| if echo "$PR_TITLE" | grep -qP "$PATTERN"; then | ||
| echo "PR title is valid: $PR_TITLE" | ||
| else | ||
| echo "::error::PR title does not follow Conventional Commits." | ||
| echo "" | ||
| echo "Got: $PR_TITLE" | ||
| echo "" | ||
| echo "Expected: <type>[optional scope]: <description>" | ||
| echo "" | ||
| echo "Allowed types: feat, fix, docs, style, refactor, perf, test, build, ci, chore, revert" | ||
| echo "Read more: https://www.conventionalcommits.org/en/v1.0.0/" | ||
| echo "" | ||
| echo "Examples:" | ||
| echo " feat: add new optimization algorithm" | ||
| echo " fix: resolve memory leak in model loading" | ||
| echo " ci(pruna): pin transformers version" | ||
| echo "" | ||
| exit 1 | ||
| fi |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -26,7 +26,7 @@ | |
|
|
||
| from pruna.algorithms.base.pruna_base import PrunaAlgorithmBase | ||
| from pruna.algorithms.base.tags import AlgorithmTag as tags | ||
| from pruna.config.smash_config import SmashConfigPrefixWrapper | ||
| from pruna.config.smash_config import SmashConfig, SmashConfigPrefixWrapper | ||
| from pruna.config.target_modules import ( | ||
| TARGET_MODULES_TYPE, | ||
| TargetModules, | ||
|
|
@@ -130,8 +130,8 @@ def model_check_fn(self, model: Any) -> bool: | |
| return any(isinstance(attr_value, tuple(transformer_and_unet_models)) for attr_value in model.__dict__.values()) | ||
|
|
||
| def get_model_dependent_hyperparameter_defaults( | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The signature for
In all cases, these signature changes shouldn't be made for algorithms that return e.g. |
||
| self, model: Any, smash_config: SmashConfigPrefixWrapper | ||
| ) -> dict[str, Any]: | ||
| self, model: Any, smash_config: SmashConfig | SmashConfigPrefixWrapper | ||
| ) -> TARGET_MODULES_TYPE: # ty: ignore[invalid-method-override] | ||
| """ | ||
| Provide default `target_modules` by detecting transformer and unet components in the pipeline. | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -22,7 +22,7 @@ | |
|
|
||
| from pruna.algorithms.base.pruna_base import PrunaAlgorithmBase | ||
| from pruna.algorithms.base.tags import AlgorithmTag as tags | ||
| from pruna.config.smash_config import SmashConfigPrefixWrapper | ||
| from pruna.config.smash_config import SmashConfig, SmashConfigPrefixWrapper | ||
| from pruna.config.target_modules import TARGET_MODULES_TYPE, TargetModules, map_targeted_nn_roots | ||
| from pruna.engine.save import SAVE_FUNCTIONS | ||
|
|
||
|
|
@@ -91,10 +91,7 @@ def _apply(self, model: Any, smash_config: SmashConfigPrefixWrapper) -> Any: | |
| target_modules = smash_config["target_modules"] | ||
|
|
||
| if target_modules is None: | ||
| target_modules = self.get_model_dependent_hyperparameter_defaults( | ||
| model, | ||
| smash_config | ||
| ) | ||
| target_modules = self.get_model_dependent_hyperparameter_defaults(model, smash_config) | ||
|
|
||
| def apply_sage_attn( | ||
| root_name: str | None, | ||
|
|
@@ -153,8 +150,8 @@ def get_hyperparameters(self) -> list: | |
| def get_model_dependent_hyperparameter_defaults( | ||
| self, | ||
| model: Any, | ||
| smash_config: SmashConfigPrefixWrapper, | ||
| ) -> TARGET_MODULES_TYPE: | ||
| smash_config: SmashConfig | SmashConfigPrefixWrapper, | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There shouldn't be any change in this file, but there is indeed a type problem here since the file hasn't been updated when changing the return type for the base class' method. I opened a quick PR #540 fixing exactly this |
||
| ) -> TARGET_MODULES_TYPE: # ty: ignore[invalid-method-override] | ||
| """ | ||
| Provide default `target_modules` targeting all transformer modules. | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I fully agree that we should ignore the comments for now, and I'll go through the code in a future PR to remove those ignore statements one by one because this frankly isn't checking anything...