Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 25 additions & 16 deletions data.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,14 @@ def load_from_files_list(
return dataset

# For all datasets
def batch_preparation_img2seq(data):
class BatchCollator:
def __init__(self, pad_token: int = 0):
self.pad_token = pad_token

def __call__(self, data):
return batch_preparation_img2seq(data, pad_token=self.pad_token)

def batch_preparation_img2seq(data, pad_token=0):
images = [sample[0] for sample in data]
dec_in = [sample[1] for sample in data]
gt = [sample[2] for sample in data]
Expand All @@ -101,16 +108,18 @@ def batch_preparation_img2seq(data):

max_length_seq = max([len(w) for w in gt])

decoder_input = torch.zeros(size=[len(dec_in),max_length_seq-1]) # <eos> will be removed
y = torch.zeros(size=[len(gt),max_length_seq-1]) # <bos> will be removed
decoder_input = torch.full(size=[len(dec_in),max_length_seq-1], fill_value=pad_token, dtype=torch.long) # <eos> will be removed
y = torch.full(size=[len(gt),max_length_seq-1], fill_value=pad_token, dtype=torch.long) # <bos> will be removed

for i, seq in enumerate(dec_in):
decoder_input[i] = torch.from_numpy(np.asarray([char for char in seq[:-1]])) # all tokens but <eos>
seq_tensor = torch.as_tensor(seq[:-1])
decoder_input[i, :len(seq_tensor)] = seq_tensor # all tokens but <eos>

for i, seq in enumerate(gt):
y[i] = torch.from_numpy(np.asarray([char for char in seq[1:]])) # all tokens but <bos>
seq_tensor = torch.as_tensor(seq[1:])
y[i, :len(seq_tensor)] = seq_tensor # all tokens but <bos>

return X_train, decoder_input.long(), y.long()
return X_train, decoder_input, y

class OMRIMG2SEQDataset(Dataset):
def __init__(
Expand Down Expand Up @@ -466,13 +475,13 @@ def get_max_length(self) -> int:
return max(Tl, vl, tl)

def train_dataloader(self):
return torch.utils.data.DataLoader(self.train_set, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=True, collate_fn=batch_preparation_img2seq)
return torch.utils.data.DataLoader(self.train_set, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=True, collate_fn=BatchCollator(self.train_set.padding_token))

def val_dataloader(self):
return torch.utils.data.DataLoader(self.val_set, batch_size=self.batch_size, num_workers=self.num_workers, collate_fn=batch_preparation_img2seq)
return torch.utils.data.DataLoader(self.val_set, batch_size=self.batch_size, num_workers=self.num_workers, collate_fn=BatchCollator(self.train_set.padding_token))

def test_dataloader(self):
return torch.utils.data.DataLoader(self.test_set, batch_size=self.batch_size, num_workers=self.num_workers, collate_fn=batch_preparation_img2seq)
return torch.utils.data.DataLoader(self.test_set, batch_size=self.batch_size, num_workers=self.num_workers, collate_fn=BatchCollator(self.train_set.padding_token))

# Synthetic system-level GrandStaff training
# NOTE: Pre-train the SMT on system-level data using this dataset
Expand Down Expand Up @@ -504,13 +513,13 @@ def get_max_length(self) -> int:
return 4360

def train_dataloader(self):
return torch.utils.data.DataLoader(self.train_set, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=True, collate_fn=batch_preparation_img2seq)
return torch.utils.data.DataLoader(self.train_set, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=True, collate_fn=BatchCollator(self.train_set.padding_token))

def val_dataloader(self):
return torch.utils.data.DataLoader(self.val_set, batch_size=self.batch_size, num_workers=self.num_workers, collate_fn=batch_preparation_img2seq)
return torch.utils.data.DataLoader(self.val_set, batch_size=self.batch_size, num_workers=self.num_workers, collate_fn=BatchCollator(self.train_set.padding_token))

def test_dataloader(self):
return torch.utils.data.DataLoader(self.test_set, batch_size=self.batch_size, num_workers=self.num_workers, collate_fn=batch_preparation_img2seq)
return torch.utils.data.DataLoader(self.test_set, batch_size=self.batch_size, num_workers=self.num_workers, collate_fn=BatchCollator(self.train_set.padding_token))

# Synthetic system-to-full-page GrandStaff curriculum training
# NOTE: Fine-tune the SMT on page-level data with curriculum learning
Expand Down Expand Up @@ -555,11 +564,11 @@ def get_max_length(self) -> int:
return max(Tl, vl, tl, 4353)

def train_dataloader(self):
return torch.utils.data.DataLoader(self.train_set, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=True, collate_fn=batch_preparation_img2seq)
# return torch.utils.data.DataLoader(self.train_set, batch_size=self.batch_size, num_workers=0, shuffle=True, collate_fn=batch_preparation_img2seq)
return torch.utils.data.DataLoader(self.train_set, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=True, collate_fn=BatchCollator(self.train_set.padding_token))
# return torch.utils.data.DataLoader(self.train_set, batch_size=self.batch_size, num_workers=0, shuffle=True, collate_fn=BatchCollator(self.train_set.padding_token))

def val_dataloader(self):
return torch.utils.data.DataLoader(self.val_set, batch_size=self.batch_size, num_workers=self.num_workers, collate_fn=batch_preparation_img2seq)
return torch.utils.data.DataLoader(self.val_set, batch_size=self.batch_size, num_workers=self.num_workers, collate_fn=BatchCollator(self.train_set.padding_token))

def test_dataloader(self):
return torch.utils.data.DataLoader(self.test_set, batch_size=self.batch_size, num_workers=self.num_workers, collate_fn=batch_preparation_img2seq)
return torch.utils.data.DataLoader(self.test_set, batch_size=self.batch_size, num_workers=self.num_workers, collate_fn=BatchCollator(self.train_set.padding_token))
34 changes: 24 additions & 10 deletions smt_model/modeling_smt.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,21 +399,35 @@ def forward(self, encoder_input, decoder_input, labels=None):

@torch.no_grad
def predict(self, input, convert_to_str=False, return_weights=False):
predicted_sequence = torch.from_numpy(np.asarray([self.w2i['<bos>']])).to(input.device).unsqueeze(0)
b = input.size(0)
predicted_sequence = torch.full((b, 1), self.w2i['<bos>'], dtype=torch.long, device=input.device)
encoder_output = self.forward_encoder(input)
text_sequence = []
for i in range(self.maxlen - predicted_sequence.shape[-1]):

has_eos = torch.zeros(b, dtype=torch.bool, device=input.device)
eos_id = self.w2i['<eos>']

for i in range(self.maxlen - predicted_sequence.size(1)):
output = self.forward_decoder(encoder_output=encoder_output, last_predictions=predicted_sequence,
return_weights=return_weights)
predicted_token = torch.argmax(output.logits[:, -1, :], dim=-1).item()
predicted_sequence = torch.cat([predicted_sequence, torch.argmax(output.logits[:, -1, :], dim=-1, keepdim=True)], dim=1)
if convert_to_str:
predicted_token = f"{predicted_token}"
if self.i2w[predicted_token] == '<eos>':
predicted_tokens = torch.argmax(output.logits[:, -1, :], dim=-1, keepdim=True)
predicted_sequence = torch.cat([predicted_sequence, predicted_tokens], dim=1)

has_eos |= (predicted_tokens.squeeze(1) == eos_id)
if has_eos.all():
break
text_sequence.append(self.i2w[predicted_token])

return text_sequence, output
text_sequences = []
for b_idx in range(b):
seq = []
for token_id in predicted_sequence[b_idx, 1:]:
token_val = str(token_id.item()) if convert_to_str else token_id.item()
token_str = self.i2w.get(token_val, "")
if token_str == '<eos>':
break
seq.append(token_str)
text_sequences.append(seq)

return text_sequences, output


def _generate_token_mask(self, token_len, total_size, device):
Expand Down
56 changes: 35 additions & 21 deletions smt_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,37 +53,51 @@ def training_step(self, batch):

stage = self.stage_calculator(self.global_step)

self.log('loss', loss, on_epoch=True, batch_size=1, prog_bar=True)
self.log("stage", stage, on_epoch=True, prog_bar=True)
self.log('loss', loss, on_epoch=True, batch_size=x.size(0), prog_bar=True)
self.log("stage", stage, on_epoch=True, batch_size=x.size(0), prog_bar=True)

return loss


def validation_step(self, val_batch):
x, _, y = val_batch
predicted_sequence, _ = self.model.predict(input=x)

dec = "".join(predicted_sequence)
dec = dec.replace("<t>", "\t")
dec = dec.replace("<b>", "\n")
dec = dec.replace("<s>", " ")

gt = "".join([self.model.i2w[token.item()] for token in y.squeeze(0)[:-1]]) # Remove <eos>
gt = gt.replace("<t>", "\t")
gt = gt.replace("<b>", "\n")
gt = gt.replace("<s>", " ")

self.preds.append(dec)
self.grtrs.append(gt)

predicted_sequences, _ = self.model.predict(input=x)

for i, predicted_sequence in enumerate(predicted_sequences):
y_i = y[i]

dec = "".join(predicted_sequence)
dec = dec.replace("<t>", "\t")
dec = dec.replace("<b>", "\n")
dec = dec.replace("<s>", " ")

gt_tokens = []
for token in y_i: # y_i is 1D
token_item = token.item()
token_str = self.model.i2w.get(token_item, "")
if token_str == '<eos>':
break
if token_str not in ['<pad>', '<bos>', '']:
gt_tokens.append(token_str)

gt = "".join(gt_tokens)
gt = gt.replace("<t>", "\t")
gt = gt.replace("<b>", "\n")
gt = gt.replace("<s>", " ")

self.preds.append(dec)
self.grtrs.append(gt)

def on_validation_epoch_end(self, metric_name="val") -> None:
cer, ser, ler = compute_poliphony_metrics(self.preds, self.grtrs)

random_index = random.randint(0, len(self.preds)-1)
predtoshow = self.preds[random_index]
gttoshow = self.grtrs[random_index]
print(f"[Prediction] - {predtoshow}")
print(f"[GT] - {gttoshow}")
if len(self.preds) > 0:
random_index = random.randint(0, len(self.preds)-1)
predtoshow = self.preds[random_index]
gttoshow = self.grtrs[random_index]
print(f"[Prediction] - {predtoshow}")
print(f"[GT] - {gttoshow}")

self.log(f'{metric_name}_CER', cer, on_epoch=True, prog_bar=True)
self.log(f'{metric_name}_SER', ser, on_epoch=True, prog_bar=True)
Expand Down
88 changes: 88 additions & 0 deletions test/test_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
import sys
import os
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

import torch
import numpy as np
from data import batch_preparation_img2seq, BatchCollator

def test_batch_padding():
# Simulate batch data
# format: [(image, decoder_input, y), ...]
# image: tensor of shape (1, H, W)
# decoder_input: tensor of seq len (includes <bos> and <eos>)
# y: same as decoder_input

pad_token = 99
collator = BatchCollator(pad_token=pad_token)

dec_in_1 = torch.tensor([1, 2, 3, 4, 5]) # e.g. <bos>, a, b, c, <eos>
dec_in_2 = torch.tensor([1, 2, 5]) # e.g. <bos>, a, <eos>

gt_1 = dec_in_1.clone()
gt_2 = dec_in_2.clone()

img_1 = torch.ones(1, 100, 100)
img_2 = torch.ones(1, 50, 50)

data = [
(img_1, dec_in_1, gt_1),
(img_2, dec_in_2, gt_2)
]

X_train, decoder_input, y = collator(data)

print("Decoder input:")
print(decoder_input)
print("Labels (y):")
print(y)

assert decoder_input.shape == (2, 4)
assert y.shape == (2, 4)

# sequence 2 should be padded with 99 (pad_token)
assert decoder_input[1, 2].item() == 99
assert y[1, 2].item() == 99

# double check the direct function handles it too
_, decoder_input2, y2 = batch_preparation_img2seq(data, pad_token=77)
assert decoder_input2[1, 2].item() == 77
assert y2[1, 2].item() == 77

def test_batch_validation_loop():
# Simulated validation_step loop for a batch size of 2
b = 2

# 2 is <bos>, 3 is 'a', 4 is 'b', 5 is 'c', 1 is <eos>, 0 is <pad>
y = torch.tensor([
[2, 3, 4, 1, 0],
[2, 5, 1, 0, 0]
])

i2w = {0: '<pad>', 1: '<eos>', 2: '<bos>', 3: 'a', 4: 'b', 5: 'c'}

preds = []
grtrs = []

for i in range(b):
y_i = y[i]

gt_tokens = []
for token in y_i:
token_item = token.item()
token_str = i2w.get(token_item, "")
if token_str == '<eos>':
break
if token_str not in ['<pad>', '<bos>', '']:
gt_tokens.append(token_str)

gt = "".join(gt_tokens)
grtrs.append(gt)

assert grtrs[0] == "ab", f"Got {grtrs[0]}"
assert grtrs[1] == "c", f"Got {grtrs[1]}"
print("test_batch_validation_loop passed!")

if __name__ == "__main__":
test_batch_padding()
test_batch_validation_loop()