Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,9 @@ dependencies = [
"vbench-pruna; sys_platform != 'darwin'",
"imageio-ffmpeg",
"jaxtyping",
"peft>=0.17.1",
"peft>=0.18.0",
"trl<=0.21.0",
"termcolor==2.3.0",
]

[project.optional-dependencies]
Expand Down
4 changes: 4 additions & 0 deletions src/pruna/algorithms/base/tags.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,10 @@ class AlgorithmTag(Enum):
"resampler",
"Resamplers change the shape of image or video latents during generation to speed up inference.",
)
RECOVERER = (
"recoverer",
"Recovery restores the performance of a model after compression.",
)

def __init__(self, name: str, description: str):
"""
Expand Down
73 changes: 73 additions & 0 deletions src/pruna/algorithms/distillation_perp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# Copyright 2025 - Pruna AI GmbH. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from typing import Iterable

from pruna.algorithms.base.tags import AlgorithmTag
from pruna.algorithms.global_utils.recovery.perp_recoverer import PERPRecoverer


class TextToImagePERPDistillation(PERPRecoverer):
"""
PERP distillation recoverer for text-to-image models.

This recoverer is a general purpose PERP recoverer for text-to-image models using norm and bias finetuning
as well as LoRA layers.

Parameters
----------
use_lora : bool
Whether to use LoRA adapters.
use_in_place : bool
Whether to use norm and bias finetuning which will modify the model in place.
"""

group_tags: list[AlgorithmTag] = [AlgorithmTag.DISTILLER, AlgorithmTag.RECOVERER] # type: ignore[attr-defined]
algorithm_name = "text_to_image_distillation_perp"
tokenizer_required = False
compatible_before: Iterable[str | AlgorithmTag] = ["quanto", "torch_dynamic", "deepcache"]
compatible_after: Iterable[str | AlgorithmTag] = ["torch_compile"]
runs_on: list[str] = ["cuda"]

def __init__(self, use_lora: bool = True, use_in_place: bool = True) -> None:
super().__init__(task_name="text_to_image", use_lora=use_lora, use_in_place=use_in_place, is_distillation=True)


class TextToImageInPlacePERPDistillation(TextToImagePERPDistillation):
"""
PERP distillation recoverer for text-to-image models without LoRA adapters.

This is the same as ``text_to_image_distillation_perp``, but without LoRA layers which add extra computations and
thus slow down the inference of the final model.
"""

algorithm_name = "text_to_image_distillation_inplace_perp"

def __init__(self) -> None:
super().__init__(use_lora=False, use_in_place=True)


class TextToImageLoraDistillation(TextToImagePERPDistillation):
"""
LoRA distillation recoverer for text-to-image models.

This recoverer attaches LoRA adapters to the model and uses them for distillation.
"""

algorithm_name = "text_to_image_distillation_lora"

def __init__(self) -> None:
super().__init__(use_lora=True, use_in_place=False)
13 changes: 13 additions & 0 deletions src/pruna/algorithms/global_utils/recovery/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright 2025 - Pruna AI GmbH. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
101 changes: 101 additions & 0 deletions src/pruna/algorithms/global_utils/recovery/adapters/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
# Copyright 2025 - Pruna AI GmbH. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from abc import ABC, abstractmethod
from typing import Any

import torch

from pruna.config.smash_config import SmashConfigPrefixWrapper


class PrunaAdapter(ABC):
"""Base class for adapters, defining which parameters to finetune for recovery."""

@property
@abstractmethod
def adapter_prefix(self) -> str:
"""The prefix of the adapter to use in the config."""
pass

@classmethod
@abstractmethod
def get_hyperparameters(cls, task_name: str, **override_defaults: Any) -> list:
"""
Configure all algorithm-specific hyperparameters with ConfigSpace.

Parameters
----------
task_name : str
The name of the task, e.g. "text-to-image" or "text-to-text".
**override_defaults : Any
Values used to override the default hyperparameters when using multiple finetuners together.

Returns
-------
list
The hyperparameters.
"""
pass

@classmethod
@abstractmethod
def activate(
cls,
model: torch.nn.Module,
smash_config: SmashConfigPrefixWrapper,
seed: int | None = None,
) -> tuple[torch.nn.Module, int, int]:
"""
Activate or create the parameters in the model corresponding to the adapter.

