Skip to content

Train or finetune a new decoder with a different tokenizer #81

@GouChuan

Description

@GouChuan
import torch
import torch.nn.functional as F
# a helper function for padding a batch of tokenized texts
from fairseq2.nn.padding import pad_seqs
from sonar.inference_pipelines.text import TextToEmbeddingModelPipeline, EmbeddingToTextModelPipeline

# define the loss computation function

def get_decoder_loss(
    decoder, 
    batch_tokens, 
    batch_embs,
):
    """
    Compute the cross entropy loss for each sentence in the batch (non-normalized), 
    and return per-sentence losses alongside with sentence lengths (for optional normalization).
    """ 
    assert int(batch_tokens[0][0]) == 3, "EOS TOKEN MUST BE PREPENDED WHEN TRAINING A DECODER"
    
    loss_fn = torch.nn.CrossEntropyLoss(ignore_index=-100, reduction="none")
    
    # prepare the batch for the model
    padded, mask = pad_seqs(batch_tokens)
    
    device = next(decoder.parameters()).device
    padded = padded.to(device)
    if mask is not None:
        mask = mask.to(device)
    batch_embs = batch_embs.to(device)
    
    # feed the batch to the model in three steps (embeddings + decoder body + output projection)
    seqs, padding_mask = decoder.decoder_frontend(
        padded, padding_mask=mask,
    )
    decoder_output, decoder_padding_mask = decoder.decoder(
        seqs,
        mask,
        encoder_output=batch_embs.unsqueeze(1),
    )
    logits = decoder.final_proj(decoder_output)
    # the "targets" are all tokens except the first one (beginning-of-sentence)
    labels = padded[:, 1:].clone()
    # make the loss ignore the padding tokens
    labels[labels==0] = -100
    # make the loss ignore the first label (which is always the language tag; it doesn't have to be predicted)
    labels[:, 0] = -100
    loss = loss_fn(logits[:, :-1].reshape(-1, logits.size(-1)), labels.view(-1))
    per_token_loss = loss.view(logits[:, :-1].shape[:2])
    # per-sentence loss is the sum of its per-token losses
    per_sent_loss = per_token_loss.sum(-1)
    # we also compute the number of tokens, so that we could normalize the total loss by the total number of tokens
    per_sent_toks = (labels !=-100).sum(1)
    
    return per_sent_loss, per_sent_toks

# try it with a sample batch

enc = TextToEmbeddingModelPipeline(
    encoder="text_sonar_basic_encoder", 
    tokenizer="text_sonar_basic_encoder", 
    device=torch.device("cuda"),
)
dec = EmbeddingToTextModelPipeline(
    decoder="text_sonar_basic_decoder", 
    tokenizer="text_sonar_basic_encoder", 
    device=torch.device("cuda"),
)

batch_text = [
    "hello world",
    "hello",
    "hello world. my name is jeff",
]
target_lang = "eng_Latn"

# a list of integer tensors (token ids) of different lengths
batch_text_tokenized = [dec.tokenizer.create_encoder(mode='target', lang=target_lang)(text) for text in batch_text]

# 3*1024 matrix
batch_embs = enc.predict(batch_text, source_lang = "eng_Latn")  

# compute the losses!
with torch.inference_mode():
    losses, n_toks = get_decoder_loss(dec.model.decoder, batch_text_tokenized, batch_embs)
print(losses)  
# tensor([0.2440, 2.6020, 3.5046], device='cuda:0')
print(n_toks)
# tensor([3, 2, 9], device='cuda:0')

# If you are interested in the average per-token loss (which is normally optimized during training and is directly related to text perplexity), 
# you can compute it by adding up all the sentence losses and dividing them by the total number of tokens:
avg_loss = losses.sum() / n_toks.sum()
print(avg_loss) 
# tensor(0.4536, device='cuda:0')

Dear Authors,

I hope this message finds you well. I would like to obtain a decoder that uses a tokenizer different from the original one. Could you please advise on how I should train or fine-tune the model in this case?

Would it be appropriate to directly replace the tokenizer in the provided demo with my own tokenizer, or would additional modifications be necessary?

I would greatly appreciate any guidance you could offer.

Thank you very much for your time and assistance.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions