From 32abc46944f6f8717fd09f01afdc538a9e1d164f Mon Sep 17 00:00:00 2001 From: fedor-grigoryev <66886867+fedor-grigoryev@users.noreply.github.com> Date: Wed, 22 Dec 2021 23:18:29 +0300 Subject: [PATCH] fix --- week8_scst/basic_model_torch.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/week8_scst/basic_model_torch.py b/week8_scst/basic_model_torch.py index fc68ae9..00c9c05 100644 --- a/week8_scst/basic_model_torch.py +++ b/week8_scst/basic_model_torch.py @@ -117,7 +117,7 @@ def translate(self, inp, greedy=False, max_len = None, eps = 1e-30, **flags): def infer_mask(seq, eos_ix, batch_first=True, include_eos=True, type=torch.FloatTensor): """ - compute length given output indices and eos code + compute mask given output indices and eos code :param seq: tf matrix [time,batch] if batch_first else [batch,time] :param eos_ix: integer index of end-of-sentence token :param include_eos: if True, the time-step where eos first occurs is has mask = 1 @@ -136,8 +136,8 @@ def infer_mask(seq, eos_ix, batch_first=True, include_eos=True, type=torch.Float def infer_length(seq, eos_ix, batch_first=True, include_eos=True, type=torch.LongTensor): """ - compute mask given output indices and eos code - :param seq: tf matrix [time,batch] if time_major else [batch,time] + compute length given output indices and eos code + :param seq: tf matrix [time,batch] if batch_first else [batch,time] :param eos_ix: integer index of end-of-sentence token :param include_eos: if True, the time-step where eos first occurs is has mask = 1 :returns: mask, float32 matrix with '0's and '1's of same shape as seq