From 795fb2cbc305b275a9cef16aa19a701a4f0eeb72 Mon Sep 17 00:00:00 2001 From: iejMac Date: Fri, 2 Sep 2022 06:19:00 +0000 Subject: [PATCH] training: optional transformation to embeddings prior to pooling --- src/training/train.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/src/training/train.py b/src/training/train.py index 53fb363..5e314bc 100644 --- a/src/training/train.py +++ b/src/training/train.py @@ -9,6 +9,8 @@ from training.loss import ClipLoss from .distributed import is_master +from open_clip import tokenize + def train_one_epoch(model_video, model_text, logit_scale, data, epoch, optimizer, scheduler, args, tb_writer=None): device = torch.device(args.device) @@ -24,6 +26,10 @@ def train_one_epoch(model_video, model_text, logit_scale, data, epoch, optimizer dataloader = data["train"].dataloader num_batches_per_epoch = dataloader.num_batches + with torch.no_grad(): + shutterstock_embs = model_text(tokenize(["shutterstock"]).to(device)) + shutterstock_embs = torch.cat([shutterstock_embs]*args.sequence_length) + running_loss = 0.0 for i, batch in enumerate(dataloader): step = num_batches_per_epoch * epoch + i @@ -31,6 +37,7 @@ def train_one_epoch(model_video, model_text, logit_scale, data, epoch, optimizer embeddings, toks = batch embeddings = embeddings.to(device, non_blocking=True) + embeddings -= shutterstock_embs toks = toks.to(device, non_blocking=True) optimizer.zero_grad() @@ -76,6 +83,10 @@ def evaluate(model_video, model_text, logit_scale, data, epoch, args, tb_writer= dataloader = data["val"].dataloader model_video.eval() + with torch.no_grad(): + shutterstock_embs = model_text(tokenize(["shutterstock"]).to(device)) + shutterstock_embs = torch.cat([shutterstock_embs]*args.sequence_length) + metrics["val_loss"] = 0.0 loss_func = ClipLoss( local_loss=False, @@ -93,6 +104,7 @@ def evaluate(model_video, model_text, logit_scale, data, epoch, args, tb_writer= embeddings, toks = batch embeddings = embeddings.to(device, non_blocking=True) + embeddings -= shutterstock_embs toks = toks.to(device, non_blocking=True) video_embeddings = model_video(embeddings, None)