From b2fca70c3d4a2cec29897db3fd604c9be4105189 Mon Sep 17 00:00:00 2001 From: fl1123 Date: Fri, 18 Jul 2025 15:17:28 +0100 Subject: [PATCH] Feat: Implement FW-merging with loss approximation. --- config/method/fw_merging/fw_hard_am.yaml | 16 + .../fw_merging/fw_hard_am_loss_approx.yaml | 16 + .../fw_merging/fw_hard_loss_approx.yaml | 11 + .../fw_merging/fw_soft_loss_approx.yaml | 12 + fusion_bench/method/__init__.py | 4 +- fusion_bench/method/fw_merging/__init__.py | 6 +- fusion_bench/method/fw_merging/fw_hard_am.py | 537 ++++++++++++++++ .../fw_merging/fw_hard_am_loss_approx.py | 603 ++++++++++++++++++ .../method/fw_merging/fw_hard_loss_approx.py | 476 ++++++++++++++ .../method/fw_merging/fw_soft_loss_approx.py | 574 +++++++++++++++++ .../wrappers/layer_wise_fusion_fw_am.py | 302 +++++++++ 11 files changed, 2554 insertions(+), 3 deletions(-) create mode 100644 config/method/fw_merging/fw_hard_am.yaml create mode 100644 config/method/fw_merging/fw_hard_am_loss_approx.yaml create mode 100644 config/method/fw_merging/fw_hard_loss_approx.yaml create mode 100644 config/method/fw_merging/fw_soft_loss_approx.yaml create mode 100644 fusion_bench/method/fw_merging/fw_hard_am.py create mode 100644 fusion_bench/method/fw_merging/fw_hard_am_loss_approx.py create mode 100644 fusion_bench/method/fw_merging/fw_hard_loss_approx.py create mode 100644 fusion_bench/method/fw_merging/fw_soft_loss_approx.py create mode 100644 fusion_bench/models/wrappers/layer_wise_fusion_fw_am.py diff --git a/config/method/fw_merging/fw_hard_am.yaml b/config/method/fw_merging/fw_hard_am.yaml new file mode 100644 index 00000000..0547439e --- /dev/null +++ b/config/method/fw_merging/fw_hard_am.yaml @@ -0,0 +1,16 @@ +_target_: fusion_bench.method.FrankWolfeHardAdamergingAlgorithm +merge_fn: task_arithmetic +max_iters: 10 +step_size: 0.1 +dataset_size: 100 +tasks: [] +init_weight: base +loss_fn: cross_entropy +scaling_factor: 0.3 +max_num_models: 100 +granularity: task +init_layer_weights: 0.0 +ada_merge: True +ada_max_steps: 1000 +ada_optimizer: adam +ada_lr: 1e-3 diff --git a/config/method/fw_merging/fw_hard_am_loss_approx.yaml b/config/method/fw_merging/fw_hard_am_loss_approx.yaml new file mode 100644 index 00000000..e5cd2568 --- /dev/null +++ b/config/method/fw_merging/fw_hard_am_loss_approx.yaml @@ -0,0 +1,16 @@ +_target_: fusion_bench.method.FrankWolfeHardAdamergingLossApproxAlgorithm +merge_fn: task_arithmetic +max_iters: 10 +step_size: 0.1 +dataset_size: 100 +tasks: [] +init_weight: base +loss_fn: cross_entropy +scaling_factor: 0.3 +max_num_models: 100 +granularity: task +init_layer_weights: 0.0 +ada_merge: True +ada_max_steps: 1000 +ada_optimizer: adam +ada_lr: 1e-3 diff --git a/config/method/fw_merging/fw_hard_loss_approx.yaml b/config/method/fw_merging/fw_hard_loss_approx.yaml new file mode 100644 index 00000000..4926175e --- /dev/null +++ b/config/method/fw_merging/fw_hard_loss_approx.yaml @@ -0,0 +1,11 @@ +_target_: fusion_bench.method.FrankWolfeHardLossApproxAlgorithm +merge_fn: task_arithmetic +max_iters: 10 +step_size: 0.1 +dataset_size: 100 +tasks: [] +init_weight: +loss_fn: cross_entropy +scaling_factor: 0.3 +max_num_models: 100 +granularity: task diff --git a/config/method/fw_merging/fw_soft_loss_approx.yaml b/config/method/fw_merging/fw_soft_loss_approx.yaml new file mode 100644 index 00000000..094edfcd --- /dev/null +++ b/config/method/fw_merging/fw_soft_loss_approx.yaml @@ -0,0 +1,12 @@ +_target_: fusion_bench.method.FrankWolfeSoftLossApproxAlgorithm +init_weight: +max_iters: 10 +merge_fn: 'adamerging' +tasks: +ada_iters: 500 +dataset_size: 100 +ada_coeff: 1e-8 +step_size: 0.1 +max_num_models: 100 +granularity: task +ada_loss: entropy_loss diff --git a/fusion_bench/method/__init__.py b/fusion_bench/method/__init__.py index 20573a38..c4647a18 100644 --- a/fusion_bench/method/__init__.py +++ b/fusion_bench/method/__init__.py @@ -100,7 +100,7 @@ "SparseLoForLlama", "PCPSparseLoForLlama", ], - "fw_merging": ["FrankWolfeHardAlgorithm", "FrankWolfeSoftAlgorithm"], + "fw_merging": ["FrankWolfeHardAlgorithm", "FrankWolfeSoftAlgorithm", "FrankWolfeHardAdamergingAlgorithm", "FrankWolfeSoftLossApproxAlgorithm", "FrankWolfeHardLossApproxAlgorithm", "FrankWolfeHardAdamergingLossApproxAlgorithm"], } @@ -182,7 +182,7 @@ from .ties_merging import TiesMergingAlgorithm from .we_moe import CLIPWeightEnsemblingMoEAlgorithm from .weighted_average import WeightedAverageAlgorithm, WeightedAverageForLLama - from .fw_merging import FrankWolfeHardAlgorithm, FrankWolfeSoftAlgorithm + from .fw_merging import FrankWolfeHardAlgorithm, FrankWolfeSoftAlgorithm, FrankWolfeHardAdamergingAlgorithm, FrankWolfeSoftLossApproxAlgorithm, FrankWolfeHardLossApproxAlgorithm, FrankWolfeHardAdamergingLossApproxAlgorithm else: sys.modules[__name__] = LazyImporter( diff --git a/fusion_bench/method/fw_merging/__init__.py b/fusion_bench/method/fw_merging/__init__.py index 74188ef4..37fa1dd7 100644 --- a/fusion_bench/method/fw_merging/__init__.py +++ b/fusion_bench/method/fw_merging/__init__.py @@ -1,2 +1,6 @@ from .fw_hard import FrankWolfeHardAlgorithm -from .fw_soft import FrankWolfeSoftAlgorithm \ No newline at end of file +from .fw_soft import FrankWolfeSoftAlgorithm +from .fw_hard_am import FrankWolfeHardAdamergingAlgorithm +from .fw_soft_loss_approx import FrankWolfeSoftLossApproxAlgorithm +from .fw_hard_loss_approx import FrankWolfeHardLossApproxAlgorithm +from .fw_hard_am_loss_approx import FrankWolfeHardAdamergingLossApproxAlgorithm \ No newline at end of file diff --git a/fusion_bench/method/fw_merging/fw_hard_am.py b/fusion_bench/method/fw_merging/fw_hard_am.py new file mode 100644 index 00000000..2b23acc7 --- /dev/null +++ b/fusion_bench/method/fw_merging/fw_hard_am.py @@ -0,0 +1,537 @@ +""" +This script contains the general implementation of the Task Arithmetic method. + +http://arxiv.org/abs/2212.04089 +""" + +import logging +import os +from abc import abstractmethod +from typing import TYPE_CHECKING, Any, List, Mapping, TypeVar, Union, Dict +from copy import deepcopy +from collections import defaultdict +from functools import partial +import functools + +from .utils import * + +import torch +from lightning.fabric.utilities.rank_zero import rank_zero_only +from omegaconf import DictConfig +from torch import Tensor, nn +from torch.utils.data import DataLoader +from tqdm.autonotebook import tqdm + +from fusion_bench.compat.method import ModelFusionAlgorithm +from fusion_bench.dataset.clip_dataset import CLIPDataset +from fusion_bench.mixins import CLIPClassificationMixin +from fusion_bench.compat.modelpool import ModelPool, HuggingFaceClipVisionPool +from fusion_bench.mixins.lightning_fabric import LightningFabricMixin +from fusion_bench.mixins.simple_profiler import SimpleProfilerMixin +from fusion_bench.models.wrappers.layer_wise_fusion_fw_am import ( + LayerWiseMergedModel, + get_layer_wise_weights, +) +from fusion_bench.utils.data import load_tensor_from_file +from fusion_bench.utils.type import TorchModelType + +if TYPE_CHECKING: + from fusion_bench.programs.fabric_fusion_program import FabricModelFusionProgram + +from fusion_bench.mixins.simple_profiler import SimpleProfilerMixin +from fusion_bench.modelpool import BaseModelPool +from fusion_bench.utils.data import InfiniteDataLoader +from fusion_bench.utils.state_dict_arithmetic import ( + state_dict_add, + state_dict_mul, + state_dict_sub, +) +from fusion_bench.utils.type import StateDictType +from fusion_bench.utils import instantiate + +log = logging.getLogger(__name__) + + +# @torch.no_grad() +# def task_arithmetic_merge( +# pretrained_model: nn.Module, +# finetuned_models: List[Dict[str, Tensor]], +# scaling_factor: float, +# inplace: bool = True, +# ) -> nn.Module: +# """ +# Merges the task vectors from multiple fine-tuned models into a single pre-trained model. + +# Args: +# pretrained_model (nn.Module): The pre-trained model to which the task vectors will be added. +# finetuned_models (List[nn.Module]): A list of fine-tuned models from which task vectors will be calculated. +# scaling_factor (float): A factor by which the task vectors will be scaled before merging. +# inplace (bool, optional): If True, the pre-trained model will be modified in place. +# If False, a copy of the pre-trained model will be modified. Defaults to True. + +# Returns: +# nn.Module: The pre-trained model with the merged task vectors. +# """ +# if not inplace: +# pretrained_model = deepcopy(pretrained_model) +# if isinstance(finetuned_models[0], nn.Module): +# finetuned_models = [deepcopy(model.state_dict(keep_vars=True)) for model in finetuned_models] +# task_vector: StateDictType = None +# # Calculate the total task vector +# for model in finetuned_models: +# if task_vector is None: +# task_vector = state_dict_sub( +# model, +# pretrained_model.state_dict(keep_vars=True), +# ) +# else: +# task_vector = state_dict_add( +# task_vector, +# state_dict_sub( +# model, +# pretrained_model.state_dict(keep_vars=True), +# ), +# ) +# # scale the task vector +# task_vector = state_dict_mul(task_vector, scaling_factor) +# # add the task vector to the pretrained model +# state_dict = state_dict_add( +# pretrained_model.state_dict(keep_vars=True), task_vector +# ) +# pretrained_model.load_state_dict(state_dict) +# return pretrained_model + + +# @torch.no_grad() +# def ties_merge( +# pretrained_model: nn.Module, +# finetuned_models: List[Dict[str, Tensor]], +# scaling_factor: float, +# threshold: float, +# ) -> nn.Module: +# remove_keys = [] +# merge_func = "sum" +# if isinstance(finetuned_models[0], nn.Module): +# finetuned_models = [deepcopy(model.state_dict(keep_vars=True)) for model in finetuned_models] + +# ptm_check = pretrained_model.state_dict(keep_vars=True) + +# # Compute the task vectors +# flat_ft = torch.vstack( +# [state_dict_to_vector(check, remove_keys) for check in finetuned_models] +# ) +# flat_ptm = state_dict_to_vector(ptm_check, remove_keys) +# tv_flat_checks = flat_ft - flat_ptm + +# # Perform TIES Merging +# merged_tv = ties_merging( +# tv_flat_checks, +# reset_thresh=threshold, +# merge_func=merge_func, +# ) +# merged_check = flat_ptm + scaling_factor * merged_tv +# merged_state_dict = vector_to_state_dict( +# merged_check, ptm_check, remove_keys=remove_keys +# ) + +# # Load the merged state dict into the pretrained model +# pretrained_model.load_state_dict(merged_state_dict) +# return pretrained_model + +@torch.no_grad() +def task_arithmetic_merge( + merged_model: LayerWiseMergedModel, + finetuned_models: List[Dict[str, Tensor]], + indices: List[int] = None, + scaling_factor: float = 1.0, +) -> nn.Module: + """ + Merges the task vectors from multiple fine-tuned models into a single pre-trained model. + + Args: + pretrained_model (nn.Module): The pre-trained model to which the task vectors will be added. + finetuned_models (List[nn.Module]): A list of fine-tuned models from which task vectors will be calculated. + scaling_factor (float): A factor by which the task vectors will be scaled before merging. + inplace (bool, optional): If True, the pre-trained model will be modified in place. + If False, a copy of the pre-trained model will be modified. Defaults to True. + + Returns: + nn.Module: The pre-trained model with the merged task vectors. + """ + print("shapes: ", merged_model.merge_weight.shape, len(indices)) + print(indices) + # Directly edit merge_weight of the merged model + for l, model_index in enumerate(indices): + merged_model.merge_weight[model_index, l] += scaling_factor + merged_model.merge_weights() + return merged_model + +def entropy_loss(logits: Tensor, pred = None, eps: float = 1e-8) -> Tensor: + """ + Compute the entropy loss of a set of logits. + + Args: + logits (Tensor): The logits to compute the entropy loss of. + eps (float): A small value to avoid log(0). Default is 1e-8. + + Returns: + Tensor: The entropy loss of the logits. + """ + # Ensure the logits tensor has 2 dimensions + assert ( + logits.dim() == 2 + ), f"Expected logits to have 2 dimensions, found {logits.dim()}, {logits.size()=}" + + # Compute the softmax probabilities + probs = torch.softmax(logits, dim=-1) + + # Compute the entropy loss + return -torch.sum(probs * torch.log(probs + eps), dim=-1).mean() + + +class FrankWolfeHardAdamergingAlgorithm( + CLIPClassificationMixin, + ModelFusionAlgorithm, + SimpleProfilerMixin, +): + + + def __init__(self, + merge_fn: str, + step_size: float, + max_iters: int, + dataset_size:int, + tasks: List[str] = [], + granularity: str = 'task', + max_num_models: int = 100, + loss_fn: str = "cross_entropy", + init_weight: str = "", + scaling_factor: float = 1., + threshold: int = 20, + init_layer_weights: float = None, + ada_merge: bool = True, + ada_max_steps: int = 500, + ada_optimizer: str = "adam", + ada_lr: float = 1e-3, + **kwargs): + """ + Initializes the TaskArithmeticAlgorithm with the given scaling factor. + + Args: + scaling_factor (int): The factor by which the task vectors will be scaled before merging. + """ + # self.merger = merge_fn + # if merge_fn == "task_arithmetic": + # self.merge_fn = task_arithmetic_merge + # elif merge_fn == "ties": + # self.merge_fn = partial(ties_merge, threshold=threshold) + # # elif merge_fn == "concrete_ta": + # # self.merge_fn = ConcreteTaskArithmeticAlgorithmForCLIP( + # # instantiate(OmegaConf.load("config/method/concrete_subspace/clip_concrete_task_arithmetic.yaml")) + # # ) + # else: + # raise ValueError(f"Unsupported merge_fn: {merge_fn}") + self.scaling_factor = scaling_factor + + self.init_weight = init_weight + self.step_size = step_size + self.max_iters = max_iters + self.granularity = granularity + self.loss_fn = loss_fn + self.tasks = tasks + self.dataset_size = dataset_size + self.max_num_models = max_num_models + + self.ada_merge = ada_merge + self.init_layer_weights = init_layer_weights + self.ada_max_steps = ada_max_steps + self.ada_optimizer = ada_optimizer + self.ada_lr = ada_lr + super().__init__(**kwargs) + + + def on_frank_wolfe_iteration_start(self): + self.setup_zero_shot_classification_head() + + @functools.cache + def get_shuffled_loader_iter(self, task: str): + if self.loss_fn == "cross_entropy": + # get dataloader kwargs + dataloader_kwargs = self._dataloader_kwargs.copy() + dataloader_kwargs["shuffle"] = True + dataloader_kwargs["batch_size"] = 1 + + # get the test dataset + clip_dataset = CLIPDataset( + self.modelpool.load_train_dataset(task), self.clip_processor + ) + # create the dataloader + loader = DataLoader(clip_dataset, **dataloader_kwargs) + loader = self.fabric.setup_dataloaders(loader) + return iter(InfiniteDataLoader(loader)) + elif self.loss_fn == "entropy": + return super().get_shuffled_test_loader_iter( + task, + batch_size=1, + ) + else: + raise ValueError(f"Unsupported loss function: {self.loss_fn}") + + + def frank_wolfe_iteration(self, merged_model): + + merged_model.train() + # zero the gradients + for name, param in merged_model.named_parameters(): + param.requires_grad = True + param.grad = None + + if self.loss_fn == "cross_entropy": + loss_fn = nn.CrossEntropyLoss() + elif self.loss_fn == "entropy": + loss_fn = entropy_loss + avg_loss = defaultdict(list) + tasks = self.tasks if self.tasks else self.modelpool.model_names + for task in tasks: + log.info(f"Processing task {task}") + for _ in range(self.dataset_size): + with self.profile("data loading"): + batch = next(self.get_shuffled_loader_iter(task)) + with self.profile("forward pass"): + logits = self.compute_logits(merged_model, batch[0], task) + loss = loss_fn(logits, batch[1]) / (self.dataset_size * len(self.modelpool.model_names)) + with self.profile("backward pass"): + # self.fabric.backward(loss, retain_graph=True) + loss.backward() + avg_loss[task].append(loss.item()) + + # calculate the loss + avg_loss = {task: sum(losses) / len(losses) for task, losses in avg_loss.items()} + log.info(f"Average Loss: {avg_loss}, Total Loss: {sum(avg_loss.values()) / len(avg_loss)}") + + gradients = {name: param.grad.clone().to('cuda') for name, param in merged_model.named_parameters() if param.requires_grad} + for name, param in merged_model.named_parameters(): + param.grad = None + merged_model.eval() + + return gradients + + + def frank_wolfe_selection(self, gradients, checkpoints, model_to_merge_names={}, type='task'): + assert type in ['task', 'layer'], f"Unsupported FW selection type: {type}, supported types are ['task', 'layer']" + min_inner_product = float("inf") + min_model = None + min_model_name = None + log_dict = {} + if type == 'task': + for model_name, model_to_merge in checkpoints.items(): + model_to_merge = model_to_merge.to('cuda').state_dict() + inner_product_sum = 0 + for param_name, param_value in model_to_merge.items(): + # caclulate consine similarity + grad = gradients[param_name] + ckpt = model_to_merge[param_name] + param_alignment = torch.dot(grad.flatten(), ckpt.flatten()) / (torch.norm(grad) * torch.norm(ckpt)) + inner_product_sum += param_alignment + log_dict[model_name] = inner_product_sum.item() + if inner_product_sum < min_inner_product and model_name not in model_to_merge_names: + min_inner_product = inner_product_sum + min_model = deepcopy(model_to_merge) + min_model_name = model_name + else: + min_model = {} + min_inner_product = {} + min_idx = {} + min_model_name = {} + for model_name, model_to_merge in checkpoints.items(): + model_to_merge = model_to_merge.to('cuda').state_dict() + for param_name, param_value in model_to_merge.items(): + # caclulate consine similarity + grad = gradients[param_name] + ckpt = model_to_merge[param_name] + param_alignment = torch.dot(grad.flatten(), ckpt.flatten()) / (torch.norm(grad) * torch.norm(ckpt)) + if (param_name not in min_inner_product or param_alignment < min_inner_product[param_name]) and \ + model_name not in model_to_merge_names[param_name]: + min_inner_product[param_name] = param_alignment + # if min_inner_product[param_name] < 0: + min_model[param_name] = param_value + min_idx[param_name] = model_name + min_model_name[param_name] = model_name + # else: + # min_model[param_name] = torch.zeros_like(param_value) + min_inner_product = sum(min_inner_product.values()) + log_dict = {model_name: 0 for model_name in checkpoints.keys()} + for k in min_idx.values(): + log_dict[k] += 1 + + return min_model, min_model_name, min_inner_product, log_dict + + + def run_adamerging(self, module: "LayerWiseMergedModel[TorchModelType]"): + # configure optimizer + if self.ada_optimizer == "adam": + optimizer = torch.optim.Adam([module.merge_weight], lr=self.ada_lr) + print(f"{optimizer=}") + module, optimizer = self.fabric.setup(module, optimizer) + else: + raise ValueError(f"Unsupported optimizer: {self.ada_optimizer}") + + module.train() + module.merge_weights() + for step_idx in ( + pbar := tqdm( + range(self.ada_max_steps), + "AdaMerging Test-time adaptation", + dynamic_ncols=True, + ) + ): + # default behavior for first-order optimizers + for task in self.tasks: + with self.profile("data loading"): + batch = next(self.get_shuffled_test_loader_iter(task)) + with self.profile("forward pass"): + logits = self.compute_logits(module, batch[0], task) + loss = entropy_loss(logits) + with self.profile("backward pass"): + self.fabric.backward(loss, retain_graph=True) + + with self.profile("optimizer step"): + optimizer.step() + optimizer.zero_grad() + with self.profile("merging weights"): + module.merge_weights() + + metrics = { + "train/loss": loss.item(), + "train/weight_max": module.merge_weight.max().item(), + "train/weight_min": module.merge_weight.min().item(), + "train/weight_mean": module.merge_weight.mean().item(), + } + self.fabric.log_dict(metrics, step=step_idx) + pbar.set_postfix(metrics) + + self.print_profile_summary() + return module + + + def run(self, modelpool: HuggingFaceClipVisionPool): + log.info("Fusing models using FW merging.") + self.modelpool = modelpool + # self.log_hyperparams(self.config) + self.on_frank_wolfe_iteration_start() + + assert modelpool.has_pretrained, "Pretrained model is required." + finetuned_models = {name: modelpool.load_model(name) for name in modelpool.model_names[:self.max_num_models]} + pretrained_model = modelpool.load_model("_pretrained_") + + if self.init_weight: + if self.init_weight == 'base': + log.info("Initializing the merged model with the base model") + merged_model = pretrained_model + else: + log.info("Initializing the merged model with the initial weight") + if isinstance(self.init_weight, str): + # self.config.weights is a path to a saved tensor + layer_wise_weight = load_tensor_from_file(self.init_weight) + else: + raise ValueError(f"Unsupported weights format: {self.init_weight}") + + merged_model = LayerWiseMergedModel( + layer_wise_weight=layer_wise_weight, + pretrained_model=modelpool.load_model("_pretrained_"), + finetuned_models=list(finetuned_models.values()), + clamp_weights=False, + tie_weights=True, + strict=False, + ) + merged_model = merged_model.merge_and_unload() + + # Construct the layer-wise merged model instead + initial_model = modelpool.load_model("_pretrained_") + initial_model.load_state_dict(merged_model.state_dict()) + finetuned_models['initial'] = initial_model + # Create layer-wise weights + layer_wise_weight = get_layer_wise_weights( + num_models=len(finetuned_models), + num_layers=len( + tuple( + filter(lambda p: p.requires_grad, pretrained_model.parameters()) + ) + ), + init_values= self.init_layer_weights if self.init_layer_weights is not None else 0.0, + ) + # Change the last row (initial model) to 1.0 + layer_wise_weight[-1, :] = 1.0 + # Create the merged model with the layer-wise weights + with torch.no_grad(): + self.set_requires_grad(merged_model, initial_model) + merged_model = LayerWiseMergedModel( + layer_wise_weight=layer_wise_weight, + pretrained_model=merged_model, + finetuned_models=list(finetuned_models.values()), + clamp_weights=False, + tie_weights=True, + strict=False, + ).cuda() + else: + raise ValueError("`init_weight` must be base or specified in the config file.") + + + # FW iteration + for step_idx in ( + pbar := tqdm( + range(self.max_iters if not self.is_debug_mode else 1), + ("[DEBUG MODE] " if self.is_debug_mode else "") + + "Frank-Wolfe Merging", + dynamic_ncols=True, + ) + ): + torch.cuda.empty_cache() + merged_model_tmp = merged_model.merge_and_copy().cuda() + with torch.no_grad(): + self.set_requires_grad(merged_model_tmp, initial_model) + torch.set_grad_enabled(True) + gradients = self.frank_wolfe_iteration(merged_model_tmp) + torch.set_grad_enabled(False) + grad_norm = torch.norm(torch.stack([torch.norm(g) for g in gradients.values()])) + + model_to_merge_names = [] if self.granularity == 'task' else {name: [] for name in merged_model.state_dict().keys()} + min_model, min_model_name, min_alignment, chosen_model = self.frank_wolfe_selection(gradients, finetuned_models, model_to_merge_names=model_to_merge_names, type=self.granularity) + + # Determine step size + step = 2 / (step_idx + 2) * self.step_size + + # print iteration information + log.info(f"Iteration {step_idx+1}, Task Vector: {min_model_name}, Gradient Norm: {grad_norm:.6f}, Inner Products: {min_alignment:.6f}, Chosen Model: {chosen_model}") + + # Calculate position of each min_model in the finetuned model lists in each layer + if self.granularity == 'task': + indices = [list(finetuned_models.keys()).index(min_model_name) for _ in range(len(merged_model.state_dict()))] + else: + indices = [list(finetuned_models.keys()).index(min_model_name[param_name]) for param_name in merged_model.state_dict().keys()] + merged_model = task_arithmetic_merge( + merged_model=merged_model, + finetuned_models=[min_model], + indices=indices, + scaling_factor=step * self.scaling_factor, + ) + + # Merge model with Adamerging + if self.ada_merge: + print("number of models to merge: ", len(modelpool.model_names)) + torch.set_grad_enabled(True) + merged_model = self.run_adamerging(merged_model) + torch.set_grad_enabled(False) + + with torch.no_grad(): + merged_model = merged_model.merge_and_unload() + self.set_requires_grad(merged_model, initial_model) + # eval and return model + merged_model = merged_model.cuda().eval() + return merged_model + + def set_requires_grad(self, merged_model, initial_model): + for name, param in initial_model.named_parameters(): + for n, p in merged_model.named_parameters(): + if name == n: + p.requires_grad = param.requires_grad diff --git a/fusion_bench/method/fw_merging/fw_hard_am_loss_approx.py b/fusion_bench/method/fw_merging/fw_hard_am_loss_approx.py new file mode 100644 index 00000000..8d5fd7cb --- /dev/null +++ b/fusion_bench/method/fw_merging/fw_hard_am_loss_approx.py @@ -0,0 +1,603 @@ +""" +This script contains the general implementation of the Task Arithmetic method. + +http://arxiv.org/abs/2212.04089 +""" + +import logging +import os +from abc import abstractmethod +from typing import TYPE_CHECKING, Any, List, Mapping, TypeVar, Union, Dict +from copy import deepcopy +from collections import defaultdict +from functools import partial +import functools + +from .utils import * + +import torch +from lightning.fabric.utilities.rank_zero import rank_zero_only +from omegaconf import DictConfig +from torch import Tensor, nn +from torch.utils.data import DataLoader +from tqdm.autonotebook import tqdm + +from fusion_bench.compat.method import ModelFusionAlgorithm +from fusion_bench.dataset.clip_dataset import CLIPDataset +from fusion_bench.mixins import CLIPClassificationMixin +from fusion_bench.compat.modelpool import ModelPool, HuggingFaceClipVisionPool +from fusion_bench.mixins.lightning_fabric import LightningFabricMixin +from fusion_bench.mixins.simple_profiler import SimpleProfilerMixin +from fusion_bench.models.wrappers.layer_wise_fusion_fw_am import ( + LayerWiseMergedModel, + get_layer_wise_weights, +) +from fusion_bench.utils.data import load_tensor_from_file +from fusion_bench.utils.type import TorchModelType + +if TYPE_CHECKING: + from fusion_bench.programs.fabric_fusion_program import FabricModelFusionProgram + +from fusion_bench.mixins.simple_profiler import SimpleProfilerMixin +from fusion_bench.modelpool import BaseModelPool +from fusion_bench.utils.data import InfiniteDataLoader +from fusion_bench.utils.state_dict_arithmetic import ( + state_dict_add, + state_dict_mul, + state_dict_sub, +) +from fusion_bench.utils.type import StateDictType +from fusion_bench.utils import instantiate + +log = logging.getLogger(__name__) + + +# @torch.no_grad() +# def task_arithmetic_merge( +# pretrained_model: nn.Module, +# finetuned_models: List[Dict[str, Tensor]], +# scaling_factor: float, +# inplace: bool = True, +# ) -> nn.Module: +# """ +# Merges the task vectors from multiple fine-tuned models into a single pre-trained model. + +# Args: +# pretrained_model (nn.Module): The pre-trained model to which the task vectors will be added. +# finetuned_models (List[nn.Module]): A list of fine-tuned models from which task vectors will be calculated. +# scaling_factor (float): A factor by which the task vectors will be scaled before merging. +# inplace (bool, optional): If True, the pre-trained model will be modified in place. +# If False, a copy of the pre-trained model will be modified. Defaults to True. + +# Returns: +# nn.Module: The pre-trained model with the merged task vectors. +# """ +# if not inplace: +# pretrained_model = deepcopy(pretrained_model) +# if isinstance(finetuned_models[0], nn.Module): +# finetuned_models = [deepcopy(model.state_dict(keep_vars=True)) for model in finetuned_models] +# task_vector: StateDictType = None +# # Calculate the total task vector +# for model in finetuned_models: +# if task_vector is None: +# task_vector = state_dict_sub( +# model, +# pretrained_model.state_dict(keep_vars=True), +# ) +# else: +# task_vector = state_dict_add( +# task_vector, +# state_dict_sub( +# model, +# pretrained_model.state_dict(keep_vars=True), +# ), +# ) +# # scale the task vector +# task_vector = state_dict_mul(task_vector, scaling_factor) +# # add the task vector to the pretrained model +# state_dict = state_dict_add( +# pretrained_model.state_dict(keep_vars=True), task_vector +# ) +# pretrained_model.load_state_dict(state_dict) +# return pretrained_model + + +# @torch.no_grad() +# def ties_merge( +# pretrained_model: nn.Module, +# finetuned_models: List[Dict[str, Tensor]], +# scaling_factor: float, +# threshold: float, +# ) -> nn.Module: +# remove_keys = [] +# merge_func = "sum" +# if isinstance(finetuned_models[0], nn.Module): +# finetuned_models = [deepcopy(model.state_dict(keep_vars=True)) for model in finetuned_models] + +# ptm_check = pretrained_model.state_dict(keep_vars=True) + +# # Compute the task vectors +# flat_ft = torch.vstack( +# [state_dict_to_vector(check, remove_keys) for check in finetuned_models] +# ) +# flat_ptm = state_dict_to_vector(ptm_check, remove_keys) +# tv_flat_checks = flat_ft - flat_ptm + +# # Perform TIES Merging +# merged_tv = ties_merging( +# tv_flat_checks, +# reset_thresh=threshold, +# merge_func=merge_func, +# ) +# merged_check = flat_ptm + scaling_factor * merged_tv +# merged_state_dict = vector_to_state_dict( +# merged_check, ptm_check, remove_keys=remove_keys +# ) + +# # Load the merged state dict into the pretrained model +# pretrained_model.load_state_dict(merged_state_dict) +# return pretrained_model + +@torch.no_grad() +def task_arithmetic_merge( + merged_model: LayerWiseMergedModel, + finetuned_models: List[Dict[str, Tensor]], + indices: List[int] = None, + scaling_factor: float = 1.0, +) -> nn.Module: + """ + Merges the task vectors from multiple fine-tuned models into a single pre-trained model. + + Args: + pretrained_model (nn.Module): The pre-trained model to which the task vectors will be added. + finetuned_models (List[nn.Module]): A list of fine-tuned models from which task vectors will be calculated. + scaling_factor (float): A factor by which the task vectors will be scaled before merging. + inplace (bool, optional): If True, the pre-trained model will be modified in place. + If False, a copy of the pre-trained model will be modified. Defaults to True. + + Returns: + nn.Module: The pre-trained model with the merged task vectors. + """ + print("shapes: ", merged_model.merge_weight.shape, len(indices)) + print(indices) + # Directly edit merge_weight of the merged model + for l, model_index in enumerate(indices): + merged_model.merge_weight[model_index, l] += scaling_factor + merged_model.merge_weights() + return merged_model + +def entropy_loss(logits: Tensor, pred = None, eps: float = 1e-8) -> Tensor: + """ + Compute the entropy loss of a set of logits. + + Args: + logits (Tensor): The logits to compute the entropy loss of. + eps (float): A small value to avoid log(0). Default is 1e-8. + + Returns: + Tensor: The entropy loss of the logits. + """ + # Ensure the logits tensor has 2 dimensions + assert ( + logits.dim() == 2 + ), f"Expected logits to have 2 dimensions, found {logits.dim()}, {logits.size()=}" + + # Compute the softmax probabilities + probs = torch.softmax(logits, dim=-1) + + # Compute the entropy loss + return -torch.sum(probs * torch.log(probs + eps), dim=-1).mean() + + +class FrankWolfeHardAdamergingLossApproxAlgorithm( + CLIPClassificationMixin, + ModelFusionAlgorithm, + SimpleProfilerMixin, +): + + + def __init__(self, + merge_fn: str, + step_size: float, + max_iters: int, + dataset_size:int, + tasks: List[str] = [], + granularity: str = 'task', + max_num_models: int = 100, + loss_fn: str = "cross_entropy", + init_weight: str = "", + scaling_factor: float = 1., + threshold: int = 20, + init_layer_weights: float = None, + ada_merge: bool = True, + ada_max_steps: int = 500, + ada_optimizer: str = "adam", + ada_lr: float = 1e-3, + **kwargs): + """ + Initializes the TaskArithmeticAlgorithm with the given scaling factor. + + Args: + scaling_factor (int): The factor by which the task vectors will be scaled before merging. + """ + # self.merger = merge_fn + # if merge_fn == "task_arithmetic": + # self.merge_fn = task_arithmetic_merge + # elif merge_fn == "ties": + # self.merge_fn = partial(ties_merge, threshold=threshold) + # # elif merge_fn == "concrete_ta": + # # self.merge_fn = ConcreteTaskArithmeticAlgorithmForCLIP( + # # instantiate(OmegaConf.load("config/method/concrete_subspace/clip_concrete_task_arithmetic.yaml")) + # # ) + # else: + # raise ValueError(f"Unsupported merge_fn: {merge_fn}") + self.scaling_factor = scaling_factor + + self.init_weight = init_weight + self.step_size = step_size + self.max_iters = max_iters + self.granularity = granularity + self.loss_fn = loss_fn + self.tasks = tasks + self.dataset_size = dataset_size + self.max_num_models = max_num_models + + self.ada_merge = ada_merge + self.init_layer_weights = init_layer_weights + self.ada_max_steps = ada_max_steps + self.ada_optimizer = ada_optimizer + self.ada_lr = ada_lr + super().__init__(**kwargs) + + + def on_frank_wolfe_iteration_start(self): + self.setup_zero_shot_classification_head() + + def calculate_projection(self, pretrained_model: nn.Module, finetuned_models: List[nn.Module]): + # Compute the svd and projection here + pretrained_sd = pretrained_model.state_dict(keep_vars=True) + filtered_keys = [ + k + for k in pretrained_sd.keys() + if ("encoder" in k and "layer_norm" not in k and "weight" in k) + ] + task_vectors = [] + for m in finetuned_models: + m.requires_grad_(False) + pretrained_model = pretrained_model.requires_grad_(False) + for model in finetuned_models: + model_sd = model.state_dict(keep_vars=True) + filtered_task_vector = { + k: (model_sd[k].to("cpu") - pretrained_sd[k].to("cpu")) for k in filtered_keys + } + task_vectors.append(filtered_task_vector) + + projection = {} + for layer_name in task_vectors[0].keys(): + for i, vector in enumerate(task_vectors): + layer_vector = vector[layer_name] + u, s, v = torch.linalg.svd(layer_vector, full_matrices=False) + if i == 0: + print(f"Computed SVD for {layer_name}...") + sum_u = torch.zeros_like(u, device=layer_vector.device) + sum_s = torch.zeros_like(s, device=layer_vector.device) + sum_v = torch.zeros_like(v, device=layer_vector.device) + + reduced_index_s = int(s.shape[0] / len(task_vectors)) + + # select only the first reduced_index_s columns of u and place them + sum_u[:, i * reduced_index_s : (i + 1) * reduced_index_s] = u[ + :, :reduced_index_s + ] + sum_s[i * reduced_index_s : (i + 1) * reduced_index_s] = s[ + :reduced_index_s + ] + # select only the first reduced_index_s rows of v and place them + sum_v[i * reduced_index_s : (i + 1) * reduced_index_s, :] = v[ + :reduced_index_s, : + ] + # SVD of shared subspace to avoid overlapping task vectors + u_u, s_u, v_u = torch.linalg.svd(sum_u, full_matrices=False) + # u_v, s_v, v_v = torch.linalg.svd(sum_v, full_matrices=False) + layer_proj = torch.matmul( + u_u[:, : int(s.shape[0] / len(task_vectors))], + u_u[:, : int(s.shape[0] / len(task_vectors))].T, + ) + projection[layer_name] = layer_proj.to("cpu") # Projection matrix for each layer + + for m in finetuned_models: + m.requires_grad_(True) + pretrained_model = pretrained_model.requires_grad_(True) + return projection, task_vectors + + @functools.cache + def get_shuffled_loader_iter(self, task: str): + if self.loss_fn == "cross_entropy": + # get dataloader kwargs + dataloader_kwargs = self._dataloader_kwargs.copy() + dataloader_kwargs["shuffle"] = True + dataloader_kwargs["batch_size"] = 1 + + # get the test dataset + clip_dataset = CLIPDataset( + self.modelpool.load_train_dataset(task), self.clip_processor + ) + # create the dataloader + loader = DataLoader(clip_dataset, **dataloader_kwargs) + loader = self.fabric.setup_dataloaders(loader) + return iter(InfiniteDataLoader(loader)) + elif self.loss_fn == "entropy": + return super().get_shuffled_test_loader_iter( + task, + batch_size=1, + ) + else: + raise ValueError(f"Unsupported loss function: {self.loss_fn}") + + + def frank_wolfe_iteration(self, merged_model): + + merged_model.train() + # zero the gradients + for name, param in merged_model.named_parameters(): + param.requires_grad = True + param.grad = None + sd = merged_model.state_dict(keep_vars=True) + + losses = defaultdict(list) + gradients = {} + + for layer_name in self.task_vectors[0].keys(): + task_layer_vectors = torch.stack([vec[layer_name] for vec in self.task_vectors]) + merged_model_layer_vector = sd[layer_name].to("cpu") + initial_model_layer_vector = self.initial_model.state_dict(keep_vars=True)[layer_name].to("cpu") + losses[layer_name] = 0.0 + for task_layer_vector in task_layer_vectors: + # -layer_vector + part_1 = -task_layer_vector + # merged_model - layer_vector + part_2 = merged_model_layer_vector - initial_model_layer_vector - task_layer_vector + # dot product between part_1 and part_2 + inner_product = torch.sum(part_1 * part_2) + result = inner_product * inner_product + losses[layer_name] += result + + print(f"Layer: {layer_name}, DoGE Loss: {losses[layer_name].item()}") + # calculate the gradients + losses[layer_name].backward() + g = sd[layer_name].grad.clone().to("cpu") + g = g - self.projection[layer_name].to("cpu") @ g + gradients[layer_name] = g + sd[layer_name].grad = None + + + # calculate the loss + avg_loss = sum(losses.values()) / len(self.task_vectors) + log.info(f"Average Loss: {avg_loss}, Total Loss: {sum(losses.values())}") + + for name, param in merged_model.named_parameters(): + param.grad = None + merged_model.eval() + + return gradients + + + def frank_wolfe_selection(self, gradients, checkpoints, model_to_merge_names={}, type='task'): + assert type in ['task', 'layer'], f"Unsupported FW selection type: {type}, supported types are ['task', 'layer']" + min_inner_product = float("inf") + min_model = None + min_model_name = None + log_dict = {} + if type == 'task': + for model_name, model_to_merge in checkpoints.items(): + model_to_merge = model_to_merge.to('cuda').state_dict() + inner_product_sum = 0 + for param_name, param_value in model_to_merge.items(): + # caclulate consine similarity + grad = (gradients[param_name] if param_name in gradients else torch.zeros_like(param_value)).to("cpu") + ckpt = model_to_merge[param_name].to("cpu") + param_alignment = torch.dot(grad.flatten(), ckpt.flatten()) / (torch.norm(grad) * torch.norm(ckpt)) + inner_product_sum += param_alignment + log_dict[model_name] = inner_product_sum.item() + if inner_product_sum < min_inner_product and model_name not in model_to_merge_names: + min_inner_product = inner_product_sum + min_model = deepcopy(model_to_merge) + min_model_name = model_name + else: + min_model = {} + min_inner_product = {} + min_idx = {} + min_model_name = {} + for model_name, model_to_merge in checkpoints.items(): + model_to_merge = model_to_merge.to('cuda').state_dict() + for param_name, param_value in model_to_merge.items(): + # caclulate consine similarity + grad = (gradients[param_name] if param_name in gradients else torch.zeros_like(param_value)).to("cpu") + ckpt = model_to_merge[param_name].to("cpu") + param_alignment = torch.dot(grad.flatten(), ckpt.flatten()) / (torch.norm(grad) * torch.norm(ckpt)) + if (param_name not in min_inner_product or param_alignment < min_inner_product[param_name]) and \ + model_name not in model_to_merge_names[param_name]: + min_inner_product[param_name] = param_alignment + # if min_inner_product[param_name] < 0: + min_model[param_name] = param_value + min_idx[param_name] = model_name + min_model_name[param_name] = model_name + # else: + # min_model[param_name] = torch.zeros_like(param_value) + min_inner_product = sum(min_inner_product.values()) + log_dict = {model_name: 0 for model_name in checkpoints.keys()} + for k in min_idx.values(): + log_dict[k] += 1 + + return min_model, min_model_name, min_inner_product, log_dict + + + def run_adamerging(self, module: "LayerWiseMergedModel[TorchModelType]"): + # configure optimizer + if self.ada_optimizer == "adam": + optimizer = torch.optim.Adam([module.merge_weight], lr=self.ada_lr) + print(f"{optimizer=}") + module, optimizer = self.fabric.setup(module, optimizer) + else: + raise ValueError(f"Unsupported optimizer: {self.ada_optimizer}") + + module.train() + module.merge_weights() + for step_idx in ( + pbar := tqdm( + range(self.ada_max_steps), + "AdaMerging Test-time adaptation", + dynamic_ncols=True, + ) + ): + # default behavior for first-order optimizers + for task in self.tasks: + with self.profile("data loading"): + batch = next(self.get_shuffled_test_loader_iter(task)) + with self.profile("forward pass"): + logits = self.compute_logits(module, batch[0], task) + loss = entropy_loss(logits) + with self.profile("backward pass"): + self.fabric.backward(loss, retain_graph=True) + + with self.profile("optimizer step"): + optimizer.step() + optimizer.zero_grad() + with self.profile("merging weights"): + module.merge_weights() + + metrics = { + "train/loss": loss.item(), + "train/weight_max": module.merge_weight.max().item(), + "train/weight_min": module.merge_weight.min().item(), + "train/weight_mean": module.merge_weight.mean().item(), + } + self.fabric.log_dict(metrics, step=step_idx) + pbar.set_postfix(metrics) + + self.print_profile_summary() + return module + + + def run(self, modelpool: HuggingFaceClipVisionPool): + log.info("Fusing models using FW merging.") + self.modelpool = modelpool + # self.log_hyperparams(self.config) + self.on_frank_wolfe_iteration_start() + + assert modelpool.has_pretrained, "Pretrained model is required." + finetuned_models = {name: modelpool.load_model(name) for name in modelpool.model_names[:self.max_num_models]} + pretrained_model = modelpool.load_model("_pretrained_") + + if self.init_weight: + if self.init_weight == 'base': + log.info("Initializing the merged model with the base model") + merged_model = pretrained_model + else: + log.info("Initializing the merged model with the initial weight") + if isinstance(self.init_weight, str): + # self.config.weights is a path to a saved tensor + layer_wise_weight = load_tensor_from_file(self.init_weight) + else: + raise ValueError(f"Unsupported weights format: {self.init_weight}") + + merged_model = LayerWiseMergedModel( + layer_wise_weight=layer_wise_weight, + pretrained_model=modelpool.load_model("_pretrained_"), + finetuned_models=list(finetuned_models.values()), + clamp_weights=False, + tie_weights=True, + strict=False, + ) + merged_model = merged_model.merge_and_unload() + + # Construct the layer-wise merged model instead + self.initial_model = modelpool.load_model("_pretrained_") + self.initial_model.load_state_dict(merged_model.state_dict()) + finetuned_models['initial'] = self.initial_model + # Create layer-wise weights + layer_wise_weight = get_layer_wise_weights( + num_models=len(finetuned_models), + num_layers=len( + tuple( + filter(lambda p: p.requires_grad, pretrained_model.parameters()) + ) + ), + init_values= self.init_layer_weights if self.init_layer_weights is not None else 0.0, + ) + # Change the last row (initial model) to 1.0 + layer_wise_weight[-1, :] = 1.0 + # Create the merged model with the layer-wise weights + with torch.no_grad(): + self.set_requires_grad(merged_model, self.initial_model) + merged_model = LayerWiseMergedModel( + layer_wise_weight=layer_wise_weight, + pretrained_model=merged_model, + finetuned_models=list(finetuned_models.values()), + clamp_weights=False, + tie_weights=True, + strict=False, + ).cuda() + else: + raise ValueError("`init_weight` must be base or specified in the config file.") + + + self.projection, self.task_vectors = self.calculate_projection(pretrained_model, finetuned_models.values()) + # FW iteration + for step_idx in ( + pbar := tqdm( + range(self.max_iters if not self.is_debug_mode else 1), + ("[DEBUG MODE] " if self.is_debug_mode else "") + + "Frank-Wolfe Merging", + dynamic_ncols=True, + ) + ): + torch.cuda.empty_cache() + merged_model_tmp = merged_model.merge_and_copy().cuda() + with torch.no_grad(): + self.set_requires_grad(merged_model_tmp, self.initial_model) + torch.set_grad_enabled(True) + gradients = self.frank_wolfe_iteration(merged_model_tmp) + torch.set_grad_enabled(False) + grad_norm = torch.norm(torch.stack([torch.norm(g) for g in gradients.values()])) + + model_to_merge_names = [] if self.granularity == 'task' else {name: [] for name in merged_model.state_dict().keys()} + min_model, min_model_name, min_alignment, chosen_model = self.frank_wolfe_selection(gradients, finetuned_models, model_to_merge_names=model_to_merge_names, type=self.granularity) + + # Determine step size + step = 2 / (step_idx + 2) * self.step_size + + # print iteration information + log.info(f"Iteration {step_idx+1}, Task Vector: {min_model_name}, Gradient Norm: {grad_norm:.6f}, Inner Products: {min_alignment:.6f}, Chosen Model: {chosen_model}") + + # Calculate position of each min_model in the finetuned model lists in each layer + if self.granularity == 'task': + indices = [list(finetuned_models.keys()).index(min_model_name) for _ in range(len(merged_model.state_dict()))] + else: + indices = [list(finetuned_models.keys()).index(min_model_name[param_name]) for param_name in merged_model.state_dict().keys()] + merged_model = task_arithmetic_merge( + merged_model=merged_model, + finetuned_models=[min_model], + indices=indices, + scaling_factor=step * self.scaling_factor, + ) + + # Merge model with Adamerging + if self.ada_merge: + print("number of models to merge: ", len(modelpool.model_names)) + torch.set_grad_enabled(True) + merged_model = self.run_adamerging(merged_model) + torch.set_grad_enabled(False) + + with torch.no_grad(): + merged_model = merged_model.merge_and_unload() + self.set_requires_grad(merged_model, self.initial_model) + # eval and return model + merged_model = merged_model.cuda().eval() + return merged_model + + def set_requires_grad(self, merged_model, initial_model): + for name, param in initial_model.named_parameters(): + for n, p in merged_model.named_parameters(): + if name == n: + p.requires_grad = param.requires_grad diff --git a/fusion_bench/method/fw_merging/fw_hard_loss_approx.py b/fusion_bench/method/fw_merging/fw_hard_loss_approx.py new file mode 100644 index 00000000..b707e5a7 --- /dev/null +++ b/fusion_bench/method/fw_merging/fw_hard_loss_approx.py @@ -0,0 +1,476 @@ +""" +This script contains the general implementation of the Task Arithmetic method. + +http://arxiv.org/abs/2212.04089 +""" + +import logging +import os +from abc import abstractmethod +from typing import TYPE_CHECKING, Any, List, Mapping, TypeVar, Union, Dict +from copy import deepcopy +from collections import defaultdict +from functools import partial +import functools + +from .utils import * + +import torch +from lightning.fabric.utilities.rank_zero import rank_zero_only +from omegaconf import DictConfig +from torch import Tensor, nn +from torch.utils.data import DataLoader +from tqdm.autonotebook import tqdm + +from fusion_bench.compat.method import ModelFusionAlgorithm +from fusion_bench.dataset.clip_dataset import CLIPDataset +from fusion_bench.mixins import CLIPClassificationMixin +from fusion_bench.compat.modelpool import ModelPool, HuggingFaceClipVisionPool +from fusion_bench.mixins.lightning_fabric import LightningFabricMixin +from fusion_bench.mixins.simple_profiler import SimpleProfilerMixin +from fusion_bench.models.wrappers.layer_wise_fusion import ( + LayerWiseMergedModel, + get_layer_wise_weights, +) +from fusion_bench.utils.data import load_tensor_from_file +from fusion_bench.utils.type import TorchModelType + +if TYPE_CHECKING: + from fusion_bench.programs.fabric_fusion_program import FabricModelFusionProgram + +from fusion_bench.mixins.simple_profiler import SimpleProfilerMixin +from fusion_bench.modelpool import BaseModelPool +from fusion_bench.utils.data import InfiniteDataLoader +from fusion_bench.utils.state_dict_arithmetic import ( + state_dict_add, + state_dict_mul, + state_dict_sub, +) +from fusion_bench.utils.type import StateDictType +from fusion_bench.utils import instantiate + +log = logging.getLogger(__name__) + + +@torch.no_grad() +def task_arithmetic_merge( + pretrained_model: nn.Module, + finetuned_models: List[Dict[str, Tensor]], + scaling_factor: float, + inplace: bool = True, +) -> nn.Module: + """ + Merges the task vectors from multiple fine-tuned models into a single pre-trained model. + + Args: + pretrained_model (nn.Module): The pre-trained model to which the task vectors will be added. + finetuned_models (List[nn.Module]): A list of fine-tuned models from which task vectors will be calculated. + scaling_factor (float): A factor by which the task vectors will be scaled before merging. + inplace (bool, optional): If True, the pre-trained model will be modified in place. + If False, a copy of the pre-trained model will be modified. Defaults to True. + + Returns: + nn.Module: The pre-trained model with the merged task vectors. + """ + if not inplace: + pretrained_model = deepcopy(pretrained_model) + if isinstance(finetuned_models[0], nn.Module): + finetuned_models = [deepcopy(model.state_dict(keep_vars=True)) for model in finetuned_models] + task_vector: StateDictType = None + # Calculate the total task vector + for model in finetuned_models: + print(model.keys()) + print("================") + print(pretrained_model.state_dict(keep_vars=True).keys()) + if task_vector is None: + task_vector = state_dict_sub( + model, + pretrained_model.state_dict(keep_vars=True), + ) + else: + task_vector = state_dict_add( + task_vector, + state_dict_sub( + model, + pretrained_model.state_dict(keep_vars=True), + ), + ) + # scale the task vector + task_vector = state_dict_mul(task_vector, scaling_factor) + # add the task vector to the pretrained model + state_dict = state_dict_add( + pretrained_model.state_dict(keep_vars=True), task_vector + ) + pretrained_model.load_state_dict(state_dict) + return pretrained_model + + +@torch.no_grad() +def ties_merge( + pretrained_model: nn.Module, + finetuned_models: List[Dict[str, Tensor]], + scaling_factor: float, + threshold: float, +) -> nn.Module: + remove_keys = [] + merge_func = "sum" + if isinstance(finetuned_models[0], nn.Module): + finetuned_models = [deepcopy(model.state_dict(keep_vars=True)) for model in finetuned_models] + + ptm_check = pretrained_model.state_dict(keep_vars=True) + + # Compute the task vectors + flat_ft = torch.vstack( + [state_dict_to_vector(check, remove_keys) for check in finetuned_models] + ) + flat_ptm = state_dict_to_vector(ptm_check, remove_keys) + tv_flat_checks = flat_ft - flat_ptm + + # Perform TIES Merging + merged_tv = ties_merging( + tv_flat_checks, + reset_thresh=threshold, + merge_func=merge_func, + ) + merged_check = flat_ptm + scaling_factor * merged_tv + merged_state_dict = vector_to_state_dict( + merged_check, ptm_check, remove_keys=remove_keys + ) + + # Load the merged state dict into the pretrained model + pretrained_model.load_state_dict(merged_state_dict) + return pretrained_model + +def entropy_loss(logits: Tensor, pred = None, eps: float = 1e-8) -> Tensor: + """ + Compute the entropy loss of a set of logits. + + Args: + logits (Tensor): The logits to compute the entropy loss of. + eps (float): A small value to avoid log(0). Default is 1e-8. + + Returns: + Tensor: The entropy loss of the logits. + """ + # Ensure the logits tensor has 2 dimensions + assert ( + logits.dim() == 2 + ), f"Expected logits to have 2 dimensions, found {logits.dim()}, {logits.size()=}" + + # Compute the softmax probabilities + probs = torch.softmax(logits, dim=-1) + + # Compute the entropy loss + return -torch.sum(probs * torch.log(probs + eps), dim=-1).mean() + + +class FrankWolfeHardLossApproxAlgorithm( + CLIPClassificationMixin, + ModelFusionAlgorithm, + SimpleProfilerMixin, +): + + + def __init__(self, + merge_fn: str, + step_size: float, + max_iters: int, + dataset_size:int, + tasks: List[str] = [], + granularity: str = 'task', + max_num_models: int = 100, + loss_fn: str = "cross_entropy", + init_weight: str = "", + scaling_factor: float = 1., + threshold: int = 20, + **kwargs): + """ + Initializes the TaskArithmeticAlgorithm with the given scaling factor. + + Args: + scaling_factor (int): The factor by which the task vectors will be scaled before merging. + """ + self.merger = merge_fn + if merge_fn == "task_arithmetic": + self.merge_fn = task_arithmetic_merge + elif merge_fn == "ties": + self.merge_fn = partial(ties_merge, threshold=threshold) + # elif merge_fn == "concrete_ta": + # self.merge_fn = ConcreteTaskArithmeticAlgorithmForCLIP( + # instantiate(OmegaConf.load("config/method/concrete_subspace/clip_concrete_task_arithmetic.yaml")) + # ) + else: + raise ValueError(f"Unsupported merge_fn: {merge_fn}") + self.scaling_factor = scaling_factor + + self.init_weight = init_weight + self.step_size = step_size + self.max_iters = max_iters + self.granularity = granularity + self.loss_fn = loss_fn + self.tasks = tasks + self.dataset_size = dataset_size + self.max_num_models = max_num_models + super().__init__(**kwargs) + + + def on_frank_wolfe_iteration_start(self): + self.setup_zero_shot_classification_head() + + def calculate_projection(self, pretrained_model: nn.Module, finetuned_models: List[nn.Module]): + # Compute the svd and projection here + pretrained_sd = pretrained_model.state_dict(keep_vars=True) + filtered_keys = [ + k + for k in pretrained_sd.keys() + if ("encoder" in k and "layer_norm" not in k and "weight" in k) + ] + task_vectors = [] + for m in finetuned_models: + m.requires_grad_(False) + pretrained_model = pretrained_model.requires_grad_(False) + for model in finetuned_models: + model_sd = model.state_dict(keep_vars=True) + filtered_task_vector = { + k: (model_sd[k].cuda() - pretrained_sd[k].cuda()) for k in filtered_keys + } + task_vectors.append(filtered_task_vector) + + projection = {} + for layer_name in task_vectors[0].keys(): + for i, vector in enumerate(task_vectors): + layer_vector = vector[layer_name] + u, s, v = torch.linalg.svd(layer_vector, full_matrices=False) + if i == 0: + print(f"Computed SVD for {layer_name}...") + sum_u = torch.zeros_like(u, device=layer_vector.device) + sum_s = torch.zeros_like(s, device=layer_vector.device) + sum_v = torch.zeros_like(v, device=layer_vector.device) + + reduced_index_s = int(s.shape[0] / len(task_vectors)) + + # select only the first reduced_index_s columns of u and place them + sum_u[:, i * reduced_index_s : (i + 1) * reduced_index_s] = u[ + :, :reduced_index_s + ] + sum_s[i * reduced_index_s : (i + 1) * reduced_index_s] = s[ + :reduced_index_s + ] + # select only the first reduced_index_s rows of v and place them + sum_v[i * reduced_index_s : (i + 1) * reduced_index_s, :] = v[ + :reduced_index_s, : + ] + # SVD of shared subspace to avoid overlapping task vectors + u_u, s_u, v_u = torch.linalg.svd(sum_u, full_matrices=False) + # u_v, s_v, v_v = torch.linalg.svd(sum_v, full_matrices=False) + layer_proj = torch.matmul( + u_u[:, : int(s.shape[0] / len(task_vectors))], + u_u[:, : int(s.shape[0] / len(task_vectors))].T, + ) + projection[layer_name] = layer_proj # Projection matrix for each layer + + for m in finetuned_models: + m.requires_grad_(True) + pretrained_model = pretrained_model.requires_grad_(True) + return projection, task_vectors + + @functools.cache + def get_shuffled_loader_iter(self, task: str): + if self.loss_fn == "cross_entropy": + # get dataloader kwargs + dataloader_kwargs = self._dataloader_kwargs.copy() + dataloader_kwargs["shuffle"] = True + dataloader_kwargs["batch_size"] = 1 + + # get the test dataset + clip_dataset = CLIPDataset( + self.modelpool.load_train_dataset(task), self.clip_processor + ) + # create the dataloader + loader = DataLoader(clip_dataset, **dataloader_kwargs) + loader = self.fabric.setup_dataloaders(loader) + return iter(InfiniteDataLoader(loader)) + elif self.loss_fn == "entropy": + return super().get_shuffled_test_loader_iter( + task, + batch_size=1, + ) + else: + raise ValueError(f"Unsupported loss function: {self.loss_fn}") + + + def frank_wolfe_iteration(self, merged_model): + + merged_model.train() + # zero the gradients + for name, param in merged_model.named_parameters(): + param.requires_grad = True + param.grad = None + sd = merged_model.state_dict(keep_vars=True) + + losses = defaultdict(list) + gradients = {} + + for layer_name in self.task_vectors[0].keys(): + task_layer_vectors = torch.stack([vec[layer_name] for vec in self.task_vectors]) + merged_model_layer_vector = sd[layer_name].cuda() + initial_model_layer_vector = self.initial_model.state_dict(keep_vars=True)[layer_name].cuda() + losses[layer_name] = 0.0 + for task_layer_vector in task_layer_vectors: + # -layer_vector + part_1 = -task_layer_vector + # merged_model - layer_vector + part_2 = merged_model_layer_vector - initial_model_layer_vector - task_layer_vector + # dot product between part_1 and part_2 + inner_product = torch.sum(part_1 * part_2) + result = inner_product * inner_product + losses[layer_name] += result + + print(f"Layer: {layer_name}, DoGE Loss: {losses[layer_name].item()}") + # calculate the gradients + losses[layer_name].backward() + g = sd[layer_name].grad.clone().to("cpu") + g = g - self.projection[layer_name].to("cpu") @ g + gradients[layer_name] = g + sd[layer_name].grad = None + + + # calculate the loss + avg_loss = sum(losses.values()) / len(self.task_vectors) + log.info(f"Average Loss: {avg_loss}, Total Loss: {sum(losses.values())}") + del losses + + for name, param in merged_model.named_parameters(): + param.grad = None + merged_model.eval() + + return gradients + + def frank_wolfe_selection(self, gradients, checkpoints, model_to_merge_names={}, type='task'): + assert type in ['task', 'layer'], f"Unsupported FW selection type: {type}, supported types are ['task', 'layer']" + min_inner_product = float("inf") + min_model = None + min_model_name = None + log_dict = {} + if type == 'task': + for model_name, model_to_merge in checkpoints.items(): + model_to_merge = model_to_merge.to('cpu').state_dict() + inner_product_sum = 0 + for param_name, param_value in model_to_merge.items(): + # caclulate consine similarity + grad = gradients[param_name] if param_name in gradients else torch.zeros_like(param_value) + ckpt = model_to_merge[param_name] + param_alignment = torch.dot(grad.flatten(), ckpt.flatten()) / (torch.norm(grad) * torch.norm(ckpt)) + inner_product_sum += param_alignment + log_dict[model_name] = inner_product_sum.item() + if inner_product_sum < min_inner_product and model_name not in model_to_merge_names: + min_inner_product = inner_product_sum + min_model = deepcopy(model_to_merge) + min_model_name = model_name + else: + min_model = {} + min_inner_product = {} + min_idx = {} + min_model_name = {} + for model_name, model_to_merge in checkpoints.items(): + model_to_merge = model_to_merge.to('cpu').state_dict() + for param_name, param_value in model_to_merge.items(): + # caclulate consine similarity + grad = gradients[param_name] if param_name in gradients else torch.zeros_like(param_value) + ckpt = model_to_merge[param_name] + param_alignment = torch.dot(grad.flatten(), ckpt.flatten()) / (torch.norm(grad) * torch.norm(ckpt)) + if (param_name not in min_inner_product or param_alignment < min_inner_product[param_name]) and \ + model_name not in model_to_merge_names[param_name]: + min_inner_product[param_name] = param_alignment + # if min_inner_product[param_name] < 0: + min_model[param_name] = param_value + min_idx[param_name] = model_name + min_model_name[param_name] = model_name + # else: + # min_model[param_name] = torch.zeros_like(param_value) + min_inner_product = sum(min_inner_product.values()) + log_dict = {model_name: 0 for model_name in checkpoints.keys()} + for k in min_idx.values(): + log_dict[k] += 1 + + return min_model, min_model_name, min_inner_product, log_dict + + + + def run(self, modelpool: HuggingFaceClipVisionPool): + log.info("Fusing models using FW merging.") + self.modelpool = modelpool + self.log_hyperparams(self.config) + self.on_frank_wolfe_iteration_start() + + assert modelpool.has_pretrained, "Pretrained model is required." + finetuned_models = {name: modelpool.load_model(name) for name in modelpool.model_names[:self.max_num_models]} + pretrained_model = modelpool.load_model("_pretrained_") + + if self.init_weight: + if self.init_weight == 'base': + log.info("Initializing the merged model with the base model") + merged_model = pretrained_model + else: + log.info("Initializing the merged model with the initial weight") + if isinstance(self.init_weight, str): + # self.config.weights is a path to a saved tensor + layer_wise_weight = load_tensor_from_file(self.init_weight) + else: + raise ValueError(f"Unsupported weights format: {self.init_weight}") + + merged_model = LayerWiseMergedModel( + layer_wise_weight=layer_wise_weight, + pretrained_model=modelpool.load_model("_pretrained_"), + finetuned_models=list(finetuned_models.values()), + clamp_weights=False, + tie_weights=True, + strict=False, + ).cuda() + merged_model = merged_model.merge_and_unload() + else: + log.info("Initializing the merged model with merge function") + merged_model = self.merge_fn( + pretrained_model=modelpool.load_model("_pretrained_"), + finetuned_models=list(finetuned_models.values()), + scaling_factor=self.scaling_factor + ).cuda() + # merged_model = self.fabric.setup(merged_model) + + self.initial_model = modelpool.load_model("_pretrained_") + self.initial_model.load_state_dict(deepcopy(merged_model.state_dict())) + finetuned_models['initial'] = self.initial_model + # calculate projection + self.projection, self.task_vectors = self.calculate_projection(pretrained_model, finetuned_models.values()) + for step_idx in ( + pbar := tqdm( + range(self.max_iters if not self.is_debug_mode else 1), + ("[DEBUG MODE] " if self.is_debug_mode else "") + + "Frank-Wolfe Merging", + dynamic_ncols=True, + ) + ): + torch.cuda.empty_cache() + torch.set_grad_enabled(True) + gradients = self.frank_wolfe_iteration(merged_model.cuda()) + torch.set_grad_enabled(False) + grad_norm = torch.norm(torch.stack([torch.norm(g) for g in gradients.values()])) + + model_to_merge_names = [] if self.granularity == 'task' else {name: [] for name in merged_model.state_dict().keys()} + min_model, min_model_name, min_alignment, chosen_model = self.frank_wolfe_selection(gradients, finetuned_models, model_to_merge_names=model_to_merge_names, type=self.granularity) + + # Determine step size + step = 2 / (step_idx + 2) * self.step_size + + # print iteration information + log.info(f"Iteration {step_idx+1}, Task Vector: {min_model_name}, Gradient Norm: {grad_norm:.6f}, Inner Products: {min_alignment:.6f}, Chosen Model: {chosen_model}") + + merged_model = self.merge_fn( + pretrained_model=merged_model.to('cpu'), + finetuned_models=[min_model], + scaling_factor=step*self.scaling_factor, + ) + + torch.set_grad_enabled(False) + merged_model = merged_model.cuda().eval() + return merged_model diff --git a/fusion_bench/method/fw_merging/fw_soft_loss_approx.py b/fusion_bench/method/fw_merging/fw_soft_loss_approx.py new file mode 100644 index 00000000..0164d2f3 --- /dev/null +++ b/fusion_bench/method/fw_merging/fw_soft_loss_approx.py @@ -0,0 +1,574 @@ +""" +This script contains the general implementation of the Task Arithmetic method. + +http://arxiv.org/abs/2212.04089 +""" + +import logging +import os +from abc import abstractmethod +from typing import TYPE_CHECKING, Any, List, Mapping, TypeVar, Union, Dict +from copy import deepcopy +from collections import defaultdict +from functools import partial +import numpy as np +import functools +import gc + +from .utils import * + +import torch +from lightning.fabric.utilities.rank_zero import rank_zero_only +from omegaconf import DictConfig +from torch import Tensor, nn +from torch.utils.data import DataLoader +from tqdm.autonotebook import tqdm + +from fusion_bench.compat.method import ModelFusionAlgorithm +from fusion_bench.dataset.clip_dataset import CLIPDataset +from fusion_bench.mixins import CLIPClassificationMixin +from fusion_bench.compat.modelpool import ModelPool, HuggingFaceClipVisionPool +from fusion_bench.mixins.lightning_fabric import LightningFabricMixin +from fusion_bench.mixins.simple_profiler import SimpleProfilerMixin +from fusion_bench.models.wrappers.layer_wise_fusion import ( + LayerWiseMergedModel, + get_layer_wise_weights, +) +from fusion_bench.utils.data import load_tensor_from_file +from fusion_bench.utils.type import TorchModelType + +if TYPE_CHECKING: + from fusion_bench.programs.fabric_fusion_program import FabricModelFusionProgram + +from fusion_bench.mixins.simple_profiler import SimpleProfilerMixin +from fusion_bench.modelpool import BaseModelPool +from fusion_bench.utils.data import InfiniteDataLoader +from fusion_bench.utils.state_dict_arithmetic import ( + state_dict_add, + state_dict_mul, + state_dict_sub, +) +from fusion_bench.utils.type import StateDictType +from fusion_bench.utils import instantiate + +log = logging.getLogger(__name__) + + +def projection_simplex_sort(v, z=1): + # print(v.shape) + n_features = v.shape[0] # Get the number of elements in v + u, _ = torch.sort(v, descending=True) # Sort v in descending order + cssv = torch.cumsum(u, dim=0) - z # Compute cumulative sum and subtract z + ind = torch.arange(1, n_features + 1, dtype=torch.long, device=v.device) # Create index tensor (1 to n_features) + cond = u - cssv / ind > 0 # Condition to find rho + if cond.any(): # Ensure there is at least one valid rho + rho = ind[cond][-1] # Find the largest index satisfying the condition + theta = cssv[rho - 1] / rho # Compute the correct threshold theta + else: + theta = 0 # Default case when all values are zero or negative + w = torch.clamp(v - theta, min=0) # Compute the projected vector, ensuring non-negativity + return w + + +@torch.no_grad() +def task_arithmetic_merge( + pretrained_model: nn.Module, + finetuned_models: List[Dict[str, Tensor]], + scaling_factor: float, + inplace: bool = True, +) -> nn.Module: + """ + Merges the task vectors from multiple fine-tuned models into a single pre-trained model. + + Args: + pretrained_model (nn.Module): The pre-trained model to which the task vectors will be added. + finetuned_models (List[nn.Module]): A list of fine-tuned models from which task vectors will be calculated. + scaling_factor (float): A factor by which the task vectors will be scaled before merging. + inplace (bool, optional): If True, the pre-trained model will be modified in place. + If False, a copy of the pre-trained model will be modified. Defaults to True. + + Returns: + nn.Module: The pre-trained model with the merged task vectors. + """ + if not inplace: + pretrained_model = deepcopy(pretrained_model) + if isinstance(finetuned_models[0], nn.Module): + finetuned_models = [deepcopy(model.state_dict(keep_vars=True)) for model in finetuned_models] + task_vector: StateDictType = None + # Calculate the total task vector + for model in finetuned_models: + if task_vector is None: + task_vector = state_dict_sub( + model, + pretrained_model.state_dict(keep_vars=True), + ) + else: + task_vector = state_dict_add( + task_vector, + state_dict_sub( + model, + pretrained_model.state_dict(keep_vars=True), + ), + ) + # scale the task vector + task_vector = state_dict_mul(task_vector, scaling_factor) + # add the task vector to the pretrained model + state_dict = state_dict_add( + pretrained_model.state_dict(keep_vars=True), task_vector + ) + pretrained_model.load_state_dict(state_dict) + return pretrained_model + + +def entropy_loss(logits: Tensor, pred = None, eps: float = 1e-8) -> Tensor: + """ + Compute the entropy loss of a set of logits. + + Args: + logits (Tensor): The logits to compute the entropy loss of. + eps (float): A small value to avoid log(0). Default is 1e-8. + + Returns: + Tensor: The entropy loss of the logits. + """ + # Ensure the logits tensor has 2 dimensions + assert ( + logits.dim() == 2 + ), f"Expected logits to have 2 dimensions, found {logits.dim()}, {logits.size()=}" + + # Compute the softmax probabilities + probs = torch.softmax(logits, dim=-1) + + # Compute the entropy loss + return -torch.sum(probs * torch.log(probs + eps), dim=-1).mean() + + +class FrankWolfeSoftLossApproxAlgorithm( + CLIPClassificationMixin, + ModelFusionAlgorithm, + SimpleProfilerMixin, +): + def __init__(self, + max_iters: int, + dataset_size:int, + ada_iters: int, + ada_coeff: float, + merge_fn: str, + granularity: str = "task", + max_num_models: int = 100, + step_size: float = 0.3, + tasks: List[str] = [], + init_weight: str = "", + ada_loss = "entropy_loss", + **kwargs): + """ + Initializes the TaskArithmeticAlgorithm with the given scaling factor. + + Args: + step_size (int): The factor by which the task vectors will be scaled before merging. + """ + self.merge_fn = merge_fn + + self.init_weight = init_weight + self.max_iters = max_iters + self.ada_iters = ada_iters + self.ada_coeff = ada_coeff + self.granularity = granularity + self.tasks = tasks + self.step_size = step_size + self.dataset_size = dataset_size + self.max_num_models = max_num_models + self.ada_loss = ada_loss + super().__init__(**kwargs) + + + def on_frank_wolfe_iteration_start(self): + self.setup_zero_shot_classification_head() + + def calculate_projection(self, pretrained_model: nn.Module, finetuned_models: List[nn.Module]): + # Compute the svd and projection here + pretrained_sd = pretrained_model.state_dict(keep_vars=True) + filtered_keys = [ + k + for k in pretrained_sd.keys() + if ("encoder" in k and "layer_norm" not in k and "weight" in k) + ] + task_vectors = [] + for m in finetuned_models: + m.requires_grad_(False) + pretrained_model = pretrained_model.requires_grad_(False) + for model in finetuned_models: + model_sd = model.state_dict(keep_vars=True) + filtered_task_vector = { + k: (model_sd[k].to("cpu") - pretrained_sd[k].to("cpu")).detach() for k in filtered_keys + } + task_vectors.append(filtered_task_vector) + + projection = {} + for layer_name in task_vectors[0].keys(): + for i, vector in enumerate(task_vectors): + layer_vector = vector[layer_name] + u, s, v = torch.linalg.svd(layer_vector, full_matrices=False) + if i == 0: + print(f"Computed SVD for {layer_name}...") + sum_u = torch.zeros_like(u, device=layer_vector.device) + sum_s = torch.zeros_like(s, device=layer_vector.device) + sum_v = torch.zeros_like(v, device=layer_vector.device) + + reduced_index_s = int(s.shape[0] / len(task_vectors)) + + # select only the first reduced_index_s columns of u and place them + sum_u[:, i * reduced_index_s : (i + 1) * reduced_index_s] = u[ + :, :reduced_index_s + ] + sum_s[i * reduced_index_s : (i + 1) * reduced_index_s] = s[ + :reduced_index_s + ] + # select only the first reduced_index_s rows of v and place them + sum_v[i * reduced_index_s : (i + 1) * reduced_index_s, :] = v[ + :reduced_index_s, : + ] + # SVD of shared subspace to avoid overlapping task vectors + u_u, s_u, v_u = torch.linalg.svd(sum_u, full_matrices=False) + # u_v, s_v, v_v = torch.linalg.svd(sum_v, full_matrices=False) + layer_proj = torch.matmul( + u_u[:, : int(s.shape[0] / len(task_vectors))], + u_u[:, : int(s.shape[0] / len(task_vectors))].T, + ) + projection[layer_name] = layer_proj # Projection matrix for each layer + + for m in finetuned_models: + m.requires_grad_(True) + pretrained_model = pretrained_model.requires_grad_(True) + return projection, task_vectors + + @functools.cache + def get_shuffled_train_loader_iter(self, task: str, batch_size: int = 1): + # get dataloader kwargs + dataloader_kwargs = self._dataloader_kwargs.copy() + dataloader_kwargs["shuffle"] = True + dataloader_kwargs["batch_size"] = batch_size + + # get the test dataset + clip_dataset = CLIPDataset( + self.modelpool.load_train_dataset(task), self.clip_processor + ) + # create the dataloader + loader = DataLoader(clip_dataset, **dataloader_kwargs) + loader = self.fabric.setup_dataloaders(loader) + return iter(InfiniteDataLoader(loader)) + + + @functools.cache + def get_shuffled_test_loader_iter(self, task: str, batch_size: int = 1): + return super().get_shuffled_test_loader_iter( + task, + batch_size=batch_size + ) + + + def run_adamerging(self, module: LayerWiseMergedModel[TorchModelType]): + use_entropy_loss = self.ada_loss == 'entropy_loss' + + optimizer = torch.optim.Adam( + [module.merge_weight], lr=1e-3 + ) + module, optimizer = self.fabric.setup(module, optimizer) + module.train() + for step_idx in ( + pbar := tqdm( + range(self.ada_iters), + "AdaMerging (2/2)", + dynamic_ncols=True, + disable=not self.fabric.is_global_zero, + ) + ): + with self.profile("merge weights"): + module.merge_weights() + + metrics = {} + total_loss = None + tasks = self.modelpool.model_names if self.tasks == [] else self.tasks + if not use_entropy_loss: + loss_fn = nn.CrossEntropyLoss() + for task in tasks: + with self.profile("data loading"): + if use_entropy_loss: + batch = next(self.get_shuffled_test_loader_iter(task, batch_size=16)) + else: + batch = next(self.get_shuffled_train_loader_iter(task, batch_size=16)) + # NOTE: The labels are not allowed to be used during test-time adaptation + images = batch[0] + with self.profile("forward pass"): + logits = self.compute_logits(module, images, task) + if use_entropy_loss: + loss = entropy_loss(logits) + else: + loss = loss_fn(logits, batch[1]) + total_loss = loss if total_loss is None else total_loss + loss + optimizer.zero_grad() + with self.profile("compute grad"): + self.fabric.backward(total_loss) + + with self.profile("base optimizer step"): + optimizer.step() + + metrics.update({"train/loss": loss.item()}) + self.fabric.log_dict(metrics, step=step_idx) + pbar.set_postfix(metrics) + return module + + + def frank_wolfe_iteration(self, merged_model): + merged_model.train() + # zero the gradients + for name, param in merged_model.named_parameters(): + param.requires_grad = True + param.grad = None + sd = merged_model.state_dict(keep_vars=True) + + losses = defaultdict(list) + gradients = {} + + for layer_name in self.task_vectors[0].keys(): + task_layer_vectors = torch.stack([vec[layer_name] for vec in self.task_vectors]) + merged_model_layer_vector = sd[layer_name].to("cpu") + initial_model_layer_vector = self.initial_model.state_dict(keep_vars=True)[layer_name].to("cpu") + losses[layer_name] = 0.0 + for task_layer_vector in task_layer_vectors: + # -layer_vector + part_1 = -task_layer_vector + # merged_model - layer_vector + part_2 = merged_model_layer_vector - initial_model_layer_vector - task_layer_vector + # dot product between part_1 and part_2 + inner_product = torch.sum(part_1 * part_2) + result = inner_product * inner_product + losses[layer_name] += result + + # print(f"Layer: {layer_name}, DoGE Loss: {losses[layer_name].item()}") + # calculate the gradients + losses[layer_name].backward(retain_graph=False) + g = sd[layer_name].grad.clone().to("cpu") + g = (g - self.projection[layer_name] @ g) + gradients[layer_name] = g.to("cpu") + sd[layer_name].grad = None + del part_1, part_2, inner_product, result + torch.cuda.empty_cache() + + + # calculate the loss + avg_loss = sum(losses.values()) / len(self.task_vectors) + log.info(f"Average Loss: {avg_loss}, Total Loss: {sum(losses.values())}") + del losses + + for name, param in merged_model.named_parameters(): + param.grad = None + merged_model.eval() + + return gradients + + def frank_wolfe_selection(self, gradients, checkpoints, model_to_merge_names=[], type='task', num_models=4): + # min_models: list of min_model_dicts; min_model_names: list of model names; min_inner_products: list of inner products; log_dicts: dict of inner products per model + assert type in ['task', 'layer'], f"Unsupported FW selection type: {type}, supported types are ['task', 'layer']" + + inner_products = [] + models = [] + model_names = [] + log_dict = {} + if type == 'task': + for model_name, model_to_merge in checkpoints.items(): + model_to_merge = model_to_merge.to('cpu').state_dict() + inner_product_sum = 0 + for param_name, param_value in model_to_merge.items(): + # caclulate consine similarity + if param_name not in gradients: + continue + grad = gradients[param_name] + ckpt = model_to_merge[param_name] + param_alignment = torch.dot(grad.flatten(), ckpt.flatten()) / (torch.norm(grad) * torch.norm(ckpt)) + inner_product_sum += param_alignment + + inner_products.append(inner_product_sum) + models.append(model_to_merge) + model_names.append(model_name) + log_dict[model_name] = inner_product_sum.item() + + # if inner_product_sum < min_inner_product and model_name not in model_to_merge_names: + # min_inner_product = inner_product_sum + # min_model = deepcopy(model_to_merge) + # min_model_name = model_name + # get smallest k model indices + min_inner_products = [] + min_models = [] + min_model_names = [] + arr = np.array(inner_products) + num_models = 4 + indices = np.argpartition(arr, num_models)[:num_models] + for i in indices: + min_inner_products.append(arr[i]) + min_models.append(deepcopy(models[i])) + min_model_names.append(model_names[i]) + print("models: ", min_model_names) + else: + param_candidates = defaultdict(list) # param_name -> list of (cos_sim, model_name, tensor) + + # Collect all cosine similarities for each layer + for model_name, model_to_merge in checkpoints.items(): + model_to_merge = model_to_merge.to('cpu').state_dict() + for param_name, param_value in model_to_merge.items(): + if param_name not in gradients: + grad = torch.zeros_like(param_value) + else: + grad = gradients[param_name] + ckpt = param_value + denom = torch.norm(grad) * torch.norm(ckpt) + if denom == 0: + param_alignment = torch.tensor(float(0)) + else: + param_alignment = torch.dot(grad.flatten(), ckpt.flatten()) / denom + param_candidates[param_name].append((param_alignment.item(), model_name, param_value)) + + # Select top-k for each layer + min_models = {} # list of dicts: one dict per selected model (for each param) + min_model_names = defaultdict(list) # list of lists: one list per param, containing top-k model names + min_inner_products = defaultdict(list) # list of lists: one list per param, containing top-k similarities + log_dict = {model_name: 0 for model_name in checkpoints.keys()} + + for param_name, candidates in param_candidates.items(): + # Sort by lowest cosine similarity + sorted_candidates = sorted(candidates, key=lambda x: x[0]) + top_k = sorted_candidates[:num_models] + + for cos_sim, model_name, param_tensor in top_k: + min_models.setdefault(param_name, []).append(param_tensor) + min_model_names[param_name].append(model_name) + min_inner_products[param_name].append(cos_sim) + log_dict[model_name] += 1 + + return min_models, min_model_names, min_inner_products, log_dict + + + + def run(self, modelpool: HuggingFaceClipVisionPool): + log.info("Fusing models using FW merging.") + self.modelpool = modelpool + tasks = self.tasks if self.tasks else self.modelpool.model_names + self.log_hyperparams(self.config) + self.on_frank_wolfe_iteration_start() + + assert modelpool.has_pretrained, "Pretrained model is required." + finetuned_models = {name: modelpool.load_model(name) for name in modelpool.model_names[:self.max_num_models]} + pretrained_model = modelpool.load_model("_pretrained_") + + if self.init_weight == 'base' or self.init_weight == '': + merged_model = modelpool.load_model("_pretrained_") + else: + log.info("Initializing the merged model with the initial weight") + if isinstance(self.init_weight, str): + # self.config.weights is a path to a saved tensor + layer_wise_weight = load_tensor_from_file(self.init_weight) + else: + raise ValueError(f"Unsupported weights format: {self.init_weight}") + + layerwise_merged_model = LayerWiseMergedModel( + layer_wise_weight=layer_wise_weight, + pretrained_model=pretrained_model, + finetuned_models=list(finetuned_models.values())[:self.max_num_models], + clamp_weights=False, + tie_weights=True, + strict=False, + ).cuda() + merged_model = layerwise_merged_model.merge_and_unload() + + self.initial_model = modelpool.load_model("_pretrained_") + self.set_requires_grad(merged_model, self.initial_model) + # initial_model.load_state_dict(deepcopy(merged_model.state_dict())) + # finetuned_models['initial'] = initial_model + + # calculate projection + task_models = finetuned_models.values() + self.projection, self.task_vectors = self.calculate_projection(pretrained_model, task_models) + for step_idx in ( + pbar := tqdm( + range(self.max_iters if not self.is_debug_mode else 1), + ("[DEBUG MODE] " if self.is_debug_mode else "") + + "Frank-Wolfe Merging", + dynamic_ncols=True, + ) + ): + torch.cuda.empty_cache() + # Find the task vector with the most alignment to the gradient + models_dict_to_merge = [] + model_to_merge_names = [] if self.granularity == 'task' else {name: [] for name in merged_model.state_dict().keys()} + inner_products = [] + + torch.set_grad_enabled(True) + torch.cuda.empty_cache() + # calculate gradient once, loss is global + gradients = self.frank_wolfe_iteration(merged_model.cuda()) + torch.set_grad_enabled(False) + grad_norm = torch.norm(torch.stack([torch.norm(g) for g in gradients.values()])) + + # select number of tasks of models + # min_models: list of min_model_dicts; min_model_names: list of model names; min_inner_products: list of inner products; log_dict: dict of inner products per model + min_models, min_model_names, min_inner_products, log_dict = self.frank_wolfe_selection(gradients, finetuned_models, model_to_merge_names, type=self.granularity, num_models=len(tasks)) + if self.granularity == 'task': + model_to_merge_names = min_model_names + else: + for model_i in min_model_names: + for param_name, model_name in zip(gradients.keys(), model_i): + model_to_merge_names[param_name].append(model_name) + models_dict_to_merge = min_models + inner_products = min_inner_products + + for task in tasks: + log.info(f"Task: {task}, Inner Products: {log_dict[task]}") + + + # print iteration information + log.info(f"Iteration {step_idx+1}, Task Vector: {model_to_merge_names}, Gradient Norm: {grad_norm:.6f}, Inner Products: {inner_products}") + + if self.merge_fn == 'adamerging': + models_to_merge = [modelpool.load_model('_pretrained_').to("cpu") for _ in range(len(models_dict_to_merge))] + layer_wise_weight = get_layer_wise_weights( + num_models=len(models_to_merge), + num_layers=len( + tuple( + filter(lambda p: p.requires_grad, models_to_merge[0].parameters()) + ) + ), + init_values=self.ada_coeff if step_idx > 0 else 0.3, + ) + for model_to_merge, model_to_merge_dict in zip(models_to_merge, models_dict_to_merge): + model_to_merge.load_state_dict(model_to_merge_dict) + layerwise_merged_model = LayerWiseMergedModel( + layer_wise_weight=layer_wise_weight.to("cpu"), + pretrained_model=merged_model.to("cpu"), + finetuned_models=models_to_merge, + clamp_weights=False, + tie_weights=True, + strict=False, + ).cuda() + torch.cuda.empty_cache() + torch.set_grad_enabled(True) + layerwise_merged_model = self.run_adamerging(layerwise_merged_model) + torch.set_grad_enabled(False) + with torch.no_grad(): + merged_model = layerwise_merged_model.merge_and_unload() + self.set_requires_grad(merged_model, self.initial_model) + del models_to_merge, layerwise_merged_model, layer_wise_weight, models_dict_to_merge + torch.cuda.empty_cache() + else: + step = 2 / (step_idx + 2) * self.step_size if step_idx > 0 else 1 + merged_model = task_arithmetic_merge(merged_model.to('cpu'), models_dict_to_merge, 0.3*step) + del models_dict_to_merge + + torch.set_grad_enabled(False) + merged_model = merged_model.cuda().eval() + return merged_model + + def set_requires_grad(self, merged_model, initial_model): + for name, param in initial_model.named_parameters(): + for n, p in merged_model.named_parameters(): + if name == n: + p.requires_grad = param.requires_grad diff --git a/fusion_bench/models/wrappers/layer_wise_fusion_fw_am.py b/fusion_bench/models/wrappers/layer_wise_fusion_fw_am.py new file mode 100644 index 00000000..5bf50d43 --- /dev/null +++ b/fusion_bench/models/wrappers/layer_wise_fusion_fw_am.py @@ -0,0 +1,302 @@ +import functools +import logging +from copy import deepcopy +from typing import ( # noqa: F401 + Any, + Callable, + Dict, + Generic, + Iterator, + List, + Optional, + TypeVar, +) + +import torch +from torch import Tensor, nn +from torch.func import functional_call + +from fusion_bench.models.utils import del_attr, get_attr, set_attr +from fusion_bench.utils.type import StateDictType, TorchModelType + +__all__ = ["get_layer_wise_weights", "fuse_weights", "LayerWiseMergedModel"] + +log = logging.getLogger(__name__) + + +def get_layer_wise_weights( + num_models: int, + num_layers: int, + init_values: float = None, + dtype: torch.dtype = torch.float32, +): + """ + Return a tensor of layer-wise weights for the given number of models and layers. + + Args: + num_models (int): The number of models to fuse. + num_layers (int): The number of layers in each model. + init_values (float, optional): The initial value for each weight. Defaults to 1.0 / num_models. + dtype (torch.dtype): dtype of weights. This should be the same with model dtype. + + Returns: + Tensor: A tensor of shape (num_models, num_layers) containing the layer-wise weights. + """ + assert num_models >= 1, f"num_models must be >= 1, got {num_models}" + assert num_layers >= 1, f"num_layers must be >= 1, got {num_layers}" + if init_values is None: + init_values = 1.0 / num_models + return torch.full((num_models, num_layers), init_values, dtype=dtype) + + +def _fuse_weights(layer_wise_weight: Tensor, tensors: List[Tensor]): + """ + Fuse the layer-wise weights with the given state dictionaries. + + Args: + layer_wise_weight (Tensor): A tensor of shape (num_models,) containing the layer-wise weights. + state_dicts (List[Tensor]): A list of state dictionaries, each containing the weights for a single layer. + + Returns: + Tensor: A tensor of shape (num_params,) containing the fused weights. + """ + assert len(layer_wise_weight) == len( + tensors + ), f"layer_wise_weight.shape={layer_wise_weight.shape}, len(tensors)={len(tensors)}" + return sum( + layer_wise_weight[i] * w.to(layer_wise_weight.device) + for i, w in enumerate(tensors) + ) + + +def fuse_weights( + layer_wise_weight: Tensor, state_dicts: List[StateDictType] +) -> StateDictType: + """ + Fuse the weights of multiple models using layer-wise fusion. + + Args: + layer_wise_weight (Tensor): A tensor of shape (num_models, num_layers) representing the weight of each layer for each model. + state_dicts (List[StateDict]): A list of state dictionaries, one for each model. + + Returns: + A dictionary mapping each weight tensor key to the fused weight tensor. + """ + num_models = len(state_dicts) + num_layers = len(state_dicts[0]) + assert layer_wise_weight.shape == ( + num_models, + num_layers, + ), f"layer_wise_weight.shape={layer_wise_weight.shape}, expected (num_models, num_layers): ({num_models}, {num_layers})" + return { + k: _fuse_weights( + layer_wise_weight[:, i], [state_dict[k] for state_dict in state_dicts] + ) + for i, k in enumerate(state_dicts[0].keys()) + } + + +class LayerWiseMergedModel(nn.Module, Generic[TorchModelType]): + _merged_state_dict: StateDictType = None + + def __init__( + self, + layer_wise_weight: Tensor, + pretrained_model: TorchModelType, + finetuned_models: List[TorchModelType], + clamp_weights: bool = True, + tie_weights: bool = False, + strict: bool = True, + sparsity_ratio: Optional[float] = None, + normalized_merging_weights: bool = False, + ): + R""" + This class wraps a pretrained model and a list of finetuned models, and merges the weights of the finetuned models into the pretrained model using layer-wise fusion. + + Reference: + + (ICLR 2024) Yang E, Wang Z, Shen L, et al. Adamerging: Adaptive model merging for multi-task learning. https://arxiv.org/pdf/2310.02575 + + Args: + layer_wise_weight (Tensor): A tensor of shape (num_models, num_layers) representing the weight of each layer for each model. + pretrained_model (nn.Module): The pretrained model to merge the weights into. + finetuned_models (List[nn.Module]): A list of finetuned models to merge the weights from. This should have the same architecture as the pretrained model. We use these models to compute the task vectors. + clamp_weights (bool, optional): If True, the layer-wise weights will be clamped to [0, 1]. Defaults to True. + tie_weights (bool, optional): This option passes the `tie_weights` argument to the `functional_call` function. Defaults to False. + strict (bool, optional): This option passes the `strict` argument to the `functional_call` function. Defaults to True. + sparsity_ratio (float, optional): If `sparsity_ratio` is provided, the task vector will be pruned before merging. A high spasity level can save the memory usage during merging. + normalized_merging_weights (bool, optional): If True, the layer-wise weights will be normalized for each layer, so that the sum of weights across models for each layer is 1. Defaults to False. + """ + super().__init__() + self.clamp_weights = clamp_weights + self.tie_weights = tie_weights + self.strict = strict + self.sparsity_ratio = sparsity_ratio + self.nromalized_merging_weights = normalized_merging_weights + + self.merge_weight = nn.Parameter(layer_wise_weight, requires_grad=True) + + for name, param in pretrained_model.named_parameters(): + if not param.requires_grad: + for m in finetuned_models: + del_attr(m, name.split(".")) + else: + for m in finetuned_models: + get_attr(m, name.split(".")).data = ( + get_attr(m, name.split(".")) - param + ) + + self.pretrained_model = pretrained_model.requires_grad_(False) + for m in finetuned_models: + m.requires_grad_(False) + + self.task_vectors = nn.ModuleList(finetuned_models) + + # if `sparisty_ratio` is given, pruning the task vectors. + if sparsity_ratio is not None: + from fusion_bench.method.pruning.prune_utils import ( + unstructured_magnitude_prune_, + ) + + for name, param in self.task_vectors.named_parameters(): + if param.dim() != 2: + continue + print(f"pruning {name}") + pruned_param = unstructured_magnitude_prune_( + param.data.clone(), torch.abs, sparsity_ratio=sparsity_ratio + ) + set_attr( + self.task_vectors, + name.split("."), + nn.Parameter(pruned_param.to_sparse(), requires_grad=False), + ) + + @property + def forward_model(self): + return functools.partial( + functional_call, + self.pretrained_model, + self._merged_state_dict, + tie_weights=self.tie_weights, + strict=self.strict, + ) + + def merge_and_unload(self, task_vector_mask: Optional[Dict[str, Tensor]] = None): + self.merge_weights(task_vector_mask=task_vector_mask) + self.pretrained_model.load_state_dict(self._merged_state_dict) + return self.pretrained_model + + def merge_and_copy(self, task_vector_mask: Optional[Dict[str, Tensor]] = None) -> nn.Module: + self.merge_weights(task_vector_mask=task_vector_mask) + # do not unload the model, so that we can use the merged weights for further training + model = deepcopy(self.pretrained_model) + model.load_state_dict(self._merged_state_dict) + return model + + def state_dict(self): + self.merge_weights() + return self._merged_state_dict + + def merge_weights(self, task_vector_mask: Optional[Dict[str, Tensor]] = None): + """ + Merges the weights of the model. + Call this after each update step. + """ + if self.clamp_weights: + layer_wise_weight = self.merge_weight.clamp(0, 1) + else: + layer_wise_weight = self.merge_weight + if self.nromalized_merging_weights: + # normalize the weights for each layer, so that the sum of weights across models for each layer is 1. + layer_wise_weight = layer_wise_weight.softmax(dim=0) + + state_dict = self.pretrained_model.state_dict(keep_vars=True) + # shape of layer_wise_weight: (num_models, num_layers) + for weight, task_vector in zip(layer_wise_weight, self.task_vectors): + assert len(list(task_vector.named_parameters())) == weight.size(0) + if task_vector_mask is not None: + weight = [ + w * task_vector_mask[name] + for w, (name, param) in zip(weight, task_vector.named_parameters()) + ] + for w, (name, param) in zip(weight, task_vector.named_parameters()): + state_dict[name] = state_dict[name] + param * w + self._merged_state_dict = state_dict + + return state_dict + + def forward(self, *args, **kwargs): + if self._merged_state_dict is None: + self.merge_weights() + return self.forward_model(args=args, kwargs=kwargs) + + # def __getattr__(self, name: str) -> Any: + # try: + # return super().__getattr__(name) + # except AttributeError: + # attr = getattr(self.model, name) + # if isinstance(attr, Callable): + # warnings.warn( + # f"forwarding `{name}` to the underlying model", UserWarning + # ) + # return attr + + # def __setattr__(self, name: str, value: Any) -> None: + # try: + # super().__setattr__(name, value) + # except AttributeError: + # setattr(self.model, name, value) + + +def merge_weights(module: nn.Module): + """ + Merges the weights for all `LayerWiseMergedModel` instances within the given module. + + Args: + module (nn.Module): The module to process. + """ + if isinstance(module, LayerWiseMergedModel): + module.merge_weights() + return + else: + for submodule in module.children(): + merge_weights(submodule) + + +def merge_and_unload(module: nn.Module): + """ + Merges and unloads all `LayerWiseMergedModel` instances within the given module. + + Args: + module (nn.Module): The module to process. + + Returns: + nn.Module: The updated module with merged weights. + """ + if isinstance(module, LayerWiseMergedModel): + return module.merge_and_unload() + else: + for name, submodule in module.named_children(): + need_merge = isinstance(submodule, LayerWiseMergedModel) + submodule = merge_and_unload(submodule) + if need_merge: + setattr(module, name, submodule) + return module + + +def fix_other_parts(module: nn.Module): + """ + Sets all parameters in the module to not require gradients, except for the merge weights + in `LayerWiseMergedModel` instances. + + Args: + module (nn.Module): The module to process. + + Returns: + nn.Module: The module with updated parameter requirements. + """ + module.requires_grad_(False) + for submodule in module.modules(): + if isinstance(submodule, LayerWiseMergedModel): + submodule.merge_weight.requires_grad_(True) + return module