Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 10 additions & 13 deletions experiments/run_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,8 @@
from slicegpt import data_utils, gpu_utils, hf_utils, utils
from slicegpt.config import config

utils.configure_logging()

os.environ["WANDB__SERVICE_WAIT"] = "300"


def argparser() -> argparse.Namespace:
def benchmarking_arg_parser(interactive: bool = True) -> argparse.Namespace:
parser = argparse.ArgumentParser()
parser.add_argument(
"--model",
Expand Down Expand Up @@ -74,8 +70,10 @@ def argparser() -> argparse.Namespace:
help="PyTorch device to use. Example values are 'cpu', 'cuda', 'cuda:0'. If not specified it will be defaulted to 'cuda' if available and 'cpu' otherwise.",
)

args = parser.parse_args()
return parser.parse_args() if interactive else parser.parse_args('')


def process_benchmarking_args(args: argparse.Namespace):
logging.debug(f'Parsed arguments:')
for arg, argv in vars(args).items():
logging.debug(f'{arg} = {argv}')
Expand All @@ -93,14 +91,9 @@ def argparser() -> argparse.Namespace:
else:
raise argparse.ArgumentTypeError(f"Data type should be one of 'fp16', 'fp32'")

return args


def main() -> None:
def benchmarking_main(args: argparse.Namespace) -> None:
logging.info("Running benchmarking of a sliced model.")

args = argparser()

logging.info(f"PyTorch device: {config.device}")
logging.info(f"Number of available cuda devices: {torch.cuda.device_count()}")

Expand Down Expand Up @@ -148,4 +141,8 @@ def main() -> None:


if __name__ == "__main__":
main()
utils.configure_logging()
os.environ["WANDB__SERVICE_WAIT"] = "300"
benchmarking_args = benchmarking_arg_parser()
process_benchmarking_args(benchmarking_args)
benchmarking_main(benchmarking_args)
51 changes: 34 additions & 17 deletions experiments/run_finetuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import argparse
import logging
import os
import pathlib
import shutil

import syne_tune
import torch
Expand All @@ -18,10 +20,6 @@
from slicegpt import data_utils, gpu_utils, hf_utils, utils
from slicegpt.config import config

utils.configure_logging()

os.environ["WANDB__SERVICE_WAIT"] = "300"


