From cadb7e0ea046d4088a03d506bebaa73dac8d782f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sigrid=20Jin=20=28=E0=B8=87=27=CC=80-=27=CC=81=29=E0=B8=87?= =?UTF-8?q?=20oO?= Date: Thu, 26 Dec 2024 23:02:10 +0900 Subject: [PATCH 1/9] add train_pylate_contrastive.py Refer to https://github.com/lightonai/pylate/blob/main/pylate/losses/contrastive.py --- examples/train_pylate_contrastive.py | 69 ++++++++++++++++++++++++++++ 1 file changed, 69 insertions(+) create mode 100644 examples/train_pylate_contrastive.py diff --git a/examples/train_pylate_contrastive.py b/examples/train_pylate_contrastive.py new file mode 100644 index 00000000..1d88d5d4 --- /dev/null +++ b/examples/train_pylate_contrastive.py @@ -0,0 +1,69 @@ +import torch +from datasets import load_dataset +from sentence_transformers import ( + SentenceTransformerTrainer, + SentenceTransformerTrainingArguments, +) + +from pylate import evaluation, losses, models, utils + +# Add at the start of your train_pylate_contrastive.py +import torch +torch._inductor.config.fallback_random = True +torch._inductor.config.triton.unique_kernel_names = True +# Or completely disable torch compile +# model.forward = torch.compile(model.forward, mode="max-autotune", fullgraph=True) + +# Define model parameters for contrastive training +model_name = "answerdotai/ModernBERT-large" # Choose the pre-trained model you want to use as base +batch_size = 64 # Larger batch size often improves results, but requires more memory + +num_train_epochs = 5 # Adjust based on your requirements +# Set the run name for logging and output directory +run_name = "contrastive-sigrid-241226" +output_dir = f"output/{run_name}" + +# 1. Here we define our ColBERT model. If not a ColBERT model, will add a linear layer to the base encoder. +model = models.ColBERT(model_name_or_path=model_name) + +# Load dataset +dataset = load_dataset("sentence-transformers/msmarco-bm25", "triplet", split="train") +# Split the dataset (this dataset does not have a validation set, so we split the training set) +splits = dataset.train_test_split(test_size=0.01) +train_dataset = splits["train"] +eval_dataset = splits["test"] + +# Define the loss function +train_loss = losses.Contrastive(model=model) + +# Initialize the evaluator +dev_evaluator = evaluation.ColBERTTripletEvaluator( + anchors=eval_dataset["query"], + positives=eval_dataset["positive"], + negatives=eval_dataset["negative"], +) + +# Configure the training arguments (e.g., batch size, evaluation strategy, logging steps) +args = SentenceTransformerTrainingArguments( + output_dir=output_dir, + num_train_epochs=num_train_epochs, + per_device_train_batch_size=batch_size, + per_device_eval_batch_size=batch_size, + fp16=True, # Set to False if you get an error that your GPU can't run on FP16 + bf16=False, # Set to True if you have a GPU that supports BF16 + run_name=run_name, # Will be used in W&B if `wandb` is installed + learning_rate=3e-6, +) + +# Initialize the trainer for the contrastive training +trainer = SentenceTransformerTrainer( + model=model, + args=args, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + loss=train_loss, + evaluator=dev_evaluator, + data_collator=utils.ColBERTCollator(model.tokenize), +) +# Start the training process +trainer.train() From d92e5e5eb049cb30e56ff09576598f272fab4fb9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sigrid=20Jin=20=28=E0=B8=87=27=CC=80-=27=CC=81=29=E0=B8=87?= =?UTF-8?q?=20oO?= Date: Mon, 30 Dec 2024 20:24:09 +0900 Subject: [PATCH 2/9] Update examples/train_pylate_contrastive.py Co-authored-by: Antoine Chaffin <38869395+NohTow@users.noreply.github.com> --- examples/train_pylate_contrastive.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/examples/train_pylate_contrastive.py b/examples/train_pylate_contrastive.py index 1d88d5d4..860b14f2 100644 --- a/examples/train_pylate_contrastive.py +++ b/examples/train_pylate_contrastive.py @@ -12,7 +12,8 @@ torch._inductor.config.fallback_random = True torch._inductor.config.triton.unique_kernel_names = True # Or completely disable torch compile -# model.forward = torch.compile(model.forward, mode="max-autotune", fullgraph=True) +# ModernBERT is compiled by default, so there is no need to call compile explicitly and it actually breaks the model +# model = torch.compile(model) # Define model parameters for contrastive training model_name = "answerdotai/ModernBERT-large" # Choose the pre-trained model you want to use as base From 88ea5f3690c799e74cec4692fd67cc79c9743467 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sigrid=20Jin=20=28=E0=B8=87=27=CC=80-=27=CC=81=29=E0=B8=87?= =?UTF-8?q?=20oO?= Date: Mon, 30 Dec 2024 20:24:15 +0900 Subject: [PATCH 3/9] Update examples/train_pylate_contrastive.py Co-authored-by: Antoine Chaffin <38869395+NohTow@users.noreply.github.com> --- examples/train_pylate_contrastive.py | 1 - 1 file changed, 1 deletion(-) diff --git a/examples/train_pylate_contrastive.py b/examples/train_pylate_contrastive.py index 860b14f2..dded392c 100644 --- a/examples/train_pylate_contrastive.py +++ b/examples/train_pylate_contrastive.py @@ -11,7 +11,6 @@ import torch torch._inductor.config.fallback_random = True torch._inductor.config.triton.unique_kernel_names = True -# Or completely disable torch compile # ModernBERT is compiled by default, so there is no need to call compile explicitly and it actually breaks the model # model = torch.compile(model) From 9fba67febaa9cd496536fd344b773bb207fd8d47 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sigrid=20Jin=20=28=E0=B8=87=27=CC=80-=27=CC=81=29=E0=B8=87?= =?UTF-8?q?=20oO?= Date: Mon, 30 Dec 2024 20:24:20 +0900 Subject: [PATCH 4/9] Update examples/train_pylate_contrastive.py Co-authored-by: Antoine Chaffin <38869395+NohTow@users.noreply.github.com> --- examples/train_pylate_contrastive.py | 1 - 1 file changed, 1 deletion(-) diff --git a/examples/train_pylate_contrastive.py b/examples/train_pylate_contrastive.py index dded392c..dfac7f68 100644 --- a/examples/train_pylate_contrastive.py +++ b/examples/train_pylate_contrastive.py @@ -10,7 +10,6 @@ # Add at the start of your train_pylate_contrastive.py import torch torch._inductor.config.fallback_random = True -torch._inductor.config.triton.unique_kernel_names = True # ModernBERT is compiled by default, so there is no need to call compile explicitly and it actually breaks the model # model = torch.compile(model) From 75d19dfc882c9a4f148a0ede1e251c4e87dd4183 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sigrid=20Jin=20=28=E0=B8=87=27=CC=80-=27=CC=81=29=E0=B8=87?= =?UTF-8?q?=20oO?= Date: Mon, 30 Dec 2024 20:24:25 +0900 Subject: [PATCH 5/9] Update examples/train_pylate_contrastive.py Co-authored-by: Antoine Chaffin <38869395+NohTow@users.noreply.github.com> --- examples/train_pylate_contrastive.py | 1 - 1 file changed, 1 deletion(-) diff --git a/examples/train_pylate_contrastive.py b/examples/train_pylate_contrastive.py index dfac7f68..77a88940 100644 --- a/examples/train_pylate_contrastive.py +++ b/examples/train_pylate_contrastive.py @@ -8,7 +8,6 @@ from pylate import evaluation, losses, models, utils # Add at the start of your train_pylate_contrastive.py -import torch torch._inductor.config.fallback_random = True # ModernBERT is compiled by default, so there is no need to call compile explicitly and it actually breaks the model # model = torch.compile(model) From 425df2fa8b54cc84445564332272042bc5fad7a6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sigrid=20Jin=20=28=E0=B8=87=27=CC=80-=27=CC=81=29=E0=B8=87?= =?UTF-8?q?=20oO?= Date: Mon, 30 Dec 2024 20:24:31 +0900 Subject: [PATCH 6/9] Update examples/train_pylate_contrastive.py Co-authored-by: Antoine Chaffin <38869395+NohTow@users.noreply.github.com> --- examples/train_pylate_contrastive.py | 1 - 1 file changed, 1 deletion(-) diff --git a/examples/train_pylate_contrastive.py b/examples/train_pylate_contrastive.py index 77a88940..f67e6fab 100644 --- a/examples/train_pylate_contrastive.py +++ b/examples/train_pylate_contrastive.py @@ -8,7 +8,6 @@ from pylate import evaluation, losses, models, utils # Add at the start of your train_pylate_contrastive.py -torch._inductor.config.fallback_random = True # ModernBERT is compiled by default, so there is no need to call compile explicitly and it actually breaks the model # model = torch.compile(model) From c41cd6b818bb781842b734774b7d13f8c722e1b2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sigrid=20Jin=20=28=E0=B8=87=27=CC=80-=27=CC=81=29=E0=B8=87?= =?UTF-8?q?=20oO?= Date: Mon, 30 Dec 2024 20:24:36 +0900 Subject: [PATCH 7/9] Update examples/train_pylate_contrastive.py Co-authored-by: Antoine Chaffin <38869395+NohTow@users.noreply.github.com> --- examples/train_pylate_contrastive.py | 1 - 1 file changed, 1 deletion(-) diff --git a/examples/train_pylate_contrastive.py b/examples/train_pylate_contrastive.py index f67e6fab..bd063d54 100644 --- a/examples/train_pylate_contrastive.py +++ b/examples/train_pylate_contrastive.py @@ -7,7 +7,6 @@ from pylate import evaluation, losses, models, utils -# Add at the start of your train_pylate_contrastive.py # ModernBERT is compiled by default, so there is no need to call compile explicitly and it actually breaks the model # model = torch.compile(model) From bf812437eb5c7a124cbf538e12586f66270aea1c Mon Sep 17 00:00:00 2001 From: Antoine Chaffin Date: Thu, 2 Jan 2025 10:37:30 +0000 Subject: [PATCH 8/9] Renaming and consistency with kd script --- examples/train_pylate_contrastive.py | 110 ++++++++++-------- .../{train_pylate.py => train_pylate_kd.py} | 4 +- 2 files changed, 63 insertions(+), 51 deletions(-) rename examples/{train_pylate.py => train_pylate_kd.py} (99%) diff --git a/examples/train_pylate_contrastive.py b/examples/train_pylate_contrastive.py index bd063d54..1097a77c 100644 --- a/examples/train_pylate_contrastive.py +++ b/examples/train_pylate_contrastive.py @@ -1,65 +1,75 @@ -import torch +# Copyright 2024 onwards Answer.AI, LightOn, and contributors +# License: Apache-2.0 from datasets import load_dataset +from pylate import evaluation, losses, models, utils from sentence_transformers import ( SentenceTransformerTrainer, SentenceTransformerTrainingArguments, ) -from pylate import evaluation, losses, models, utils -# ModernBERT is compiled by default, so there is no need to call compile explicitly and it actually breaks the model -# model = torch.compile(model) +def main(): + # Define model parameters for contrastive training + model_name = "answerdotai/ModernBERT-base" # Choose the pre-trained model you want to use as base + model_shortname = model_name.split("/")[-1] -# Define model parameters for contrastive training -model_name = "answerdotai/ModernBERT-large" # Choose the pre-trained model you want to use as base -batch_size = 64 # Larger batch size often improves results, but requires more memory + batch_size = 64 # Larger batch size often improves results, but requires more memory + lr = 3e-6 + num_train_epochs = 5 # Adjust based on your requirements -num_train_epochs = 5 # Adjust based on your requirements -# Set the run name for logging and output directory -run_name = "contrastive-sigrid-241226" -output_dir = f"output/{run_name}" + # Set the run name for logging and output directory + run_name = f"{model_shortname}-colbert-contrastive-{lr}" + output_dir = f"output/{model_shortname}/{run_name}" -# 1. Here we define our ColBERT model. If not a ColBERT model, will add a linear layer to the base encoder. -model = models.ColBERT(model_name_or_path=model_name) + # 1. Here we define our ColBERT model. If not a ColBERT model, will add a linear layer to the base encoder. + model = models.ColBERT(model_name_or_path=model_name) + # ModernBERT is compiled by default, so there is no need to call compile explicitly and it actually breaks the model + # model = torch.compile(model) -# Load dataset -dataset = load_dataset("sentence-transformers/msmarco-bm25", "triplet", split="train") -# Split the dataset (this dataset does not have a validation set, so we split the training set) -splits = dataset.train_test_split(test_size=0.01) -train_dataset = splits["train"] -eval_dataset = splits["test"] + # Load dataset + dataset = load_dataset("sentence-transformers/msmarco-bm25", "triplet", split="train") + # Split the dataset (this dataset does not have a validation set, so we split the training set) + splits = dataset.train_test_split(test_size=0.01) + train_dataset = splits["train"] + eval_dataset = splits["test"] -# Define the loss function -train_loss = losses.Contrastive(model=model) + # Define the loss function + train_loss = losses.Contrastive(model=model) -# Initialize the evaluator -dev_evaluator = evaluation.ColBERTTripletEvaluator( - anchors=eval_dataset["query"], - positives=eval_dataset["positive"], - negatives=eval_dataset["negative"], -) + # Initialize the evaluator + dev_evaluator = evaluation.ColBERTTripletEvaluator( + anchors=eval_dataset["query"], + positives=eval_dataset["positive"], + negatives=eval_dataset["negative"], + ) -# Configure the training arguments (e.g., batch size, evaluation strategy, logging steps) -args = SentenceTransformerTrainingArguments( - output_dir=output_dir, - num_train_epochs=num_train_epochs, - per_device_train_batch_size=batch_size, - per_device_eval_batch_size=batch_size, - fp16=True, # Set to False if you get an error that your GPU can't run on FP16 - bf16=False, # Set to True if you have a GPU that supports BF16 - run_name=run_name, # Will be used in W&B if `wandb` is installed - learning_rate=3e-6, -) + # Configure the training arguments (e.g., batch size, evaluation strategy, logging steps) + args = SentenceTransformerTrainingArguments( + output_dir=output_dir, + num_train_epochs=num_train_epochs, + per_device_train_batch_size=batch_size, + per_device_eval_batch_size=batch_size, + fp16=True, # Set to False if you get an error that your GPU can't run on FP16 + bf16=False, # Set to True if you have a GPU that supports BF16 + run_name=run_name, # Will be used in W&B if `wandb` is installed + learning_rate=lr, + ) -# Initialize the trainer for the contrastive training -trainer = SentenceTransformerTrainer( - model=model, - args=args, - train_dataset=train_dataset, - eval_dataset=eval_dataset, - loss=train_loss, - evaluator=dev_evaluator, - data_collator=utils.ColBERTCollator(model.tokenize), -) -# Start the training process -trainer.train() + # Initialize the trainer for the contrastive training + trainer = SentenceTransformerTrainer( + model=model, + args=args, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + loss=train_loss, + evaluator=dev_evaluator, + data_collator=utils.ColBERTCollator(model.tokenize), + ) + # Start the training process + trainer.train() + + model.save_pretrained(f"{output_dir}/final") + + +if __name__ == "__main__": + main() diff --git a/examples/train_pylate.py b/examples/train_pylate_kd.py similarity index 99% rename from examples/train_pylate.py rename to examples/train_pylate_kd.py index fa7d46dd..4e2c83d0 100644 --- a/examples/train_pylate.py +++ b/examples/train_pylate_kd.py @@ -8,6 +8,7 @@ SentenceTransformerTrainingArguments, ) + def main(): # Load the datasets required for knowledge distillation (train, queries, documents) train = load_dataset( @@ -76,5 +77,6 @@ def main(): model.save_pretrained(f"{output_dir}/final") + if __name__ == "__main__": - main() \ No newline at end of file + main() From ea42756cee893299c9d61b2822e3a2e422d3463b Mon Sep 17 00:00:00 2001 From: Antoine Chaffin Date: Thu, 2 Jan 2025 13:28:26 +0000 Subject: [PATCH 9/9] Use bf16 --- examples/train_pylate_contrastive.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/examples/train_pylate_contrastive.py b/examples/train_pylate_contrastive.py index 1097a77c..856b70cd 100644 --- a/examples/train_pylate_contrastive.py +++ b/examples/train_pylate_contrastive.py @@ -9,10 +9,10 @@ def main(): - # Define model parameters for contrastive training model_name = "answerdotai/ModernBERT-base" # Choose the pre-trained model you want to use as base model_shortname = model_name.split("/")[-1] + # Define model parameters for contrastive training batch_size = 64 # Larger batch size often improves results, but requires more memory lr = 3e-6 num_train_epochs = 5 # Adjust based on your requirements @@ -49,10 +49,11 @@ def main(): num_train_epochs=num_train_epochs, per_device_train_batch_size=batch_size, per_device_eval_batch_size=batch_size, - fp16=True, # Set to False if you get an error that your GPU can't run on FP16 - bf16=False, # Set to True if you have a GPU that supports BF16 + fp16=False, # Set to False if you get an error that your GPU can't run on FP16 + bf16=True, # Set to True if you have a GPU that supports BF16 run_name=run_name, # Will be used in W&B if `wandb` is installed learning_rate=lr, + logging_steps=100, ) # Initialize the trainer for the contrastive training