diff --git a/examples/neuron-mapping/compare.py b/examples/neuron-mapping/compare.py new file mode 100644 index 0000000..347386f --- /dev/null +++ b/examples/neuron-mapping/compare.py @@ -0,0 +1,30 @@ +import torch + +def compare_pruned_ff_criteria(cripple_repos: list[str], model_size: str): + # cripple_repos = ["physics", "bio", "code"] + directory = "/home/ubuntu/taker-rashid/examples/neuron-mapping/saved_tensors/"+model_size+"/" + focus_repo = "pile" + suffix = "-"+model_size+"-recent.pt" + ratios = {} + ratios["model_size"] = model_size + + for repo1 in cripple_repos: + #load ff_criteria from repo1 + repo1_tensors = torch.load(directory+repo1+"-"+focus_repo+suffix) + repo1_ff_criteria = repo1_tensors["ff_criteria"] + ratios[repo1] = {} + for repo2 in cripple_repos: + if repo1 == repo2: + continue + #load ff_criteria from repo2 + repo2_tensors = torch.load(directory+repo2+"-"+focus_repo+suffix) + repo2_ff_criteria = repo2_tensors["ff_criteria"] + + + matches = torch.logical_and(repo1_ff_criteria, repo2_ff_criteria) + ratio = torch.sum(matches)/torch.sum(repo1_ff_criteria) + ratios[repo1][repo2] = ratio + + return ratios + +print(compare_pruned_ff_criteria(["physics", "bio", "code"], "nickypro/tinyllama-15M")) \ No newline at end of file diff --git a/examples/neuron-mapping/prune_repos.py b/examples/neuron-mapping/prune_repos.py new file mode 100644 index 0000000..c32000b --- /dev/null +++ b/examples/neuron-mapping/prune_repos.py @@ -0,0 +1,73 @@ + +from taker.data_classes import PruningConfig +from taker.parser import cli_parser +from taker.prune import run_pruning +import torch + +def compare_pruned_ff_criteria(cripple_repos: list[str], model_size: str): + # cripple_repos = ["physics", "bio", "code"] + print("model_size: ",model_size) + directory = "/home/ubuntu/taker-rashid/examples/neuron-mapping/saved_tensors/"+model_size+"/" + focus_repo = "pile" + suffix = "-"+model_size+"-recent.pt" + ratios = {} + ratios["model_size"] = model_size + + for repo1 in cripple_repos: + #load ff_criteria from repo1 + repo1_tensors = torch.load(directory+repo1+"-"+focus_repo+suffix) + repo1_ff_criteria = repo1_tensors["ff_criteria"] + ratios[repo1] = {} + for repo2 in cripple_repos: + if repo1 == repo2: + continue + #load ff_criteria from repo2 + repo2_tensors = torch.load(directory+repo2+"-"+focus_repo+suffix) + repo2_ff_criteria = repo2_tensors["ff_criteria"] + + matches = torch.logical_and(repo1_ff_criteria, repo2_ff_criteria) + ratio = torch.sum(matches)/torch.sum(repo1_ff_criteria) + ratios[repo1][repo2] = ratio + + return ratios + + +# Configure initial model and tests +c = PruningConfig( + wandb_project = "testing", # repo to push results to + model_repo = "nickypro/tinyllama-15M", + # "metallama/llama-2-7b" + token_limit = 1000, # trim the input to this max length + run_pre_test = True, # evaluate the unpruned model + eval_sample_size = 1e3, + collection_sample_size = 1e3, + # Removals parameters + ff_frac = 0.2, # % of feed forward neurons to prune + attn_frac = 0.00, # % of attention neurons to prune + focus = "pile", # the “reference” dataset + cripple = "physics", # the “unlearned” dataset + additional_datasets=tuple(), # any extra datasets to evaluate on + recalculate_activations = False, # iterative vs non-iterative + n_steps = 1, +) + +# Parse CLI for arguments +# c, args = cli_parser(c) + +#list of repos to cripple +cripple_repos = ["physics", "biology","chemistry", "math", "code", "poems", "civil", "stories"] +ff_frac_to_prune = [0.01] +model_size = c.model_repo.split('-')[-1] + +# Run the iterated pruning for each cripple repo, for a range of ff_frac pruned +shared_pruning_data = {} +for ff_frac in ff_frac_to_prune: + c.ff_frac = ff_frac + for repo in cripple_repos: + c.cripple = repo + print("running iteration for ", c.cripple, " vs ", c.focus, "with ff_frac: ", ff_frac) + with torch.no_grad(): + model, history = run_pruning(c) + ratios = compare_pruned_ff_criteria(cripple_repos, model_size) + shared_pruning_data[ff_frac] = ratios +print(shared_pruning_data) \ No newline at end of file diff --git a/src/taker/activations.py b/src/taker/activations.py index 72d6393..7f2e713 100644 --- a/src/taker/activations.py +++ b/src/taker/activations.py @@ -625,8 +625,17 @@ def save_timestamped_tensor_dict( opt: Model, data: Dict[str, Tensor], name: str ): now = datetime.datetime.now().strftime( "%Y-%m-%d_%H:%M:%S" ) - os.makedirs( f'tmp/{opt.model_size}', exist_ok=True ) - filename = f'tmp/{opt.model_size}/{opt.model_size}-{name}-{now}.pt' + os.makedirs( f'saved_tensors/{opt.model_size}', exist_ok=True ) + filename = f'saved_tensors/{opt.model_size}/{name}-{opt.model_size}-{now}.pt' + torch.save( data, filename ) + print( f'Saved {filename} to {opt.model_size}' ) + return filename + +def save_tensor_dict( opt: Model, + data: Dict[str, Tensor], + name: str ): + os.makedirs( f'saved_tensors/{opt.model_size}', exist_ok=True ) + filename = f'saved_tensors/{opt.model_size}/{name}-{opt.model_size}-recent.pt' torch.save( data, filename ) print( f'Saved {filename} to {opt.model_size}' ) return filename diff --git a/src/taker/prune.py b/src/taker/prune.py index 2d680a1..d79803a 100644 --- a/src/taker/prune.py +++ b/src/taker/prune.py @@ -11,7 +11,7 @@ from .eval import evaluate_all from .scoring import score_indices_by, score_indices from .activations import get_midlayer_activations, get_top_frac, \ - choose_attn_heads_by, save_timestamped_tensor_dict + choose_attn_heads_by, save_timestamped_tensor_dict, save_tensor_dict from .texts import prepare def prune_and_evaluate( @@ -60,6 +60,7 @@ def prune_and_evaluate( # Prune the model using the activation data data = score_and_prune(opt, focus_out, cripple_out, c) + # Should return a dict with data["deletions"]["ff_pruned"] # Evaluate the model with torch.no_grad(): @@ -73,7 +74,7 @@ def score_and_prune( opt: Model, focus_activations_data: ActivationOverview, cripple_activations_data: ActivationOverview, pruning_config: PruningConfig, - save=False, + save=True, ): # Get the top fraction FF activations and prune ff_frac, ff_eps = pruning_config.ff_frac, pruning_config.ff_eps @@ -133,7 +134,9 @@ def score_and_prune( opt: Model, "attn_criteria": attn_criteria if do_attn else None, } if save: - save_timestamped_tensor_dict( opt, tensor_data, "activation_metrics" ) + #original save function with timestamp, but also version with most recent run saved without timestamp for easy loading, will overwrite old version. + save_timestamped_tensor_dict( opt, tensor_data, pruning_config.cripple + "-" + pruning_config.focus ) + save_tensor_dict( opt, tensor_data, pruning_config.cripple + "-" + pruning_config.focus) # Initialize the output dictionary data = RunDataItem() @@ -143,6 +146,8 @@ def score_and_prune( opt: Model, "attn_threshold": attn_threshold if do_attn else 0, "ff_del": float( torch.sum(ff_criteria) ) if do_ff else 0, "attn_del": float( torch.sum(attn_criteria) ) if do_attn else 0, + "ff_scores": ff_scores.cpu().numpy(), + "ff_criteria": ff_criteria.cpu().numpy(), }}) data.update({'deletions_per_layer': { @@ -273,7 +278,7 @@ def run_pruning(c: PruningConfig): entity=c.wandb_entity, name=c.wandb_run_name, ) - wandb.config.update(c.to_dict()) + wandb.config.update(c.to_dict(), allow_val_change=True) # Evaluate model before removal of any neurons if c.run_pre_test: @@ -310,7 +315,7 @@ def run_pruning(c: PruningConfig): print(history.history[-1]) print(history.df.T) print(history.df.T.to_csv()) - + # print("masks: ", opt.masks["mlp_pre_out"]) return opt, history ###################################################################################### diff --git a/src/taker/texts.py b/src/taker/texts.py index aa8bae8..1088d36 100644 --- a/src/taker/texts.py +++ b/src/taker/texts.py @@ -291,17 +291,37 @@ def infer_dataset_config(dataset_name:str, dataset_subset:str=None): dataset_image_label_key = "coarse_label", dataset_filter=DatasetFilters.filter_veh2, ), - EvalConfig("bio", - dataset_repo = "camel-ai/biology", - dataset_text_key = "message_2", - dataset_has_test_split = False, - ), EvalConfig("emotion", dataset_repo = "dair-ai/emotion", dataset_type = "text-classification", dataset_text_key = "text", dataset_text_label_key = "label", dataset_has_test_split = True, + ), + EvalConfig("biology", + dataset_repo = "camel-ai/biology", + dataset_text_key = "message_2", + dataset_has_test_split = False, + ), + EvalConfig("physics", + dataset_repo = "camel-ai/physics", + dataset_text_key = "message_2", + dataset_has_test_split = False, + ), + EvalConfig("chemistry", + dataset_repo = "camel-ai/chemistry", + dataset_text_key = "message_2", + dataset_has_test_split = False, + ), + EvalConfig("math", + dataset_repo = "camel-ai/math", + dataset_text_key = "message_2", + dataset_has_test_split = False, + ), + EvalConfig("poems", + dataset_repo = "sadFaceEmoji/english-poems", + dataset_text_key = "poem", + dataset_has_test_split = False, ) ]