diff --git a/data.py b/data.py index 00ecfd8..0aff419 100644 --- a/data.py +++ b/data.py @@ -85,7 +85,14 @@ def load_from_files_list( return dataset # For all datasets -def batch_preparation_img2seq(data): +class BatchCollator: + def __init__(self, pad_token: int = 0): + self.pad_token = pad_token + + def __call__(self, data): + return batch_preparation_img2seq(data, pad_token=self.pad_token) + +def batch_preparation_img2seq(data, pad_token=0): images = [sample[0] for sample in data] dec_in = [sample[1] for sample in data] gt = [sample[2] for sample in data] @@ -101,16 +108,18 @@ def batch_preparation_img2seq(data): max_length_seq = max([len(w) for w in gt]) - decoder_input = torch.zeros(size=[len(dec_in),max_length_seq-1]) # will be removed - y = torch.zeros(size=[len(gt),max_length_seq-1]) # will be removed + decoder_input = torch.full(size=[len(dec_in),max_length_seq-1], fill_value=pad_token, dtype=torch.long) # will be removed + y = torch.full(size=[len(gt),max_length_seq-1], fill_value=pad_token, dtype=torch.long) # will be removed for i, seq in enumerate(dec_in): - decoder_input[i] = torch.from_numpy(np.asarray([char for char in seq[:-1]])) # all tokens but + seq_tensor = torch.as_tensor(seq[:-1]) + decoder_input[i, :len(seq_tensor)] = seq_tensor # all tokens but for i, seq in enumerate(gt): - y[i] = torch.from_numpy(np.asarray([char for char in seq[1:]])) # all tokens but + seq_tensor = torch.as_tensor(seq[1:]) + y[i, :len(seq_tensor)] = seq_tensor # all tokens but - return X_train, decoder_input.long(), y.long() + return X_train, decoder_input, y class OMRIMG2SEQDataset(Dataset): def __init__( @@ -466,13 +475,13 @@ def get_max_length(self) -> int: return max(Tl, vl, tl) def train_dataloader(self): - return torch.utils.data.DataLoader(self.train_set, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=True, collate_fn=batch_preparation_img2seq) + return torch.utils.data.DataLoader(self.train_set, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=True, collate_fn=BatchCollator(self.train_set.padding_token)) def val_dataloader(self): - return torch.utils.data.DataLoader(self.val_set, batch_size=self.batch_size, num_workers=self.num_workers, collate_fn=batch_preparation_img2seq) + return torch.utils.data.DataLoader(self.val_set, batch_size=self.batch_size, num_workers=self.num_workers, collate_fn=BatchCollator(self.train_set.padding_token)) def test_dataloader(self): - return torch.utils.data.DataLoader(self.test_set, batch_size=self.batch_size, num_workers=self.num_workers, collate_fn=batch_preparation_img2seq) + return torch.utils.data.DataLoader(self.test_set, batch_size=self.batch_size, num_workers=self.num_workers, collate_fn=BatchCollator(self.train_set.padding_token)) # Synthetic system-level GrandStaff training # NOTE: Pre-train the SMT on system-level data using this dataset @@ -504,13 +513,13 @@ def get_max_length(self) -> int: return 4360 def train_dataloader(self): - return torch.utils.data.DataLoader(self.train_set, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=True, collate_fn=batch_preparation_img2seq) + return torch.utils.data.DataLoader(self.train_set, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=True, collate_fn=BatchCollator(self.train_set.padding_token)) def val_dataloader(self): - return torch.utils.data.DataLoader(self.val_set, batch_size=self.batch_size, num_workers=self.num_workers, collate_fn=batch_preparation_img2seq) + return torch.utils.data.DataLoader(self.val_set, batch_size=self.batch_size, num_workers=self.num_workers, collate_fn=BatchCollator(self.train_set.padding_token)) def test_dataloader(self): - return torch.utils.data.DataLoader(self.test_set, batch_size=self.batch_size, num_workers=self.num_workers, collate_fn=batch_preparation_img2seq) + return torch.utils.data.DataLoader(self.test_set, batch_size=self.batch_size, num_workers=self.num_workers, collate_fn=BatchCollator(self.train_set.padding_token)) # Synthetic system-to-full-page GrandStaff curriculum training # NOTE: Fine-tune the SMT on page-level data with curriculum learning @@ -555,11 +564,11 @@ def get_max_length(self) -> int: return max(Tl, vl, tl, 4353) def train_dataloader(self): - return torch.utils.data.DataLoader(self.train_set, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=True, collate_fn=batch_preparation_img2seq) - # return torch.utils.data.DataLoader(self.train_set, batch_size=self.batch_size, num_workers=0, shuffle=True, collate_fn=batch_preparation_img2seq) + return torch.utils.data.DataLoader(self.train_set, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=True, collate_fn=BatchCollator(self.train_set.padding_token)) + # return torch.utils.data.DataLoader(self.train_set, batch_size=self.batch_size, num_workers=0, shuffle=True, collate_fn=BatchCollator(self.train_set.padding_token)) def val_dataloader(self): - return torch.utils.data.DataLoader(self.val_set, batch_size=self.batch_size, num_workers=self.num_workers, collate_fn=batch_preparation_img2seq) + return torch.utils.data.DataLoader(self.val_set, batch_size=self.batch_size, num_workers=self.num_workers, collate_fn=BatchCollator(self.train_set.padding_token)) def test_dataloader(self): - return torch.utils.data.DataLoader(self.test_set, batch_size=self.batch_size, num_workers=self.num_workers, collate_fn=batch_preparation_img2seq) + return torch.utils.data.DataLoader(self.test_set, batch_size=self.batch_size, num_workers=self.num_workers, collate_fn=BatchCollator(self.train_set.padding_token)) diff --git a/smt_model/modeling_smt.py b/smt_model/modeling_smt.py index 58c0845..18be876 100644 --- a/smt_model/modeling_smt.py +++ b/smt_model/modeling_smt.py @@ -399,21 +399,35 @@ def forward(self, encoder_input, decoder_input, labels=None): @torch.no_grad def predict(self, input, convert_to_str=False, return_weights=False): - predicted_sequence = torch.from_numpy(np.asarray([self.w2i['']])).to(input.device).unsqueeze(0) + b = input.size(0) + predicted_sequence = torch.full((b, 1), self.w2i[''], dtype=torch.long, device=input.device) encoder_output = self.forward_encoder(input) - text_sequence = [] - for i in range(self.maxlen - predicted_sequence.shape[-1]): + + has_eos = torch.zeros(b, dtype=torch.bool, device=input.device) + eos_id = self.w2i[''] + + for i in range(self.maxlen - predicted_sequence.size(1)): output = self.forward_decoder(encoder_output=encoder_output, last_predictions=predicted_sequence, return_weights=return_weights) - predicted_token = torch.argmax(output.logits[:, -1, :], dim=-1).item() - predicted_sequence = torch.cat([predicted_sequence, torch.argmax(output.logits[:, -1, :], dim=-1, keepdim=True)], dim=1) - if convert_to_str: - predicted_token = f"{predicted_token}" - if self.i2w[predicted_token] == '': + predicted_tokens = torch.argmax(output.logits[:, -1, :], dim=-1, keepdim=True) + predicted_sequence = torch.cat([predicted_sequence, predicted_tokens], dim=1) + + has_eos |= (predicted_tokens.squeeze(1) == eos_id) + if has_eos.all(): break - text_sequence.append(self.i2w[predicted_token]) - return text_sequence, output + text_sequences = [] + for b_idx in range(b): + seq = [] + for token_id in predicted_sequence[b_idx, 1:]: + token_val = str(token_id.item()) if convert_to_str else token_id.item() + token_str = self.i2w.get(token_val, "") + if token_str == '': + break + seq.append(token_str) + text_sequences.append(seq) + + return text_sequences, output def _generate_token_mask(self, token_len, total_size, device): diff --git a/smt_trainer.py b/smt_trainer.py index 8562df3..83095a2 100644 --- a/smt_trainer.py +++ b/smt_trainer.py @@ -53,37 +53,51 @@ def training_step(self, batch): stage = self.stage_calculator(self.global_step) - self.log('loss', loss, on_epoch=True, batch_size=1, prog_bar=True) - self.log("stage", stage, on_epoch=True, prog_bar=True) + self.log('loss', loss, on_epoch=True, batch_size=x.size(0), prog_bar=True) + self.log("stage", stage, on_epoch=True, batch_size=x.size(0), prog_bar=True) return loss def validation_step(self, val_batch): x, _, y = val_batch - predicted_sequence, _ = self.model.predict(input=x) - - dec = "".join(predicted_sequence) - dec = dec.replace("", "\t") - dec = dec.replace("", "\n") - dec = dec.replace("", " ") - - gt = "".join([self.model.i2w[token.item()] for token in y.squeeze(0)[:-1]]) # Remove - gt = gt.replace("", "\t") - gt = gt.replace("", "\n") - gt = gt.replace("", " ") - - self.preds.append(dec) - self.grtrs.append(gt) + + predicted_sequences, _ = self.model.predict(input=x) + + for i, predicted_sequence in enumerate(predicted_sequences): + y_i = y[i] + + dec = "".join(predicted_sequence) + dec = dec.replace("", "\t") + dec = dec.replace("", "\n") + dec = dec.replace("", " ") + + gt_tokens = [] + for token in y_i: # y_i is 1D + token_item = token.item() + token_str = self.model.i2w.get(token_item, "") + if token_str == '': + break + if token_str not in ['', '', '']: + gt_tokens.append(token_str) + + gt = "".join(gt_tokens) + gt = gt.replace("", "\t") + gt = gt.replace("", "\n") + gt = gt.replace("", " ") + + self.preds.append(dec) + self.grtrs.append(gt) def on_validation_epoch_end(self, metric_name="val") -> None: cer, ser, ler = compute_poliphony_metrics(self.preds, self.grtrs) - random_index = random.randint(0, len(self.preds)-1) - predtoshow = self.preds[random_index] - gttoshow = self.grtrs[random_index] - print(f"[Prediction] - {predtoshow}") - print(f"[GT] - {gttoshow}") + if len(self.preds) > 0: + random_index = random.randint(0, len(self.preds)-1) + predtoshow = self.preds[random_index] + gttoshow = self.grtrs[random_index] + print(f"[Prediction] - {predtoshow}") + print(f"[GT] - {gttoshow}") self.log(f'{metric_name}_CER', cer, on_epoch=True, prog_bar=True) self.log(f'{metric_name}_SER', ser, on_epoch=True, prog_bar=True) diff --git a/test/test_data.py b/test/test_data.py new file mode 100644 index 0000000..9b842aa --- /dev/null +++ b/test/test_data.py @@ -0,0 +1,88 @@ +import sys +import os +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +import torch +import numpy as np +from data import batch_preparation_img2seq, BatchCollator + +def test_batch_padding(): + # Simulate batch data + # format: [(image, decoder_input, y), ...] + # image: tensor of shape (1, H, W) + # decoder_input: tensor of seq len (includes and ) + # y: same as decoder_input + + pad_token = 99 + collator = BatchCollator(pad_token=pad_token) + + dec_in_1 = torch.tensor([1, 2, 3, 4, 5]) # e.g. , a, b, c, + dec_in_2 = torch.tensor([1, 2, 5]) # e.g. , a, + + gt_1 = dec_in_1.clone() + gt_2 = dec_in_2.clone() + + img_1 = torch.ones(1, 100, 100) + img_2 = torch.ones(1, 50, 50) + + data = [ + (img_1, dec_in_1, gt_1), + (img_2, dec_in_2, gt_2) + ] + + X_train, decoder_input, y = collator(data) + + print("Decoder input:") + print(decoder_input) + print("Labels (y):") + print(y) + + assert decoder_input.shape == (2, 4) + assert y.shape == (2, 4) + + # sequence 2 should be padded with 99 (pad_token) + assert decoder_input[1, 2].item() == 99 + assert y[1, 2].item() == 99 + + # double check the direct function handles it too + _, decoder_input2, y2 = batch_preparation_img2seq(data, pad_token=77) + assert decoder_input2[1, 2].item() == 77 + assert y2[1, 2].item() == 77 + +def test_batch_validation_loop(): + # Simulated validation_step loop for a batch size of 2 + b = 2 + + # 2 is , 3 is 'a', 4 is 'b', 5 is 'c', 1 is , 0 is + y = torch.tensor([ + [2, 3, 4, 1, 0], + [2, 5, 1, 0, 0] + ]) + + i2w = {0: '', 1: '', 2: '', 3: 'a', 4: 'b', 5: 'c'} + + preds = [] + grtrs = [] + + for i in range(b): + y_i = y[i] + + gt_tokens = [] + for token in y_i: + token_item = token.item() + token_str = i2w.get(token_item, "") + if token_str == '': + break + if token_str not in ['', '', '']: + gt_tokens.append(token_str) + + gt = "".join(gt_tokens) + grtrs.append(gt) + + assert grtrs[0] == "ab", f"Got {grtrs[0]}" + assert grtrs[1] == "c", f"Got {grtrs[1]}" + print("test_batch_validation_loop passed!") + +if __name__ == "__main__": + test_batch_padding() + test_batch_validation_loop()