Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
9e55657
Initial commit
TristenStegall Apr 7, 2024
6a0f8a6
Merge remote-tracking branch 'upstream/main' into Sample-Sparsification
TristenStegall Apr 7, 2024
5709da6
Initial implementation
TristenStegall Apr 7, 2024
2e87e87
Semi-complete implementation
TristenStegall Apr 7, 2024
9f06353
Fix negative power
TristenStegall Apr 7, 2024
bb1c74f
better handling of 0 powers (?)
TristenStegall Apr 7, 2024
848d72d
Revert "better handling of 0 powers (?)"
TristenStegall Apr 7, 2024
1a35512
Fix handling of Sparse tensors
TristenStegall Apr 7, 2024
638d785
Merge branch 'main' into Sample-Sparsification
MonsterAzi Apr 7, 2024
df86fdc
Add rescaling
TristenStegall Apr 7, 2024
b42ed87
Clean up code
TristenStegall Apr 7, 2024
56f12d0
Fixed edge cases
TristenStegall Apr 7, 2024
e0e97d9
reformatted
TristenStegall Apr 8, 2024
b71e908
reformat
TristenStegall Apr 8, 2024
d99bd82
reformat
TristenStegall Apr 8, 2024
2cf7243
Merge branch 'Sample-Sparsification' of https://github.com/MonsterAzi…
TristenStegall Apr 8, 2024
192af12
Basic Ranked implementation
TristenStegall Apr 9, 2024
b066209
Merge branch 'main' into Sample-Sparsification
MonsterAzi Apr 10, 2024
2c214ef
Working implementation
TristenStegall Apr 10, 2024
480e9fd
Working implementation
TristenStegall Apr 10, 2024
73bd19b
Merge branch 'Sample-Sparsification' of https://github.com/MonsterAzi…
TristenStegall Apr 10, 2024
3f6207e
Added Smoothing
TristenStegall Apr 13, 2024
bd6469b
Undo some local stuff
TristenStegall Apr 13, 2024
dc381aa
Reformatted
TristenStegall Apr 13, 2024
e90e2f6
again :P
TristenStegall Apr 13, 2024
15b699d
again again :P
TristenStegall Apr 13, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions mergekit/merge_methods/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,30 +39,50 @@ def get(method: str) -> MergeMethod:
sparsification_method=None,
default_normalize=False,
default_rescale=False,
default_smooth=False,
)
elif method == "ties":
return GeneralizedTaskArithmeticMerge(
consensus_method=ConsensusMethod.sum,
sparsification_method=SparsificationMethod.magnitude,
default_normalize=True,
default_rescale=False,
default_smooth=False,
)
elif method == "dare_ties":
return GeneralizedTaskArithmeticMerge(
consensus_method=ConsensusMethod.sum,
sparsification_method=SparsificationMethod.random,
default_normalize=False,
default_rescale=True,
default_smooth=False,
)
elif method == "dare_linear":
return GeneralizedTaskArithmeticMerge(
consensus_method=None,
sparsification_method=SparsificationMethod.random,
default_normalize=False,
default_rescale=True,
default_smooth=False,
)
elif method == "model_stock":
return ModelStockMerge()
elif method == "sample_ties":
return GeneralizedTaskArithmeticMerge(
consensus_method=ConsensusMethod.sum,
sparsification_method=SparsificationMethod.sample,
default_normalize=False,
default_rescale=True,
default_smooth=False,
)
elif method == "ranked_ties":
return GeneralizedTaskArithmeticMerge(
consensus_method=ConsensusMethod.sum,
sparsification_method=SparsificationMethod.ranked,
default_normalize=False,
default_rescale=True,
default_smooth=False,
)
raise RuntimeError(f"Unimplemented merge method {method}")


Expand Down
7 changes: 7 additions & 0 deletions mergekit/merge_methods/generalized_task_arithmetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ class GeneralizedTaskArithmeticMerge(MergeMethod, BaseModel, frozen=True):
sparsification_method: Optional[SparsificationMethod]
default_normalize: bool
default_rescale: bool
default_smooth: bool

def parameters(self) -> List[ConfigParameterDef]:
return [
Expand All @@ -49,6 +50,9 @@ def parameters(self) -> List[ConfigParameterDef]:
ConfigParameterDef(
name="rescale", required=False, default_value=self.default_rescale
),
ConfigParameterDef(
name="smooth", required=False, default_value=self.default_smooth
),
]