Parameters
----------
model : torch.nn.Module
The model to apply the component to.
smash_config : SmashConfigPrefixWrapper
The configuration for the component.
seed : int
The seed to use for the adapter if it requires initialization.

Returns
-------
torch.nn.Module
The model with the adapter activated.
int
The number of trainable parameters.
int
The number of skipped parameters.
"""
pass

@classmethod
def pre_smash_hook(
cls, model: torch.nn.Module, smash_config: SmashConfigPrefixWrapper, seed: int | None = None
) -> None:
"""
Optional hook to prepare the model/config before smashing.

Parameters
----------
model : torch.nn.Module
The model to prepare.
smash_config : SmashConfigPrefixWrapper
Configuration scoped to this adapter.
seed : int | None
Optional seed for deterministic initialization.
"""
pass
68 changes: 68 additions & 0 deletions src/pruna/algorithms/global_utils/recovery/adapters/bias.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# Copyright 2025 - Pruna AI GmbH. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch

from pruna.algorithms.global_utils.recovery.adapters import PrunaAdapter, utils


class BiasAdapter(PrunaAdapter):
"""Adapter for bias finetuning."""

adapter_prefix = "bias"

@classmethod
def get_hyperparameters(cls, *args, **kwargs) -> list:
"""
Configure all method-specific hyperparameters with ConfigSpace.

Parameters
----------
*args : Any
Unused arguments.
**kwargs : Any
Unused keyword arguments.

Returns
-------
list
The hyperparameters.
"""
return []

@classmethod
def activate(cls, model: torch.nn.Module, *args, **kwargs) -> tuple[torch.nn.Module, int, int]:
"""
Activate all biases for training.

Parameters
----------
model : torch.nn.Module
The model containing the biases.
*args : Any
Unused additional arguments.
**kwargs : Any
Unused additional keyword arguments.

Returns
-------
torch.nn.Module
The model with the biases activated.
int
The number of trainable bias parameters.
int
The number of skipped bias parameters.
"""
num_activ_param, num_skip_param = utils.unfreeze_parameters_by_name(model, target_modules=("bias",))
return model, num_activ_param, num_skip_param
99 changes: 99 additions & 0 deletions src/pruna/algorithms/global_utils/recovery/adapters/head.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
# Copyright 2025 - Pruna AI GmbH. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import inspect

import torch

from pruna.algorithms.global_utils.recovery.adapters import PrunaAdapter, utils
from pruna.logging.logger import pruna_logger


class HeadAdapter(PrunaAdapter):
"""Adapter for finetuning the model's head while keeping the backbone as is."""

adapter_prefix = "head"

@classmethod
def get_hyperparameters(cls, *args, **kwargs) -> list:
"""
Configure all method-specific hyperparameters with ConfigSpace.

Parameters
----------
*args : tuple
The arguments for the adapter.
**kwargs : dict
The hyperparameters for the adapter.

Returns
-------
list
The hyperparameters.
"""
return []

@classmethod
def activate(cls, model: torch.nn.Module, *args, **kwargs) -> tuple[torch.nn.Module, int, int]:
"""
Activate the model's head for training.

Parameters
----------
model : torch.nn.Module
The model containing the head.
*args : tuple
The arguments for the adapter.
**kwargs : dict
The hyperparameters for the adapter.

Returns
-------
torch.nn.Module
The model with the head activated.
int
The number of trainable head parameters.
int
The number of skipped head parameters.
"""
# find head from type and name
model_heads = [
component
for comp_name, component in inspect.getmembers(model)
if isinstance(component, torch.nn.Linear) and "head" in comp_name.lower()
]
if len(model_heads) != 1:
# = 0: model with no head, e.g. diffusers
# > 1: model with multiple heads, e.g. for localization, not currently supported
model_head_names = [
comp_name
for comp_name, component in inspect.getmembers(model)
if isinstance(component, torch.nn.Linear) and "head" in comp_name.lower()
]
pruna_logger.warning(
f"Found multiple heads but expected only one: {model_head_names}. Skipping head finetuning."
)
return model, 0, 0
model_head = model_heads[0]

# unfreeze head parameters, recording the number of trainable and skipped parameters
num_activ_param, num_skip_param = 0, 0
for param in model_head.parameters():
if utils.is_trainable(param):
param.requires_grad = True
num_activ_param += int(param.numel())
else:
num_skip_param += int(param.numel())

return model, num_activ_param, num_skip_param
Loading
Loading