diff --git a/src/transformers/models/prophetnet/tokenization_prophetnet.py b/src/transformers/models/prophetnet/tokenization_prophetnet.py index 24401835c7fc..3017ec66b1e0 100644 --- a/src/transformers/models/prophetnet/tokenization_prophetnet.py +++ b/src/transformers/models/prophetnet/tokenization_prophetnet.py @@ -87,8 +87,11 @@ def tokenize(self, text, never_split=None): Kept for backward compatibility purposes. Now implemented directly at the base class level (see [`PreTrainedTokenizer.tokenize`]) List of token not to split. """ - # union() returns a new set by concatenating the two sets. - never_split = self.never_split.union(set(never_split)) if never_split else self.never_split + # minor optimization: set ops may be expensive on empty, so specialize to avoid allocation + if never_split: + never_split_set = self.never_split.union(set(never_split)) + else: + never_split_set = self.never_split text = self._clean_text(text) # This was added on November 1st, 2018 for the multilingual and Chinese @@ -98,23 +101,44 @@ def tokenize(self, text, never_split=None): # characters in the vocabulary because Wikipedia does have some Chinese # words in the English Wikipedia.). if self.tokenize_chinese_chars: - text = self._tokenize_chinese_chars(text) + if _may_have_chinese_char(text): + text = self._tokenize_chinese_chars(text) + # NFC normalization # prevents treating the same character with different unicode codepoints as different characters unicode_normalized_text = unicodedata.normalize("NFC", text) orig_tokens = whitespace_tokenize(unicode_normalized_text) split_tokens = [] + + # Cache for per-token never_split check, lower, strip, etc + ds_lower_case = self.do_lower_case + ds_strip_accents = self.strip_accents + + # For tight loop and mem efficiency, inline as much as possible + append = split_tokens.append + extend = split_tokens.extend + nsplit = never_split_set + + # Optimize _run_split_on_punc usage by avoiding method attr lookup + split_on_punc = self._run_split_on_punc + run_strip_accents = self._run_strip_accents for token in orig_tokens: - if token not in never_split: - if self.do_lower_case: + # Fast membership test with set + if token not in nsplit: + # Only lowercase and/or strip accents if requested. + if ds_lower_case: token = token.lower() - if self.strip_accents is not False: - token = self._run_strip_accents(token) - elif self.strip_accents: - token = self._run_strip_accents(token) - split_tokens.extend(self._run_split_on_punc(token, never_split)) - - output_tokens = whitespace_tokenize(" ".join(split_tokens)) - return output_tokens + if ds_strip_accents is not False: + token = run_strip_accents(token) + elif ds_strip_accents: + token = run_strip_accents(token) + # This can return multiple tokens; minimize extension overhead + extend(split_on_punc(token, nsplit)) + + # Use generator to avoid creating a large intermediate string if too many split_tokens + # But whitespace_tokenize is already calling .split(), so we must construct the " ".join(split_tokens) + # If split_tokens is usually small, skip generator overhead + output = " ".join(split_tokens) + return whitespace_tokenize(output) def _run_strip_accents(self, text): """Strips accents from a piece of text.""" @@ -199,6 +223,20 @@ def _clean_text(self, text): output.append(char) return "".join(output) + def _is_chinese_char(self, cp): + """Checks whether CP is the codepoint of a CJK character.""" + # Only check blocks as in original implementation for performance + return ( + (0x4E00 <= cp <= 0x9FFF) + or (0x3400 <= cp <= 0x4DBF) + or (0x20000 <= cp <= 0x2A6DF) + or (0x2A700 <= cp <= 0x2B73F) + or (0x2B740 <= cp <= 0x2B81F) + or (0x2B820 <= cp <= 0x2CEAF) + or (0xF900 <= cp <= 0xFAFF) + or (0x2F800 <= cp <= 0x2FA1F) + ) + # Copied from transformers.models.bert.tokenization_bert.WordpieceTokenizer class WordpieceTokenizer: @@ -224,24 +262,34 @@ def tokenize(self, text): A list of wordpiece tokens. """ + # whitespace_tokenize is already fast, keep as-is for safe behavior + tokens = whitespace_tokenize(text) output_tokens = [] - for token in whitespace_tokenize(text): - chars = list(token) - if len(chars) > self.max_input_chars_per_word: - output_tokens.append(self.unk_token) + + vocab = self.vocab + max_len = self.max_input_chars_per_word + unk_token = self.unk_token + + # Avoid global attr lookups, static inline for tiny perf gain + for token in tokens: + chars = token + if len(chars) > max_len: + output_tokens.append(unk_token) continue is_bad = False start = 0 sub_tokens = [] - while start < len(chars): - end = len(chars) + L = len(chars) + while start < L: + end = L cur_substr = None while start < end: - substr = "".join(chars[start:end]) + substr = chars[start:end] + # Only prepend "##" if it's not the first token if start > 0: substr = "##" + substr - if substr in self.vocab: + if substr in vocab: cur_substr = substr break end -= 1 @@ -252,7 +300,7 @@ def tokenize(self, text): start = end if is_bad: - output_tokens.append(self.unk_token) + output_tokens.append(unk_token) else: output_tokens.extend(sub_tokens) return output_tokens @@ -269,6 +317,24 @@ def load_vocab(vocab_file): return vocab +# micro-optimization: quickly scan for CJK codepoints, short-circuit if none present +def _may_have_chinese_char(text): + for t in text: + c = ord(t) + if ( + (0x4E00 <= c <= 0x9FFF) + or (0x3400 <= c <= 0x4DBF) + or (0x20000 <= c <= 0x2A6DF) + or (0x2A700 <= c <= 0x2B73F) + or (0x2B740 <= c <= 0x2B81F) + or (0x2B820 <= c <= 0x2CEAF) + or (0xF900 <= c <= 0xFAFF) + or (0x2F800 <= c <= 0x2FA1F) + ): + return True + return False + + class ProphetNetTokenizer(PreTrainedTokenizer): r""" Construct a ProphetNetTokenizer. Based on WordPiece. @@ -377,17 +443,24 @@ def get_vocab(self): return dict(self.vocab, **self.added_tokens_encoder) def _tokenize(self, text): - split_tokens = [] - if self.do_basic_tokenize: - for token in self.basic_tokenizer.tokenize(text, never_split=self.all_special_tokens): - # If the token is part of the never_split set - if token in self.basic_tokenizer.never_split: + # Preallocate local variables for critical logic + do_basic = self.do_basic_tokenize + basic_tok = getattr(self, "basic_tokenizer", None) + wordpiece_tok = self.wordpiece_tokenizer + all_special_tokens = self.all_special_tokens + + if do_basic: + nsplit = basic_tok.never_split + split_tokens = [] + # Avoid repeated attr lookup and list ops inside the loop + for token in basic_tok.tokenize(text, never_split=all_special_tokens): + if token in nsplit: split_tokens.append(token) else: - split_tokens += self.wordpiece_tokenizer.tokenize(token) + split_tokens.extend(wordpiece_tok.tokenize(token)) + return split_tokens else: - split_tokens = self.wordpiece_tokenizer.tokenize(text) - return split_tokens + return wordpiece_tok.tokenize(text) def _convert_token_to_id(self, token: str): """Converts a token (str) in an id using the vocab."""