diff --git a/pretrain/pretrain_helpers.py b/pretrain/pretrain_helpers.py index 7329bec..a60f471 100644 --- a/pretrain/pretrain_helpers.py +++ b/pretrain/pretrain_helpers.py @@ -118,7 +118,7 @@ def scatter_update(sequence, updates, positions): def _get_candidates_mask(inputs: pretrain_data.Inputs, vocab, disallow_from_mask=None): """Returns a mask tensor of positions in the input that can be masked out.""" - ignore_ids = [vocab["[SEP]"], vocab["[CLS]"], vocab["[MASK]"]] + ignore_ids = [vocab["[SEP]"], vocab["[CLS]"], vocab["[MASK]"], vocab["[PAD]"]] candidates_mask = tf.ones_like(inputs.input_ids, tf.bool) for ignore_id in ignore_ids: candidates_mask &= tf.not_equal(inputs.input_ids, ignore_id)