def get_optimizer_and_scheduler(model, train_dataset, config):
optimizer = torch.optim.AdamW(
Expand Down Expand Up @@ -64,7 +62,7 @@ def get_eval_dataloader(self, _) -> DataLoader:
return self.test_loader


def argparser():
def finetuning_arg_parser(interactive: bool = True) -> argparse.Namespace:
parser = argparse.ArgumentParser()
parser.add_argument(
"--model",
Expand Down Expand Up @@ -195,9 +193,10 @@ def argparser():
help="target module option to apply lora to (names of attn i/p, attn o/p and mlp in LayerAdapter)",
)

args = parser.parse_args()
return parser.parse_args() if interactive else parser.parse_args('')


logging.debug(f'Parsed arguments:')
def process_finetuning_args(args):
for arg, argv in vars(args).items():
logging.debug(f'{arg} = {argv}')

Expand All @@ -214,14 +213,9 @@ def argparser():
else:
raise argparse.ArgumentTypeError(f"Data type should be one of 'fp16', 'fp32'")

return args


def main() -> None:
def finetuning_main(args: argparse.Namespace) -> None:
logging.info("Running SliceGPT post-slicing finetuning experiment")

args = argparser()

logging.info(f"PyTorch device: {config.device}")
logging.info(f"Number of available cuda devices: {torch.cuda.device_count()}")

Expand Down Expand Up @@ -342,15 +336,33 @@ def main() -> None:
trainer.train()

if args.save_dir:
if not os.path.exists(args.save_dir):
os.makedirs(args.save_dir)
rft_dir = args.save_dir
if not os.path.exists(rft_dir):
os.makedirs(rft_dir, exist_ok=True)

model_file = os.path.join(args.save_dir, os.path.basename(args.model) + "_" + str(args.sparsity) + ".pt")
model_file = os.path.join(rft_dir, os.path.basename(args.model) + "_" + str(args.sparsity) + ".pt")

# save peft model as a standard pt model
merged_model = lora_model.merge_and_unload()

torch.save(merged_model.state_dict(), model_file)

if args.sliced_model_path:
sliced_model_dir = os.path.dirname(args.sliced_model_path)
try:
# copy all config files (tokenizer, model and slicing configs)
for file in pathlib.Path(sliced_model_dir).glob("*.json"):
if 'safetensors' not in str(file):
shutil.copy(str(file), rft_dir)
# copy all tokenizer models
for file in pathlib.Path(sliced_model_dir).glob("*token*.model"):
shutil.copy(str(file), rft_dir)
# copy vocab merges if any
for file in pathlib.Path(sliced_model_dir).glob("merges.txt"):
shutil.copy(str(file), rft_dir)
except OSError as e:
logging.info(f'Failed to copy configs and tokenizer files: {e}')

logging.info(f"Saved sliced and finetuned model to {args.save_dir}")

utils.cleanup_memory()
Expand All @@ -365,4 +377,9 @@ def main() -> None:


if __name__ == "__main__":
main()
utils.configure_logging(log_to_console=True, log_to_file=False, level=logging.INFO)
os.environ["WANDB__SERVICE_WAIT"] = "300"

finetuning_args = finetuning_arg_parser()
process_finetuning_args(finetuning_args)
finetuning_main(finetuning_args)
37 changes: 19 additions & 18 deletions experiments/run_slicegpt_perplexity.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,8 @@
from slicegpt.config import config
from slicegpt.slicing_scheduler import ConstSlicingScheduler

utils.configure_logging()

os.environ["WANDB__SERVICE_WAIT"] = "300"


def argparser() -> argparse.Namespace:
def slicing_arg_parser(interactive: bool = True) -> argparse.Namespace:
parser = argparse.ArgumentParser()
parser.add_argument(
"--model",
Expand Down Expand Up @@ -105,9 +101,10 @@ def argparser() -> argparse.Namespace:
help="PyTorch device to use. Example values are 'cpu', 'cuda', 'cuda:0'. If not specified it will be defaulted to 'cuda' if available and 'cpu' otherwise.",
)

args = parser.parse_args()
return parser.parse_args() if interactive else parser.parse_args('')

logging.debug(f'Parsed arguments:')

def process_slicing_args(args):
for arg, argv in vars(args).items():
logging.debug(f'{arg} = {argv}')

Expand All @@ -124,14 +121,9 @@ def argparser() -> argparse.Namespace:
else:
raise argparse.ArgumentTypeError(f"Data type should be one of 'fp16', 'fp32'")

return args


def main() -> None:
def slicing_main(args: argparse.Namespace) -> None:
logging.info("Running SliceGPT perplexity experiment")

args = argparser()

logging.info(f"PyTorch device: {config.device}")
logging.info(f"Number of available cuda devices: {torch.cuda.device_count()}")

Expand Down Expand Up @@ -248,11 +240,15 @@ def reset_model_device() -> None:
# If slicing a local model, also save HF config files in sliced model dir
if args.model_path:
try:
# copy all config files
for file in pathlib.Path(args.model_path).glob("*config*.json"):
# copy all config files (tokenizer, model and slicing configs)
for file in pathlib.Path(args.model_path).glob("*.json"):
if 'safetensors' not in str(file):
shutil.copy(str(file), sliced_model_dir)
# copy all tokenizer models
for file in pathlib.Path(args.model_path).glob("*token*.model"):
shutil.copy(str(file), sliced_model_dir)
# copy all tokenizer files
for file in pathlib.Path(args.model_path).glob("*token*.json"):
# copy vocab merges if any
for file in pathlib.Path(args.model_path).glob("merges.txt"):
shutil.copy(str(file), sliced_model_dir)
except OSError as e:
logging.info(f'Failed to copy configs and tokenizer files: {e}')
Expand All @@ -270,4 +266,9 @@ def reset_model_device() -> None:


if __name__ == "__main__":
main()
utils.configure_logging(log_to_console=True, log_to_file=False, level=logging.INFO)
os.environ["WANDB__SERVICE_WAIT"] = "300"

slicing_args = slicing_arg_parser()
process_slicing_args(slicing_args)
slicing_main(slicing_args)
84 changes: 44 additions & 40 deletions experiments/run_zero_shot_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,9 @@
from slicegpt import gpu_utils, hf_utils, utils
from slicegpt.config import config

utils.configure_logging()

os.environ["WANDB__SERVICE_WAIT"] = "300"


def parse_args() -> argparse.Namespace:
def eval_arg_parser(interactive: bool = True) -> argparse.Namespace:
initialize_tasks()
parser = argparse.ArgumentParser()
parser.add_argument(
"--model",
Expand Down Expand Up @@ -69,14 +66,39 @@ def parse_args() -> argparse.Namespace:
choices=lm_eval_utils.MultiChoice(tasks.ALL_TASKS),
)
parser.add_argument('--num-fewshot', type=int, default=0, help="Number of fewshots for all tasks.")
return parser.parse_args()
parser.add_argument("--save-dir", type=str, default=".", help="Path to save the lm eval results")
return parser.parse_args() if interactive else parser.parse_args('')


def main() -> None:
logging.info("Running SliceGPT zeroshot tasks experiment.")
def calculate_avg_accuracy(task_names: str, results: dict) -> float:
n_tasks = len(task_names)
acc_cumul = sum(
result.get('acc_norm,none', result['acc,none']) for task, result in results.items() if 'mmlu' not in task
)

initialize_tasks()
args = parse_args()
questions_per_mmlu_task = {
task_name: lm_eval.tasks.get_task_dict([task_name])[task_name].dataset["test"].num_rows
for task_name in task_names
if 'mmlu' in task_name
}

if not questions_per_mmlu_task:
return acc_cumul / n_tasks

# Calculate average accuracy for mmlu tasks, weighted by number of questions in each task
acc_mmlu = sum(
result.get('acc_norm,none', result['acc,none']) * questions_per_mmlu_task[task]
for task, result in results.items()
if 'mmlu' in task
)
acc_mmlu_avg = acc_mmlu / sum(questions_per_mmlu_task.values())
wandb.log({'acc_mmlu_avg': acc_mmlu_avg})

return (acc_cumul + acc_mmlu_avg) / (n_tasks - len(questions_per_mmlu_task) + 1)


def eval_main(args: argparse.Namespace) -> None:
logging.info("Running SliceGPT zeroshot tasks experiment.")

logging.info(f"PyTorch device: {config.device}")
logging.info(f"Number of available cuda devices: {torch.cuda.device_count()}")
Expand Down Expand Up @@ -127,40 +149,22 @@ def main() -> None:
'results'
]

