From 9e55657cbfd4d3ba86c91d2714ffbac680fb9832 Mon Sep 17 00:00:00 2001 From: Azazelle Guice Date: Sun, 7 Apr 2024 00:14:00 +0000 Subject: [PATCH 01/21] Initial commit --- mergekit/sparsify.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/mergekit/sparsify.py b/mergekit/sparsify.py index 01239cd3..b688d80e 100644 --- a/mergekit/sparsify.py +++ b/mergekit/sparsify.py @@ -23,6 +23,8 @@ class SparsificationMethod(str, Enum): random = "random" rescaled_random = "rescaled_random" +def sample(): + pass def magnitude(tensor: torch.Tensor, density: float) -> torch.Tensor: """Masks out the smallest values, retaining a proportion of `density`.""" From 5709da69ea698b85f40b0a733dad5497880aa138 Mon Sep 17 00:00:00 2001 From: Azazelle Guice Date: Sun, 7 Apr 2024 00:44:15 +0000 Subject: [PATCH 02/21] Initial implementation --- mergekit/sparsify.py | 24 ++++++++++++++++++++++-- 1 file changed, 22 insertions(+), 2 deletions(-) diff --git a/mergekit/sparsify.py b/mergekit/sparsify.py index b688d80e..33357989 100644 --- a/mergekit/sparsify.py +++ b/mergekit/sparsify.py @@ -23,8 +23,28 @@ class SparsificationMethod(str, Enum): random = "random" rescaled_random = "rescaled_random" -def sample(): - pass +def sample(tensor: torch.Tensor, density: float) -> torch.Tensor: + """Samples the tensor as it's own mask, then shifts mean to fit density.""" + if density >= 1 or tensor.abs().max() == 0: + return tensor + + if (tensor.device.type == "cpu") or tensor.dtype != torch.bfloat16: + # torch.bernoulli not implemented for float16 on CPU, upcast to float32 + origin_type = tensor.dtype + tensor = tensor.to(torch.float32) + + avg = tensor.abs().mean() / tensor.abs().max() + + power = 1.0 + while abs(avg - density) > 1e-5: + power += avg - density + intermediate = tensor.abs()**power + avg = (intermediate.mean() / intermediate.max()) + + mask = torch.bernoulli(intermediate / intermediate.max()) + + tensor *= mask + return tensor.to(origin_type) def magnitude(tensor: torch.Tensor, density: float) -> torch.Tensor: """Masks out the smallest values, retaining a proportion of `density`.""" From 2e87e87a7c4957200a380c88cdd8dc59902caad7 Mon Sep 17 00:00:00 2001 From: Azazelle Guice Date: Sun, 7 Apr 2024 02:54:04 +0000 Subject: [PATCH 03/21] Semi-complete implementation --- examples/ties.yml | 30 ++++++--------- mergekit/merge_methods/__init__.py | 6 +++ mergekit/sparsify.py | 62 +++++++++++++++++++----------- 3 files changed, 57 insertions(+), 41 deletions(-) diff --git a/examples/ties.yml b/examples/ties.yml index 8c5cfe5c..99174d4a 100644 --- a/examples/ties.yml +++ b/examples/ties.yml @@ -1,22 +1,16 @@ models: - - model: psmathur/orca_mini_v3_13b + - model: abacaj/phi-2-super # Best parameters: - density: [1, 0.7, 0.1] # density gradient - weight: 1.0 - - model: garage-bAInd/Platypus2-13B + density: [0.8, 0.3, 0.1] + weight: [1.0, 0.8, 0.5] + - model: rhysjones/phi-2-orange-v2 #Middle parameters: - density: 0.5 - weight: [0, 0.3, 0.7, 1] # weight gradient - - model: WizardLM/WizardMath-13B-V1.0 + density: [0.1, 0.4, 0.1] + weight: [0.2, 0.8, 0.1] + - model: mobiuslabsgmbh/aanaphi2-v0.1 # 2nd End parameters: - density: 0.33 - weight: - - filter: mlp - value: 0.5 - - value: 0 -merge_method: ties -base_model: TheBloke/Llama-2-13B-fp16 -parameters: - normalize: true - int8_mask: true -dtype: float16 + density: [0.1, 0.2, 0.6] + weight: [0.5, 0.8, 1.0] +merge_method: sample_ties +base_model: microsoft/phi-2 +dtype: bfloat16 \ No newline at end of file diff --git a/mergekit/merge_methods/__init__.py b/mergekit/merge_methods/__init__.py index 51589526..c7b650c2 100644 --- a/mergekit/merge_methods/__init__.py +++ b/mergekit/merge_methods/__init__.py @@ -59,6 +59,12 @@ def get(method: str) -> MergeMethod: ) elif method == "model_stock": return ModelStockMerge() + elif method == "sample_ties": + return GeneralizedTaskArithmeticMerge( + consensus_method=ConsensusMethod.sum, + sparsification_method=SparsificationMethod.sample, + default_normalize=False, + ) raise RuntimeError(f"Unimplemented merge method {method}") diff --git a/mergekit/sparsify.py b/mergekit/sparsify.py index 33357989..a0583869 100644 --- a/mergekit/sparsify.py +++ b/mergekit/sparsify.py @@ -22,29 +22,7 @@ class SparsificationMethod(str, Enum): magnitude = "magnitude" random = "random" rescaled_random = "rescaled_random" - -def sample(tensor: torch.Tensor, density: float) -> torch.Tensor: - """Samples the tensor as it's own mask, then shifts mean to fit density.""" - if density >= 1 or tensor.abs().max() == 0: - return tensor - - if (tensor.device.type == "cpu") or tensor.dtype != torch.bfloat16: - # torch.bernoulli not implemented for float16 on CPU, upcast to float32 - origin_type = tensor.dtype - tensor = tensor.to(torch.float32) - - avg = tensor.abs().mean() / tensor.abs().max() - - power = 1.0 - while abs(avg - density) > 1e-5: - power += avg - density - intermediate = tensor.abs()**power - avg = (intermediate.mean() / intermediate.max()) - - mask = torch.bernoulli(intermediate / intermediate.max()) - - tensor *= mask - return tensor.to(origin_type) + sample = "sample" def magnitude(tensor: torch.Tensor, density: float) -> torch.Tensor: """Masks out the smallest values, retaining a proportion of `density`.""" @@ -84,6 +62,42 @@ def bernoulli( res /= density return res.to(tensor.dtype) +def sample(tensor: torch.Tensor, density: float) -> torch.Tensor: + """Samples the tensor as it's own mask, then shifts mean to fit density.""" + if density >= 1 or tensor.abs().max() == 0.0: + return tensor + # print("Original tensor: ", tensor) + if (tensor.device.type == "cpu") or tensor.dtype != torch.bfloat16: + # torch.bernoulli not implemented for float16 on CPU, upcast to float32 + origin_type = tensor.dtype + tensor = tensor.to(torch.float32) + + intermediate = tensor.abs()[tensor.nonzero(as_tuple=True)] + avg = (intermediate.mean() / intermediate.max()).item() + + i = 0 + power = 1.0 + while abs(avg - density) > 2e-4 and i < 15: + if torch.numel(intermediate) < 5: + break + # print("Average: ", avg) + # print("Density: ", density) + # print("Diff: ", avg - density) + power += avg - density + # print("Power: ", power) + intermediate = tensor.abs()[tensor.nonzero(as_tuple=True)]**power + # print("Intermediate tensor: ", intermediate) + avg = (intermediate.mean() / intermediate.max()).item() + i += 1 + + intermediate = tensor.abs()**power + mask = torch.bernoulli(intermediate / intermediate.max()) + # print("Mask: ", mask) + + tensor *= mask + if (tensor.device.type == "cpu") or tensor.dtype != torch.bfloat16: + return tensor.to(origin_type) + return tensor def sparsify( tensor: torch.Tensor, density: float, method: SparsificationMethod @@ -94,5 +108,7 @@ def sparsify( return bernoulli(tensor, density=density, rescale=False) elif method == SparsificationMethod.rescaled_random: return bernoulli(tensor, density=density, rescale=True) + elif method == SparsificationMethod.sample: + return sample(tensor, density=density) else: raise NotImplementedError(method) From 9f063538f35ff68bef6fcc3df582ea7738955cc1 Mon Sep 17 00:00:00 2001 From: Azazelle Guice Date: Sun, 7 Apr 2024 03:15:44 +0000 Subject: [PATCH 04/21] Fix negative power --- mergekit/sparsify.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/mergekit/sparsify.py b/mergekit/sparsify.py index a0583869..15248532 100644 --- a/mergekit/sparsify.py +++ b/mergekit/sparsify.py @@ -80,19 +80,22 @@ def sample(tensor: torch.Tensor, density: float) -> torch.Tensor: while abs(avg - density) > 2e-4 and i < 15: if torch.numel(intermediate) < 5: break - # print("Average: ", avg) - # print("Density: ", density) - # print("Diff: ", avg - density) + print("Average: ", avg) + print("Density: ", density) + print("Diff: ", avg - density) power += avg - density - # print("Power: ", power) + print("Power: ", power) intermediate = tensor.abs()[tensor.nonzero(as_tuple=True)]**power - # print("Intermediate tensor: ", intermediate) + print("Intermediate tensor: ", intermediate) avg = (intermediate.mean() / intermediate.max()).item() i += 1 + if power < 0: + power = 0 + break intermediate = tensor.abs()**power + mask = torch.bernoulli(intermediate / intermediate.max()) - # print("Mask: ", mask) tensor *= mask if (tensor.device.type == "cpu") or tensor.dtype != torch.bfloat16: From bb1c74fb33c91d303f912cca96e86467593a5275 Mon Sep 17 00:00:00 2001 From: Azazelle Guice Date: Sun, 7 Apr 2024 04:22:12 +0000 Subject: [PATCH 05/21] better handling of 0 powers (?) --- mergekit/sparsify.py | 21 +++++++++------------ 1 file changed, 9 insertions(+), 12 deletions(-) diff --git a/mergekit/sparsify.py b/mergekit/sparsify.py index 15248532..532cee9e 100644 --- a/mergekit/sparsify.py +++ b/mergekit/sparsify.py @@ -67,31 +67,28 @@ def sample(tensor: torch.Tensor, density: float) -> torch.Tensor: if density >= 1 or tensor.abs().max() == 0.0: return tensor # print("Original tensor: ", tensor) + origin_type = tensor.dtype if (tensor.device.type == "cpu") or tensor.dtype != torch.bfloat16: # torch.bernoulli not implemented for float16 on CPU, upcast to float32 - origin_type = tensor.dtype tensor = tensor.to(torch.float32) - intermediate = tensor.abs()[tensor.nonzero(as_tuple=True)] + intermediate = tensor.abs() avg = (intermediate.mean() / intermediate.max()).item() i = 0 power = 1.0 while abs(avg - density) > 2e-4 and i < 15: - if torch.numel(intermediate) < 5: - break - print("Average: ", avg) - print("Density: ", density) - print("Diff: ", avg - density) + # print("Average: ", avg) + # print("Density: ", density) + # print("Diff: ", avg - density) power += avg - density - print("Power: ", power) - intermediate = tensor.abs()[tensor.nonzero(as_tuple=True)]**power - print("Intermediate tensor: ", intermediate) + # print("Power: ", power) + intermediate = tensor.abs()**power + # print("Intermediate tensor: ", intermediate) avg = (intermediate.mean() / intermediate.max()).item() i += 1 if power < 0: - power = 0 - break + return tensor.to(origin_type) intermediate = tensor.abs()**power From 848d72dcd8ace9f8c4bad5eed75c8c54a9ef11e3 Mon Sep 17 00:00:00 2001 From: Azazelle Guice Date: Sun, 7 Apr 2024 05:07:27 +0000 Subject: [PATCH 06/21] Revert "better handling of 0 powers (?)" This reverts commit bb1c74fb33c91d303f912cca96e86467593a5275. --- mergekit/sparsify.py | 21 ++++++++++++--------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/mergekit/sparsify.py b/mergekit/sparsify.py index 532cee9e..15248532 100644 --- a/mergekit/sparsify.py +++ b/mergekit/sparsify.py @@ -67,28 +67,31 @@ def sample(tensor: torch.Tensor, density: float) -> torch.Tensor: if density >= 1 or tensor.abs().max() == 0.0: return tensor # print("Original tensor: ", tensor) - origin_type = tensor.dtype if (tensor.device.type == "cpu") or tensor.dtype != torch.bfloat16: # torch.bernoulli not implemented for float16 on CPU, upcast to float32 + origin_type = tensor.dtype tensor = tensor.to(torch.float32) - intermediate = tensor.abs() + intermediate = tensor.abs()[tensor.nonzero(as_tuple=True)] avg = (intermediate.mean() / intermediate.max()).item() i = 0 power = 1.0 while abs(avg - density) > 2e-4 and i < 15: - # print("Average: ", avg) - # print("Density: ", density) - # print("Diff: ", avg - density) + if torch.numel(intermediate) < 5: + break + print("Average: ", avg) + print("Density: ", density) + print("Diff: ", avg - density) power += avg - density - # print("Power: ", power) - intermediate = tensor.abs()**power - # print("Intermediate tensor: ", intermediate) + print("Power: ", power) + intermediate = tensor.abs()[tensor.nonzero(as_tuple=True)]**power + print("Intermediate tensor: ", intermediate) avg = (intermediate.mean() / intermediate.max()).item() i += 1 if power < 0: - return tensor.to(origin_type) + power = 0 + break intermediate = tensor.abs()**power From 1a355125be58f38ffdffa4fb716cd460adb3f188 Mon Sep 17 00:00:00 2001 From: Azazelle Guice Date: Sun, 7 Apr 2024 19:44:15 +0000 Subject: [PATCH 07/21] Fix handling of Sparse tensors --- mergekit/sparsify.py | 41 +++++++++++++++++++++-------------------- 1 file changed, 21 insertions(+), 20 deletions(-) diff --git a/mergekit/sparsify.py b/mergekit/sparsify.py index 15248532..719c4d0a 100644 --- a/mergekit/sparsify.py +++ b/mergekit/sparsify.py @@ -66,34 +66,35 @@ def sample(tensor: torch.Tensor, density: float) -> torch.Tensor: """Samples the tensor as it's own mask, then shifts mean to fit density.""" if density >= 1 or tensor.abs().max() == 0.0: return tensor - # print("Original tensor: ", tensor) + if (tensor.device.type == "cpu") or tensor.dtype != torch.bfloat16: # torch.bernoulli not implemented for float16 on CPU, upcast to float32 origin_type = tensor.dtype tensor = tensor.to(torch.float32) - intermediate = tensor.abs()[tensor.nonzero(as_tuple=True)] - avg = (intermediate.mean() / intermediate.max()).item() - - i = 0 - power = 1.0 - while abs(avg - density) > 2e-4 and i < 15: - if torch.numel(intermediate) < 5: - break - print("Average: ", avg) - print("Density: ", density) - print("Diff: ", avg - density) + intermediate = tensor.abs() + avg = intermediate.mean() / intermediate.max() + + # Handle if the tensor is already sparser than the density (In line with trimming). + if ((intermediate**0.0).mean() / (intermediate**0.0).max()) <= density: + if (tensor.device.type == "cpu") or tensor.dtype != torch.bfloat16: + return tensor.to(origin_type) + return tensor + + # Find the power that makes the distribution fit the density + i = 0; power = 1.0 + while i < 15: + # print("Average: ", avg) + # print("Density: ", density) + # print("Diff: ", avg - density) power += avg - density - print("Power: ", power) - intermediate = tensor.abs()[tensor.nonzero(as_tuple=True)]**power - print("Intermediate tensor: ", intermediate) - avg = (intermediate.mean() / intermediate.max()).item() - i += 1 + # print("Power: ", power) if power < 0: power = 0 - break - - intermediate = tensor.abs()**power + intermediate = tensor.abs()**power + # print("Intermediate: ", intermediat) + avg = intermediate.mean() / intermediate.max() + i += 1 mask = torch.bernoulli(intermediate / intermediate.max()) From df86fdc3bbcc15f251e2fc44bf1cb63589bc8889 Mon Sep 17 00:00:00 2001 From: Azazelle Guice Date: Sun, 7 Apr 2024 19:49:44 +0000 Subject: [PATCH 08/21] Add rescaling --- mergekit/sparsify.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/mergekit/sparsify.py b/mergekit/sparsify.py index 9663961c..74fab497 100644 --- a/mergekit/sparsify.py +++ b/mergekit/sparsify.py @@ -77,13 +77,14 @@ def bernoulli(tensor: torch.Tensor, density: float, rescale: bool) -> torch.Tens res /= density return res.to(tensor.dtype) -def sample(tensor: torch.Tensor, density: float) -> torch.Tensor: +def sample(tensor: torch.Tensor, density: float, rescale: bool) -> torch.Tensor: """Samples the tensor as it's own mask, then shifts mean to fit density.""" if density >= 1 or tensor.abs().max() == 0.0: return tensor if (tensor.device.type == "cpu") or tensor.dtype != torch.bfloat16: # torch.bernoulli not implemented for float16 on CPU, upcast to float32 + flag = True origin_type = tensor.dtype tensor = tensor.to(torch.float32) @@ -92,7 +93,7 @@ def sample(tensor: torch.Tensor, density: float) -> torch.Tensor: # Handle if the tensor is already sparser than the density (In line with trimming). if ((intermediate**0.0).mean() / (intermediate**0.0).max()) <= density: - if (tensor.device.type == "cpu") or tensor.dtype != torch.bfloat16: + if flag: return tensor.to(origin_type) return tensor @@ -113,10 +114,14 @@ def sample(tensor: torch.Tensor, density: float) -> torch.Tensor: mask = torch.bernoulli(intermediate / intermediate.max()) - tensor *= mask - if (tensor.device.type == "cpu") or tensor.dtype != torch.bfloat16: - return tensor.to(origin_type) - return tensor + if rescale: + res = rescale_sum(tensor, mask) + else: + res = tensor * mask + + if flag: + return res.to(origin_type) + return res def sparsify( tensor: torch.Tensor, From b42ed87f1afdd8816cc16617bbf8b01ad970d692 Mon Sep 17 00:00:00 2001 From: Azazelle Guice Date: Sun, 7 Apr 2024 19:56:42 +0000 Subject: [PATCH 09/21] Clean up code --- mergekit/merge_methods/__init__.py | 1 + mergekit/sparsify.py | 38 ++++++++++-------------------- 2 files changed, 13 insertions(+), 26 deletions(-) diff --git a/mergekit/merge_methods/__init__.py b/mergekit/merge_methods/__init__.py index 26adde30..24848d94 100644 --- a/mergekit/merge_methods/__init__.py +++ b/mergekit/merge_methods/__init__.py @@ -68,6 +68,7 @@ def get(method: str) -> MergeMethod: consensus_method=ConsensusMethod.sum, sparsification_method=SparsificationMethod.sample, default_normalize=False, + default_rescale=True, ) raise RuntimeError(f"Unimplemented merge method {method}") diff --git a/mergekit/sparsify.py b/mergekit/sparsify.py index 74fab497..05ffd8f3 100644 --- a/mergekit/sparsify.py +++ b/mergekit/sparsify.py @@ -82,46 +82,32 @@ def sample(tensor: torch.Tensor, density: float, rescale: bool) -> torch.Tensor: if density >= 1 or tensor.abs().max() == 0.0: return tensor - if (tensor.device.type == "cpu") or tensor.dtype != torch.bfloat16: - # torch.bernoulli not implemented for float16 on CPU, upcast to float32 - flag = True - origin_type = tensor.dtype - tensor = tensor.to(torch.float32) - - intermediate = tensor.abs() - avg = intermediate.mean() / intermediate.max() - # Handle if the tensor is already sparser than the density (In line with trimming). - if ((intermediate**0.0).mean() / (intermediate**0.0).max()) <= density: - if flag: - return tensor.to(origin_type) + if ((tensor.abs()**0.0).mean() / (tensor.abs()**0.0).max()) <= density: return tensor + if (tensor.device.type != "cpu") or tensor.dtype == torch.bfloat16: + work_dtype = tensor.dtype + else: + # torch.bernoulli not implemented for float16 on CPU, upcast to float32 + work_dtype = torch.float32 + # Find the power that makes the distribution fit the density i = 0; power = 1.0 while i < 15: - # print("Average: ", avg) - # print("Density: ", density) - # print("Diff: ", avg - density) - power += avg - density - # print("Power: ", power) - if power < 0: - power = 0 intermediate = tensor.abs()**power - # print("Intermediate: ", intermediat) avg = intermediate.mean() / intermediate.max() + power += avg - density + if power < 0: power = 0 i += 1 - mask = torch.bernoulli(intermediate / intermediate.max()) - + mask = torch.bernoulli((intermediate / intermediate.max()).to(work_dtype)) + if rescale: res = rescale_sum(tensor, mask) else: res = tensor * mask - - if flag: - return res.to(origin_type) - return res + return res.to(tensor.dtype) def sparsify( tensor: torch.Tensor, From 56f12d061554e327c72eace307477054a29ea608 Mon Sep 17 00:00:00 2001 From: Azazelle Guice Date: Sun, 7 Apr 2024 20:34:17 +0000 Subject: [PATCH 10/21] Fixed edge cases --- mergekit/sparsify.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/mergekit/sparsify.py b/mergekit/sparsify.py index 05ffd8f3..19b0d912 100644 --- a/mergekit/sparsify.py +++ b/mergekit/sparsify.py @@ -79,7 +79,7 @@ def bernoulli(tensor: torch.Tensor, density: float, rescale: bool) -> torch.Tens def sample(tensor: torch.Tensor, density: float, rescale: bool) -> torch.Tensor: """Samples the tensor as it's own mask, then shifts mean to fit density.""" - if density >= 1 or tensor.abs().max() == 0.0: + if density >= 1 or tensor.abs().max() == 0.0 or tensor.abs().max() == float('inf'): return tensor # Handle if the tensor is already sparser than the density (In line with trimming). @@ -93,14 +93,16 @@ def sample(tensor: torch.Tensor, density: float, rescale: bool) -> torch.Tensor: work_dtype = torch.float32 # Find the power that makes the distribution fit the density - i = 0; power = 1.0 - while i < 15: + i = 0; power = 1.0; avg = tensor.abs().mean() / tensor.abs().max() + while (avg - density) <= 1e-5 and i < 15: intermediate = tensor.abs()**power avg = intermediate.mean() / intermediate.max() power += avg - density - if power < 0: power = 0 + if power < 0: + power = 0 i += 1 + intermediate = tensor.abs()**power mask = torch.bernoulli((intermediate / intermediate.max()).to(work_dtype)) if rescale: From e0e97d9adfb33ee4a6d5d0e54730d0661f2629b2 Mon Sep 17 00:00:00 2001 From: Azazelle Guice Date: Mon, 8 Apr 2024 17:44:14 +0000 Subject: [PATCH 11/21] reformatted --- mergekit/sparsify.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/mergekit/sparsify.py b/mergekit/sparsify.py index 19b0d912..e994a8a5 100644 --- a/mergekit/sparsify.py +++ b/mergekit/sparsify.py @@ -77,13 +77,14 @@ def bernoulli(tensor: torch.Tensor, density: float, rescale: bool) -> torch.Tens res /= density return res.to(tensor.dtype) + def sample(tensor: torch.Tensor, density: float, rescale: bool) -> torch.Tensor: """Samples the tensor as it's own mask, then shifts mean to fit density.""" if density >= 1 or tensor.abs().max() == 0.0 or tensor.abs().max() == float('inf'): return tensor # Handle if the tensor is already sparser than the density (In line with trimming). - if ((tensor.abs()**0.0).mean() / (tensor.abs()**0.0).max()) <= density: + if ((tensor.abs() ** 0.0).mean() / (tensor.abs() ** 0.0).max()) <= density: return tensor if (tensor.device.type != "cpu") or tensor.dtype == torch.bfloat16: @@ -93,9 +94,11 @@ def sample(tensor: torch.Tensor, density: float, rescale: bool) -> torch.Tensor: work_dtype = torch.float32 # Find the power that makes the distribution fit the density - i = 0; power = 1.0; avg = tensor.abs().mean() / tensor.abs().max() + i = 0 + power = 1.0 + avg = tensor.abs().mean() / tensor.abs().max() while (avg - density) <= 1e-5 and i < 15: - intermediate = tensor.abs()**power + intermediate = tensor.abs() ** power avg = intermediate.mean() / intermediate.max() power += avg - density if power < 0: @@ -111,6 +114,7 @@ def sample(tensor: torch.Tensor, density: float, rescale: bool) -> torch.Tensor: res = tensor * mask return res.to(tensor.dtype) + def sparsify( tensor: torch.Tensor, density: float, From b71e908b9c7e3ce7d60df21fb1024db6c34ef890 Mon Sep 17 00:00:00 2001 From: Azazelle Guice Date: Mon, 8 Apr 2024 18:52:44 +0000 Subject: [PATCH 12/21] reformat --- mergekit/sparsify.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mergekit/sparsify.py b/mergekit/sparsify.py index e994a8a5..4c145190 100644 --- a/mergekit/sparsify.py +++ b/mergekit/sparsify.py @@ -80,7 +80,7 @@ def bernoulli(tensor: torch.Tensor, density: float, rescale: bool) -> torch.Tens def sample(tensor: torch.Tensor, density: float, rescale: bool) -> torch.Tensor: """Samples the tensor as it's own mask, then shifts mean to fit density.""" - if density >= 1 or tensor.abs().max() == 0.0 or tensor.abs().max() == float('inf'): + if density >= 1 or tensor.abs().max() == 0.0 or tensor.abs().max() == float("inf"): return tensor # Handle if the tensor is already sparser than the density (In line with trimming). @@ -105,7 +105,7 @@ def sample(tensor: torch.Tensor, density: float, rescale: bool) -> torch.Tensor: power = 0 i += 1 - intermediate = tensor.abs()**power + intermediate = tensor.abs() ** power mask = torch.bernoulli((intermediate / intermediate.max()).to(work_dtype)) if rescale: From d99bd82765a286bebf699c3ee0939599c0eafa3b Mon Sep 17 00:00:00 2001 From: Azazelle Guice Date: Mon, 8 Apr 2024 18:52:44 +0000 Subject: [PATCH 13/21] reformat --- examples/ties.yml | 2 +- mergekit/sparsify.py | 5 +++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/examples/ties.yml b/examples/ties.yml index 99174d4a..b5f63bce 100644 --- a/examples/ties.yml +++ b/examples/ties.yml @@ -13,4 +13,4 @@ models: weight: [0.5, 0.8, 1.0] merge_method: sample_ties base_model: microsoft/phi-2 -dtype: bfloat16 \ No newline at end of file +dtype: bfloat16 diff --git a/mergekit/sparsify.py b/mergekit/sparsify.py index e994a8a5..0d7dd711 100644 --- a/mergekit/sparsify.py +++ b/mergekit/sparsify.py @@ -23,6 +23,7 @@ class SparsificationMethod(str, Enum): random = "random" sample = "sample" + def rescale_sum(tensor: torch.Tensor, mask: torch.Tensor): """Rescales the values to match the original tensor sum.""" org_sum = tensor.abs().sum() @@ -80,7 +81,7 @@ def bernoulli(tensor: torch.Tensor, density: float, rescale: bool) -> torch.Tens def sample(tensor: torch.Tensor, density: float, rescale: bool) -> torch.Tensor: """Samples the tensor as it's own mask, then shifts mean to fit density.""" - if density >= 1 or tensor.abs().max() == 0.0 or tensor.abs().max() == float('inf'): + if density >= 1 or tensor.abs().max() == 0.0 or tensor.abs().max() == float("inf"): return tensor # Handle if the tensor is already sparser than the density (In line with trimming). @@ -105,7 +106,7 @@ def sample(tensor: torch.Tensor, density: float, rescale: bool) -> torch.Tensor: power = 0 i += 1 - intermediate = tensor.abs()**power + intermediate = tensor.abs() ** power mask = torch.bernoulli((intermediate / intermediate.max()).to(work_dtype)) if rescale: From 192af12d81391559e94a9160b56e7f2749a350bd Mon Sep 17 00:00:00 2001 From: Azazelle Guice Date: Tue, 9 Apr 2024 03:42:11 +0000 Subject: [PATCH 14/21] Basic Ranked implementation --- examples/ties.yml | 3 ++- mergekit/merge_methods/__init__.py | 7 ++++++ mergekit/sparsify.py | 38 ++++++++++++++++++++++++++++-- 3 files changed, 45 insertions(+), 3 deletions(-) diff --git a/examples/ties.yml b/examples/ties.yml index b5f63bce..f9aaa7f4 100644 --- a/examples/ties.yml +++ b/examples/ties.yml @@ -11,6 +11,7 @@ models: parameters: density: [0.1, 0.2, 0.6] weight: [0.5, 0.8, 1.0] -merge_method: sample_ties +merge_method: ranked_ties base_model: microsoft/phi-2 +parameters: dtype: bfloat16 diff --git a/mergekit/merge_methods/__init__.py b/mergekit/merge_methods/__init__.py index 24848d94..e8509403 100644 --- a/mergekit/merge_methods/__init__.py +++ b/mergekit/merge_methods/__init__.py @@ -70,6 +70,13 @@ def get(method: str) -> MergeMethod: default_normalize=False, default_rescale=True, ) + elif method == "ranked_ties": + return GeneralizedTaskArithmeticMerge( + consensus_method=ConsensusMethod.sum, + sparsification_method=SparsificationMethod.ranked, + default_normalize=False, + default_rescale=True, + ) raise RuntimeError(f"Unimplemented merge method {method}") diff --git a/mergekit/sparsify.py b/mergekit/sparsify.py index 0d7dd711..4612404a 100644 --- a/mergekit/sparsify.py +++ b/mergekit/sparsify.py @@ -22,6 +22,7 @@ class SparsificationMethod(str, Enum): magnitude = "magnitude" random = "random" sample = "sample" + ranked = "ranked" def rescale_sum(tensor: torch.Tensor, mask: torch.Tensor): @@ -49,8 +50,8 @@ def magnitude(tensor: torch.Tensor, density: float, rescale: bool) -> torch.Tens w = tensor.abs().view(-1) if w.device.type == "cpu": w = w.float() - topk = torch.topk(w, k=k, largest=True) - mask.view(-1)[topk.indices] = 1 + topk = torch.argsort(w, descending=True)[:k] + mask.view(-1)[topk] = 1 if rescale: res = rescale_sum(tensor, mask) @@ -79,6 +80,37 @@ def bernoulli(tensor: torch.Tensor, density: float, rescale: bool) -> torch.Tens return res.to(tensor.dtype) +def ranked(tensor: torch.Tensor, density: float, rescale: bool) -> torch.Tensor: + print("triggered") + if density >= 1: + return tensor + + # Handle if the tensor is already sparser than the density (In line with trimming). + if ((tensor.abs() ** 0.0).mean() / (tensor.abs() ** 0.0).max()) <= density: + return tensor + + work_dtype = tensor.dtype + size = int(tensor.view(-1).shape[0]) + + rank = torch.zeros_like(tensor) + w = tensor.abs().view(-1) + if w.device.type == "cpu": + w = w.float() + sort = torch.argsort(w, descending=True) + rank.view(-1)[sort] = torch.linspace(1, 0, steps=size, device=w.device.type, dtype=work_dtype) + + mask = torch.bernoulli(rank) + + print(mask) + + if rescale: + res = rescale_sum(tensor, mask) + else: + res = tensor * mask + + return res + + def sample(tensor: torch.Tensor, density: float, rescale: bool) -> torch.Tensor: """Samples the tensor as it's own mask, then shifts mean to fit density.""" if density >= 1 or tensor.abs().max() == 0.0 or tensor.abs().max() == float("inf"): @@ -128,5 +160,7 @@ def sparsify( return bernoulli(tensor, density=density, rescale=rescale) elif method == SparsificationMethod.sample: return sample(tensor, density=density, rescale=rescale) + elif method == SparsificationMethod.ranked: + return ranked(tensor, density=density, rescale=rescale) else: raise NotImplementedError(method) From 2c214ef37b44608b6b6b735cf576f297f46bb6d8 Mon Sep 17 00:00:00 2001 From: Azazelle Guice Date: Wed, 10 Apr 2024 03:33:37 +0000 Subject: [PATCH 15/21] Working implementation --- mergekit/sparsify.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/mergekit/sparsify.py b/mergekit/sparsify.py index 4612404a..23dee971 100644 --- a/mergekit/sparsify.py +++ b/mergekit/sparsify.py @@ -81,7 +81,6 @@ def bernoulli(tensor: torch.Tensor, density: float, rescale: bool) -> torch.Tens def ranked(tensor: torch.Tensor, density: float, rescale: bool) -> torch.Tensor: - print("triggered") if density >= 1: return tensor @@ -97,11 +96,9 @@ def ranked(tensor: torch.Tensor, density: float, rescale: bool) -> torch.Tensor: if w.device.type == "cpu": w = w.float() sort = torch.argsort(w, descending=True) - rank.view(-1)[sort] = torch.linspace(1, 0, steps=size, device=w.device.type, dtype=work_dtype) - + + rank.view(-1)[sort] = torch.linspace(1, 0, steps=size, device=w.device.type, dtype=work_dtype).pow((1 / density) - 1) mask = torch.bernoulli(rank) - - print(mask) if rescale: res = rescale_sum(tensor, mask) From 480e9fdaf2fac7406de17550d07e013909f25ffe Mon Sep 17 00:00:00 2001 From: Azazelle Guice Date: Wed, 10 Apr 2024 03:33:37 +0000 Subject: [PATCH 16/21] Working implementation --- mergekit/sparsify.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/mergekit/sparsify.py b/mergekit/sparsify.py index 4612404a..b183296d 100644 --- a/mergekit/sparsify.py +++ b/mergekit/sparsify.py @@ -81,7 +81,6 @@ def bernoulli(tensor: torch.Tensor, density: float, rescale: bool) -> torch.Tens def ranked(tensor: torch.Tensor, density: float, rescale: bool) -> torch.Tensor: - print("triggered") if density >= 1: return tensor @@ -92,16 +91,14 @@ def ranked(tensor: torch.Tensor, density: float, rescale: bool) -> torch.Tensor: work_dtype = tensor.dtype size = int(tensor.view(-1).shape[0]) - rank = torch.zeros_like(tensor) + mask = torch.zeros_like(tensor) w = tensor.abs().view(-1) if w.device.type == "cpu": w = w.float() sort = torch.argsort(w, descending=True) - rank.view(-1)[sort] = torch.linspace(1, 0, steps=size, device=w.device.type, dtype=work_dtype) - - mask = torch.bernoulli(rank) - - print(mask) + + mask.view(-1)[sort] = torch.linspace(1, 0, steps=size, device=w.device.type, dtype=work_dtype).pow((1 / density) - 1) + mask = torch.bernoulli(mask) if rescale: res = rescale_sum(tensor, mask) From 3f6207e4e950775366c063f1d7aa626356c59e62 Mon Sep 17 00:00:00 2001 From: Azazelle Guice Date: Sat, 13 Apr 2024 03:01:42 +0000 Subject: [PATCH 17/21] Added Smoothing --- examples/ties.yml | 28 +++++++++++-------- mergekit/merge_methods/__init__.py | 6 ++++ .../generalized_task_arithmetic.py | 7 +++++ mergekit/sparsify.py | 20 +++++++------ 4 files changed, 41 insertions(+), 20 deletions(-) diff --git a/examples/ties.yml b/examples/ties.yml index 7c8c373d..aa2f56e0 100644 --- a/examples/ties.yml +++ b/examples/ties.yml @@ -1,16 +1,20 @@ models: - - model: abacaj/phi-2-super # Best + - model: mistralai/Mistral-7B-v0.1 + # no parameters necessary for base model + - model: Undi95/Toppy-M-7B #175 parameters: - density: [0.8, 0.3, 0.1] - weight: [1.0, 0.8, 0.5] - - model: rhysjones/phi-2-orange-v2 #Middle + weight: 0.54 + density: 0.81 + - model: PistachioAlt/Noromaid-Bagel-7B-Slerp #75 parameters: - density: [0.1, 0.4, 0.1] - weight: [0.2, 0.8, 0.1] - - model: mobiuslabsgmbh/aanaphi2-v0.1 # 2nd End + weight: 0.23 + density: 0.61 + - model: OpenPipe/mistral-ft-optimized-1227 #100 parameters: - density: [0.1, 0.2, 0.6] - weight: [0.5, 0.8, 1.0] -merge_method: ranked_ties -base_model: microsoft/phi-2 -dtype: bfloat16 + weight: 0.31 + density: 0.68 +merge_method: dare_ties +base_model: mistralai/Mistral-7B-v0.1 +parameters: + int8_mask: true +dtype: bfloat16 \ No newline at end of file diff --git a/mergekit/merge_methods/__init__.py b/mergekit/merge_methods/__init__.py index e8509403..68b7648d 100644 --- a/mergekit/merge_methods/__init__.py +++ b/mergekit/merge_methods/__init__.py @@ -39,6 +39,7 @@ def get(method: str) -> MergeMethod: sparsification_method=None, default_normalize=False, default_rescale=False, + default_smooth=False, ) elif method == "ties": return GeneralizedTaskArithmeticMerge( @@ -46,6 +47,7 @@ def get(method: str) -> MergeMethod: sparsification_method=SparsificationMethod.magnitude, default_normalize=True, default_rescale=False, + default_smooth=False, ) elif method == "dare_ties": return GeneralizedTaskArithmeticMerge( @@ -53,6 +55,7 @@ def get(method: str) -> MergeMethod: sparsification_method=SparsificationMethod.random, default_normalize=False, default_rescale=True, + default_smooth=False, ) elif method == "dare_linear": return GeneralizedTaskArithmeticMerge( @@ -60,6 +63,7 @@ def get(method: str) -> MergeMethod: sparsification_method=SparsificationMethod.random, default_normalize=False, default_rescale=True, + default_smooth=False, ) elif method == "model_stock": return ModelStockMerge() @@ -69,6 +73,7 @@ def get(method: str) -> MergeMethod: sparsification_method=SparsificationMethod.sample, default_normalize=False, default_rescale=True, + default_smooth=False, ) elif method == "ranked_ties": return GeneralizedTaskArithmeticMerge( @@ -76,6 +81,7 @@ def get(method: str) -> MergeMethod: sparsification_method=SparsificationMethod.ranked, default_normalize=False, default_rescale=True, + default_smooth=False, ) raise RuntimeError(f"Unimplemented merge method {method}") diff --git a/mergekit/merge_methods/generalized_task_arithmetic.py b/mergekit/merge_methods/generalized_task_arithmetic.py index 2bfbcd74..f614b869 100644 --- a/mergekit/merge_methods/generalized_task_arithmetic.py +++ b/mergekit/merge_methods/generalized_task_arithmetic.py @@ -39,6 +39,7 @@ class GeneralizedTaskArithmeticMerge(MergeMethod, BaseModel, frozen=True): sparsification_method: Optional[SparsificationMethod] default_normalize: bool default_rescale: bool + default_smooth: bool def parameters(self) -> List[ConfigParameterDef]: return [ @@ -49,6 +50,9 @@ def parameters(self) -> List[ConfigParameterDef]: ConfigParameterDef( name="rescale", required=False, default_value=self.default_rescale ), + ConfigParameterDef( + name="smooth", required=False, default_value=self.default_smooth + ), ] def tensor_parameters(self) -> List[ConfigParameterDef]: @@ -73,6 +77,7 @@ def make_task( int8_mask=parameters["int8_mask"], normalize=parameters["normalize"], rescale=parameters["rescale"], + smooth=parameters["smooth"], out_tensor_name=output_weight.name, ) @@ -86,6 +91,7 @@ class GTATask(Task[torch.Tensor]): int8_mask: bool normalize: bool rescale: bool + smooth: bool def uses_accelerator(self) -> bool: return True @@ -116,6 +122,7 @@ def execute( density=tv_info["density"], method=self.method.sparsification_method, rescale=self.rescale, + smooth=self.smooth, ) deltas = torch.stack([tv["delta"] for tv in tvs], dim=0) diff --git a/mergekit/sparsify.py b/mergekit/sparsify.py index b183296d..79ad80f2 100644 --- a/mergekit/sparsify.py +++ b/mergekit/sparsify.py @@ -80,7 +80,7 @@ def bernoulli(tensor: torch.Tensor, density: float, rescale: bool) -> torch.Tens return res.to(tensor.dtype) -def ranked(tensor: torch.Tensor, density: float, rescale: bool) -> torch.Tensor: +def ranked(tensor: torch.Tensor, density: float, rescale: bool, smooth: bool) -> torch.Tensor: if density >= 1: return tensor @@ -98,9 +98,10 @@ def ranked(tensor: torch.Tensor, density: float, rescale: bool) -> torch.Tensor: sort = torch.argsort(w, descending=True) mask.view(-1)[sort] = torch.linspace(1, 0, steps=size, device=w.device.type, dtype=work_dtype).pow((1 / density) - 1) - mask = torch.bernoulli(mask) + if smooth: + mask = torch.bernoulli(mask) - if rescale: + if not rescale: res = rescale_sum(tensor, mask) else: res = tensor * mask @@ -108,7 +109,7 @@ def ranked(tensor: torch.Tensor, density: float, rescale: bool) -> torch.Tensor: return res -def sample(tensor: torch.Tensor, density: float, rescale: bool) -> torch.Tensor: +def sample(tensor: torch.Tensor, density: float, rescale: bool, smooth: bool) -> torch.Tensor: """Samples the tensor as it's own mask, then shifts mean to fit density.""" if density >= 1 or tensor.abs().max() == 0.0 or tensor.abs().max() == float("inf"): return tensor @@ -136,7 +137,9 @@ def sample(tensor: torch.Tensor, density: float, rescale: bool) -> torch.Tensor: i += 1 intermediate = tensor.abs() ** power - mask = torch.bernoulli((intermediate / intermediate.max()).to(work_dtype)) + mask = (intermediate / intermediate.max()).to(work_dtype) + if not smooth: + mask = torch.bernoulli(mask) if rescale: res = rescale_sum(tensor, mask) @@ -149,15 +152,16 @@ def sparsify( tensor: torch.Tensor, density: float, method: SparsificationMethod, - rescale: bool = False, + rescale: bool, + smooth: bool, ) -> torch.Tensor: if method == SparsificationMethod.magnitude: return magnitude(tensor, density=density, rescale=rescale) elif method == SparsificationMethod.random: return bernoulli(tensor, density=density, rescale=rescale) elif method == SparsificationMethod.sample: - return sample(tensor, density=density, rescale=rescale) + return sample(tensor, density=density, rescale=rescale, smooth=smooth) elif method == SparsificationMethod.ranked: - return ranked(tensor, density=density, rescale=rescale) + return ranked(tensor, density=density, rescale=rescale, smooth=smooth) else: raise NotImplementedError(method) From bd6469b55ff4516c389e50db9458bd60f05c8a06 Mon Sep 17 00:00:00 2001 From: Azazelle Guice Date: Sat, 13 Apr 2024 03:14:32 +0000 Subject: [PATCH 18/21] Undo some local stuff --- examples/ties.yml | 30 ++++++++++++++++-------------- mergekit/sparsify.py | 4 ++-- 2 files changed, 18 insertions(+), 16 deletions(-) diff --git a/examples/ties.yml b/examples/ties.yml index aa2f56e0..f16b60ed 100644 --- a/examples/ties.yml +++ b/examples/ties.yml @@ -1,20 +1,22 @@ models: - - model: mistralai/Mistral-7B-v0.1 - # no parameters necessary for base model - - model: Undi95/Toppy-M-7B #175 + - model: psmathur/orca_mini_v3_13b parameters: - weight: 0.54 - density: 0.81 - - model: PistachioAlt/Noromaid-Bagel-7B-Slerp #75 + density: [1, 0.7, 0.1] # density gradient + weight: 1.0 + - model: garage-bAInd/Platypus2-13B parameters: - weight: 0.23 - density: 0.61 - - model: OpenPipe/mistral-ft-optimized-1227 #100 + density: 0.5 + weight: [0, 0.3, 0.7, 1] # weight gradient + - model: WizardLM/WizardMath-13B-V1.0 parameters: - weight: 0.31 - density: 0.68 -merge_method: dare_ties -base_model: mistralai/Mistral-7B-v0.1 + density: 0.33 + weight: + - filter: mlp + value: 0.5 + - value: 0 +merge_method: ties +base_model: TheBloke/Llama-2-13B-fp16 parameters: + normalize: true int8_mask: true -dtype: bfloat16 \ No newline at end of file +dtype: float16 \ No newline at end of file diff --git a/mergekit/sparsify.py b/mergekit/sparsify.py index 79ad80f2..c4d9ce7b 100644 --- a/mergekit/sparsify.py +++ b/mergekit/sparsify.py @@ -50,8 +50,8 @@ def magnitude(tensor: torch.Tensor, density: float, rescale: bool) -> torch.Tens w = tensor.abs().view(-1) if w.device.type == "cpu": w = w.float() - topk = torch.argsort(w, descending=True)[:k] - mask.view(-1)[topk] = 1 + topk = torch.topk(w, k=k, largest=True) + mask.view(-1)[topk.indices] = 1 if rescale: res = rescale_sum(tensor, mask) From dc381aa97d42f910ef675e7ef48b1ab1a00157c0 Mon Sep 17 00:00:00 2001 From: Azazelle Guice Date: Sat, 13 Apr 2024 03:27:19 +0000 Subject: [PATCH 19/21] Reformatted --- examples/ties.yml | 2 +- mergekit/sparsify.py | 18 ++++++++++++------ 2 files changed, 13 insertions(+), 7 deletions(-) diff --git a/examples/ties.yml b/examples/ties.yml index f16b60ed..8c5cfe5c 100644 --- a/examples/ties.yml +++ b/examples/ties.yml @@ -19,4 +19,4 @@ base_model: TheBloke/Llama-2-13B-fp16 parameters: normalize: true int8_mask: true -dtype: float16 \ No newline at end of file +dtype: float16 diff --git a/mergekit/sparsify.py b/mergekit/sparsify.py index c4d9ce7b..81c57011 100644 --- a/mergekit/sparsify.py +++ b/mergekit/sparsify.py @@ -80,10 +80,12 @@ def bernoulli(tensor: torch.Tensor, density: float, rescale: bool) -> torch.Tens return res.to(tensor.dtype) -def ranked(tensor: torch.Tensor, density: float, rescale: bool, smooth: bool) -> torch.Tensor: +def ranked( + tensor: torch.Tensor, density: float, rescale: bool, smooth: bool + ) -> torch.Tensor: if density >= 1: return tensor - + # Handle if the tensor is already sparser than the density (In line with trimming). if ((tensor.abs() ** 0.0).mean() / (tensor.abs() ** 0.0).max()) <= density: return tensor @@ -96,11 +98,13 @@ def ranked(tensor: torch.Tensor, density: float, rescale: bool, smooth: bool) -> if w.device.type == "cpu": w = w.float() sort = torch.argsort(w, descending=True) - - mask.view(-1)[sort] = torch.linspace(1, 0, steps=size, device=w.device.type, dtype=work_dtype).pow((1 / density) - 1) + + mask.view(-1)[sort] = torch.linspace( + 1, 0, steps=size, device=w.device.type, dtype=work_dtype + ).pow((1 / density) - 1) if smooth: mask = torch.bernoulli(mask) - + if not rescale: res = rescale_sum(tensor, mask) else: @@ -109,7 +113,9 @@ def ranked(tensor: torch.Tensor, density: float, rescale: bool, smooth: bool) -> return res -def sample(tensor: torch.Tensor, density: float, rescale: bool, smooth: bool) -> torch.Tensor: +def sample( + tensor: torch.Tensor, density: float, rescale: bool, smooth: bool + ) -> torch.Tensor: """Samples the tensor as it's own mask, then shifts mean to fit density.""" if density >= 1 or tensor.abs().max() == 0.0 or tensor.abs().max() == float("inf"): return tensor From e90e2f6b67e2b6e8bdccee78f465e52e0c3522de Mon Sep 17 00:00:00 2001 From: Azazelle Guice Date: Sat, 13 Apr 2024 03:29:32 +0000 Subject: [PATCH 20/21] again :P --- mergekit/sparsify.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mergekit/sparsify.py b/mergekit/sparsify.py index 81c57011..e929f329 100644 --- a/mergekit/sparsify.py +++ b/mergekit/sparsify.py @@ -101,7 +101,7 @@ def ranked( mask.view(-1)[sort] = torch.linspace( 1, 0, steps=size, device=w.device.type, dtype=work_dtype - ).pow((1 / density) - 1) + ).pow((1 / density) - 1) if smooth: mask = torch.bernoulli(mask) @@ -115,7 +115,7 @@ def ranked( def sample( tensor: torch.Tensor, density: float, rescale: bool, smooth: bool - ) -> torch.Tensor: +) -> torch.Tensor: """Samples the tensor as it's own mask, then shifts mean to fit density.""" if density >= 1 or tensor.abs().max() == 0.0 or tensor.abs().max() == float("inf"): return tensor From 15b699da4799ddea28f347f7663f4bf2ae1cb3ed Mon Sep 17 00:00:00 2001 From: Azazelle Guice Date: Sat, 13 Apr 2024 03:30:31 +0000 Subject: [PATCH 21/21] again again :P --- mergekit/sparsify.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mergekit/sparsify.py b/mergekit/sparsify.py index e929f329..6319408a 100644 --- a/mergekit/sparsify.py +++ b/mergekit/sparsify.py @@ -82,7 +82,7 @@ def bernoulli(tensor: torch.Tensor, density: float, rescale: bool) -> torch.Tens def ranked( tensor: torch.Tensor, density: float, rescale: bool, smooth: bool - ) -> torch.Tensor: +) -> torch.Tensor: if density >= 1: return tensor