diff --git a/src/transformers/models/prophetnet/tokenization_prophetnet.py b/src/transformers/models/prophetnet/tokenization_prophetnet.py index 24401835c7fc..09e81be8113a 100644 --- a/src/transformers/models/prophetnet/tokenization_prophetnet.py +++ b/src/transformers/models/prophetnet/tokenization_prophetnet.py @@ -34,8 +34,8 @@ def whitespace_tokenize(text): text = text.strip() if not text: return [] - tokens = text.split() - return tokens + # Avoid intermediate variable, just return directly + return text.split() # Copied from transformers.models.bert.tokenization_bert.BasicTokenizer @@ -78,6 +78,10 @@ def __init__( self.strip_accents = strip_accents self.do_split_on_punc = do_split_on_punc + # Preallocate commonly used method as local attribute for performance + self._unicodedata_normalize = unicodedata.normalize + self._unicodedata_category = unicodedata.category + def tokenize(self, text, never_split=None): """ Basic Tokenization of a piece of text. For sub-word tokenization, see WordPieceTokenizer. @@ -99,67 +103,88 @@ def tokenize(self, text, never_split=None): # words in the English Wikipedia.). if self.tokenize_chinese_chars: text = self._tokenize_chinese_chars(text) - # prevents treating the same character with different unicode codepoints as different characters - unicode_normalized_text = unicodedata.normalize("NFC", text) + + # Using local var to avoid attribute lookup in loop/hot path + unicodedata_normalize = self._unicodedata_normalize + unicode_normalized_text = unicodedata_normalize("NFC", text) + orig_tokens = whitespace_tokenize(unicode_normalized_text) split_tokens = [] + never_split_contains = never_split.__contains__ + + # Cache method references for hot path + do_lower_case = self.do_lower_case + strip_accents = self.strip_accents + _run_strip_accents = self._run_strip_accents + _run_split_on_punc = self._run_split_on_punc + for token in orig_tokens: - if token not in never_split: - if self.do_lower_case: + if not never_split_contains(token): + if do_lower_case: token = token.lower() - if self.strip_accents is not False: - token = self._run_strip_accents(token) - elif self.strip_accents: - token = self._run_strip_accents(token) - split_tokens.extend(self._run_split_on_punc(token, never_split)) - - output_tokens = whitespace_tokenize(" ".join(split_tokens)) + if strip_accents is not False: + token = _run_strip_accents(token) + elif strip_accents: + token = _run_strip_accents(token) + split_tokens.extend(_run_split_on_punc(token, never_split)) + + # Avoid recreating the string if single token + if len(split_tokens) == 1: + output_tokens = whitespace_tokenize(split_tokens[0]) + else: + output_tokens = whitespace_tokenize(" ".join(split_tokens)) return output_tokens def _run_strip_accents(self, text): """Strips accents from a piece of text.""" - text = unicodedata.normalize("NFD", text) + unicodedata_normalize = self._unicodedata_normalize + unicodedata_category = self._unicodedata_category + + # Optimize by avoiding attribute lookups + text = unicodedata_normalize("NFD", text) output = [] + append = output.append + # Fast-path for ASCII for char in text: - cat = unicodedata.category(char) - if cat == "Mn": + if unicodedata_category(char) == "Mn": continue - output.append(char) + append(char) return "".join(output) def _run_split_on_punc(self, text, never_split=None): """Splits punctuation on a piece of text.""" if not self.do_split_on_punc or (never_split is not None and text in never_split): return [text] - chars = list(text) - i = 0 - start_new_word = True + + # Optimize by avoiding creating a list of all characters + _is_punctuation_local = _is_punctuation output = [] - while i < len(chars): - char = chars[i] - if _is_punctuation(char): - output.append([char]) - start_new_word = True - else: - if start_new_word: - output.append([]) - start_new_word = False - output[-1].append(char) - i += 1 + current_word = [] - return ["".join(x) for x in output] + for char in text: + if _is_punctuation_local(char): + if current_word: + output.append("".join(current_word)) + current_word = [] + output.append(char) + else: + current_word.append(char) + if current_word: + output.append("".join(current_word)) + return output def _tokenize_chinese_chars(self, text): """Adds whitespace around any CJK character.""" output = [] + append = output.append + _is_chinese_char = self._is_chinese_char for char in text: - cp = ord(char) - if self._is_chinese_char(cp): - output.append(" ") - output.append(char) - output.append(" ") + if _is_chinese_char(ord(char)): + append(" ") + append(char) + append(" ") else: - output.append(char) + append(char) return "".join(output) def _is_chinese_char(self, cp): @@ -189,14 +214,17 @@ def _is_chinese_char(self, cp): def _clean_text(self, text): """Performs invalid character removal and whitespace cleanup on text.""" output = [] + append = output.append + _is_control_local = _is_control + _is_whitespace_local = _is_whitespace for char in text: cp = ord(char) - if cp == 0 or cp == 0xFFFD or _is_control(char): + if cp == 0 or cp == 0xFFFD or _is_control_local(char): continue - if _is_whitespace(char): - output.append(" ") + if _is_whitespace_local(char): + append(" ") else: - output.append(char) + append(char) return "".join(output)