wandb.log(results)
metric_vals = {task: round(result.get('acc_norm,none', result['acc,none']), 4) for task, result in results.items()}
logging.info(json.dumps(metric_vals, indent=4))

def calculate_avg_accuracy(task_names, results):
n_tasks = len(task_names)
acc_cumul = sum(
result.get('acc_norm,none', result['acc,none']) for task, result in results.items() if 'mmlu' not in task
)

questions_per_mmlu_task = {
task_name: lm_eval.tasks.get_task_dict([task_name])[task_name].dataset["test"].num_rows
for task_name in task_names
if 'mmlu' in task_name
}

if not questions_per_mmlu_task:
return acc_cumul / n_tasks

# Calculate average accuracy for mmlu tasks, weighted by number of questions in each task
acc_mmlu = sum(
result.get('acc_norm,none', result['acc,none']) * questions_per_mmlu_task[task]
for task, result in results.items()
if 'mmlu' in task
)
acc_mmlu_avg = acc_mmlu / sum(questions_per_mmlu_task.values())
wandb.log({'acc_mmlu_avg': acc_mmlu_avg})

return (acc_cumul + acc_mmlu_avg) / (n_tasks - len(questions_per_mmlu_task) + 1)

acc_avg = calculate_avg_accuracy(task_names, results)
metric_vals['average'] = round(acc_avg, 4)
with open(f"{args.save_dir}/{args.num_fewshot}_shot_task_results.json", "w") as f:
json.dump(metric_vals, f)

wandb.log(results)
wandb.log({'acc_avg': acc_avg})

logging.info(json.dumps(metric_vals, indent=4))
logging.info(f"Average accuracy across tasks: {acc_avg}")


if __name__ == "__main__":
main()
utils.configure_logging(log_to_console=True, log_to_file=False, level=logging.ERROR)
os.environ["WANDB__SERVICE_WAIT"] = "300"

eval_args = eval_arg_parser()
eval_main(eval_args)