Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions config/method/fw_merging/fw_hard_am.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
_target_: fusion_bench.method.FrankWolfeHardAdamergingAlgorithm
merge_fn: task_arithmetic
max_iters: 10
step_size: 0.1
dataset_size: 100
tasks: []
init_weight: base
loss_fn: cross_entropy
scaling_factor: 0.3
max_num_models: 100
granularity: task
init_layer_weights: 0.0
ada_merge: True
ada_max_steps: 1000
ada_optimizer: adam
ada_lr: 1e-3
16 changes: 16 additions & 0 deletions config/method/fw_merging/fw_hard_am_loss_approx.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
_target_: fusion_bench.method.FrankWolfeHardAdamergingLossApproxAlgorithm
merge_fn: task_arithmetic
max_iters: 10
step_size: 0.1
dataset_size: 100
tasks: []
init_weight: base
loss_fn: cross_entropy
scaling_factor: 0.3
max_num_models: 100
granularity: task
init_layer_weights: 0.0
ada_merge: True
ada_max_steps: 1000
ada_optimizer: adam
ada_lr: 1e-3
11 changes: 11 additions & 0 deletions config/method/fw_merging/fw_hard_loss_approx.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
_target_: fusion_bench.method.FrankWolfeHardLossApproxAlgorithm
merge_fn: task_arithmetic
max_iters: 10
step_size: 0.1
dataset_size: 100
tasks: []
init_weight:
loss_fn: cross_entropy
scaling_factor: 0.3
max_num_models: 100
granularity: task
12 changes: 12 additions & 0 deletions config/method/fw_merging/fw_soft_loss_approx.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
_target_: fusion_bench.method.FrankWolfeSoftLossApproxAlgorithm
init_weight:
max_iters: 10
merge_fn: 'adamerging'
tasks:
ada_iters: 500
dataset_size: 100
ada_coeff: 1e-8
step_size: 0.1
max_num_models: 100
granularity: task
ada_loss: entropy_loss
4 changes: 2 additions & 2 deletions fusion_bench/method/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@
"SparseLoForLlama",
"PCPSparseLoForLlama",
],
"fw_merging": ["FrankWolfeHardAlgorithm", "FrankWolfeSoftAlgorithm"],
"fw_merging": ["FrankWolfeHardAlgorithm", "FrankWolfeSoftAlgorithm", "FrankWolfeHardAdamergingAlgorithm", "FrankWolfeSoftLossApproxAlgorithm", "FrankWolfeHardLossApproxAlgorithm", "FrankWolfeHardAdamergingLossApproxAlgorithm"],
}


Expand Down Expand Up @@ -182,7 +182,7 @@
from .ties_merging import TiesMergingAlgorithm
from .we_moe import CLIPWeightEnsemblingMoEAlgorithm
from .weighted_average import WeightedAverageAlgorithm, WeightedAverageForLLama
from .fw_merging import FrankWolfeHardAlgorithm, FrankWolfeSoftAlgorithm
from .fw_merging import FrankWolfeHardAlgorithm, FrankWolfeSoftAlgorithm, FrankWolfeHardAdamergingAlgorithm, FrankWolfeSoftLossApproxAlgorithm, FrankWolfeHardLossApproxAlgorithm, FrankWolfeHardAdamergingLossApproxAlgorithm

else:
sys.modules[__name__] = LazyImporter(
Expand Down
6 changes: 5 additions & 1 deletion fusion_bench/method/fw_merging/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,6 @@
from .fw_hard import FrankWolfeHardAlgorithm
from .fw_soft import FrankWolfeSoftAlgorithm
from .fw_soft import FrankWolfeSoftAlgorithm
from .fw_hard_am import FrankWolfeHardAdamergingAlgorithm
from .fw_soft_loss_approx import FrankWolfeSoftLossApproxAlgorithm
from .fw_hard_loss_approx import FrankWolfeHardLossApproxAlgorithm
from .fw_hard_am_loss_approx import FrankWolfeHardAdamergingLossApproxAlgorithm
Loading