diff --git a/src/transformers/models/prophetnet/tokenization_prophetnet.py b/src/transformers/models/prophetnet/tokenization_prophetnet.py index 24401835c7fc..efac7145d675 100644 --- a/src/transformers/models/prophetnet/tokenization_prophetnet.py +++ b/src/transformers/models/prophetnet/tokenization_prophetnet.py @@ -354,6 +354,11 @@ def __init__( ) self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab, unk_token=str(unk_token)) + # Cache unknown token id for use in _convert_token_to_id + self._unk_token_id = self.vocab.get( + self.unk_token if hasattr(self, "unk_token") and self.unk_token is not None else unk_token + ) + super().__init__( do_lower_case=do_lower_case, do_basic_tokenize=do_basic_tokenize, @@ -391,7 +396,11 @@ def _tokenize(self, text): def _convert_token_to_id(self, token: str): """Converts a token (str) in an id using the vocab.""" - return self.vocab.get(token, self.vocab.get(self.unk_token)) + # Optimize the fallback by caching unk id in __init__ to avoid repeated dictionary lookups + id_ = self.vocab.get(token) + if id_ is not None: + return id_ + return self._unk_token_id def _convert_id_to_token(self, index: int): """Converts an index (integer) in a token (str) using the vocab."""