def tensor_parameters(self) -> List[ConfigParameterDef]:
Expand All @@ -73,6 +77,7 @@ def make_task(
int8_mask=parameters["int8_mask"],
normalize=parameters["normalize"],
rescale=parameters["rescale"],
smooth=parameters["smooth"],
out_tensor_name=output_weight.name,
)

Expand All @@ -86,6 +91,7 @@ class GTATask(Task[torch.Tensor]):
int8_mask: bool
normalize: bool
rescale: bool
smooth: bool

def uses_accelerator(self) -> bool:
return True
Expand Down Expand Up @@ -116,6 +122,7 @@ def execute(
density=tv_info["density"],
method=self.method.sparsification_method,
rescale=self.rescale,
smooth=self.smooth,
)

deltas = torch.stack([tv["delta"] for tv in tvs], dim=0)
Expand Down
83 changes: 82 additions & 1 deletion mergekit/sparsify.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
class SparsificationMethod(str, Enum):
magnitude = "magnitude"
random = "random"
sample = "sample"
ranked = "ranked"


def rescale_sum(tensor: torch.Tensor, mask: torch.Tensor):
Expand Down Expand Up @@ -78,15 +80,94 @@ def bernoulli(tensor: torch.Tensor, density: float, rescale: bool) -> torch.Tens
return res.to(tensor.dtype)


def ranked(
tensor: torch.Tensor, density: float, rescale: bool, smooth: bool
) -> torch.Tensor:
if density >= 1:
return tensor

# Handle if the tensor is already sparser than the density (In line with trimming).
if ((tensor.abs() ** 0.0).mean() / (tensor.abs() ** 0.0).max()) <= density:
return tensor

work_dtype = tensor.dtype
size = int(tensor.view(-1).shape[0])

mask = torch.zeros_like(tensor)
w = tensor.abs().view(-1)
if w.device.type == "cpu":
w = w.float()
sort = torch.argsort(w, descending=True)

mask.view(-1)[sort] = torch.linspace(
1, 0, steps=size, device=w.device.type, dtype=work_dtype
).pow((1 / density) - 1)
if smooth:
mask = torch.bernoulli(mask)

if not rescale:
res = rescale_sum(tensor, mask)
else:
res = tensor * mask

return res


def sample(
tensor: torch.Tensor, density: float, rescale: bool, smooth: bool
) -> torch.Tensor:
"""Samples the tensor as it's own mask, then shifts mean to fit density."""
if density >= 1 or tensor.abs().max() == 0.0 or tensor.abs().max() == float("inf"):
return tensor

# Handle if the tensor is already sparser than the density (In line with trimming).
if ((tensor.abs() ** 0.0).mean() / (tensor.abs() ** 0.0).max()) <= density:
return tensor

if (tensor.device.type != "cpu") or tensor.dtype == torch.bfloat16:
work_dtype = tensor.dtype
else:
# torch.bernoulli not implemented for float16 on CPU, upcast to float32
work_dtype = torch.float32

# Find the power that makes the distribution fit the density
i = 0
power = 1.0
avg = tensor.abs().mean() / tensor.abs().max()
while (avg - density) <= 1e-5 and i < 15:
intermediate = tensor.abs() ** power
avg = intermediate.mean() / intermediate.max()
power += avg - density
if power < 0:
power = 0
i += 1

intermediate = tensor.abs() ** power
mask = (intermediate / intermediate.max()).to(work_dtype)
if not smooth:
mask = torch.bernoulli(mask)

if rescale:
res = rescale_sum(tensor, mask)
else:
res = tensor * mask
return res.to(tensor.dtype)


def sparsify(
tensor: torch.Tensor,
density: float,
method: SparsificationMethod,
rescale: bool = False,
rescale: bool,
smooth: bool,
) -> torch.Tensor:
if method == SparsificationMethod.magnitude:
return magnitude(tensor, density=density, rescale=rescale)
elif method == SparsificationMethod.random:
return bernoulli(tensor, density=density, rescale=rescale)
elif method == SparsificationMethod.sample:
return sample(tensor, density=density, rescale=rescale, smooth=smooth)
elif method == SparsificationMethod.ranked:
return ranked(tensor, density=density, rescale=rescale, smooth=smooth)
else:
raise NotImplementedError(method)