From c19ca2f380053a3d2d9ca0b4fe31140f2861b95b Mon Sep 17 00:00:00 2001 From: Boris Vassilev Date: Thu, 11 Sep 2025 15:29:19 +0300 Subject: [PATCH] Mistake in trellis evaluation in accordance with the CTC paper Trellis computation has to also account for the probability of repeating the last token not only by emitting a blank, but also by actually repeating it. --- examples/tutorials/forced_alignment_tutorial.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/examples/tutorials/forced_alignment_tutorial.py b/examples/tutorials/forced_alignment_tutorial.py index 7fa7c86dc3..83783b53ed 100644 --- a/examples/tutorials/forced_alignment_tutorial.py +++ b/examples/tutorials/forced_alignment_tutorial.py @@ -146,10 +146,11 @@ def plot(): # # To generate, the probability of time step :math:`t+1`, we look at the # trellis from time step :math:`t` and emission at time step :math:`t+1`. -# There are two path to reach to time step :math:`t+1` with label -# :math:`c_{j+1}`. The first one is the case where the label was +# There are three paths to reach to time step :math:`t+1` with label +# :math:`c_{j+1}`. The first two are the cases where the label was # :math:`c_{j+1}` at :math:`t` and there was no label change from -# :math:`t` to :math:`t+1`. The other case is where the label was +# :math:`t` to :math:`t+1`. For this we use he probability of 'blank' +# emission and repeating the same letter. The last case is where the label was # :math:`c_j` at :math:`t` and it transitioned to the next label # :math:`c_{j+1}` at :math:`t+1`. # @@ -160,7 +161,7 @@ def plot(): # Since we are looking for the most likely transitions, we take the more # likely path for the value of :math:`k_{(t+1, j+1)}`, that is # -# :math:`k_{(t+1, j+1)} = max( k_{(t, j)} p(t+1, c_{j+1}), k_{(t, j+1)} p(t+1, repeat) )` +# :math:`k_{(t+1, j+1)} = max( k_{(t, j)} p(t+1, c_{j+1}), k_{(t, j+1)} p(t+1, blank), k_{(t, j+1)} p(t+1, c_{j+1}))` # # where :math:`k` represents is trellis matrix, and :math:`p(t, c_j)` # represents the probability of label :math:`c_j` at time step :math:`t`. @@ -194,6 +195,10 @@ def get_trellis(emission, tokens, blank_id=0): # Score for changing to the next token trellis[t, :-1] + emission[t, tokens[1:]], ) + trellis[t+1,1:] = torch.maximum( + trellis[t+1,1:], + trellis[t, 1:] + emission[t, tokens[1:]], + ) return trellis