Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 103 additions & 30 deletions src/transformers/models/prophetnet/tokenization_prophetnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,11 @@ def tokenize(self, text, never_split=None):
Kept for backward compatibility purposes. Now implemented directly at the base class level (see
[`PreTrainedTokenizer.tokenize`]) List of token not to split.
"""
# union() returns a new set by concatenating the two sets.
never_split = self.never_split.union(set(never_split)) if never_split else self.never_split
# minor optimization: set ops may be expensive on empty, so specialize to avoid allocation
if never_split:
never_split_set = self.never_split.union(set(never_split))
else:
never_split_set = self.never_split
text = self._clean_text(text)

# This was added on November 1st, 2018 for the multilingual and Chinese
Expand All @@ -98,23 +101,44 @@ def tokenize(self, text, never_split=None):
# characters in the vocabulary because Wikipedia does have some Chinese
# words in the English Wikipedia.).
if self.tokenize_chinese_chars:
text = self._tokenize_chinese_chars(text)
if _may_have_chinese_char(text):
text = self._tokenize_chinese_chars(text)
# NFC normalization
# prevents treating the same character with different unicode codepoints as different characters
unicode_normalized_text = unicodedata.normalize("NFC", text)
orig_tokens = whitespace_tokenize(unicode_normalized_text)
split_tokens = []

# Cache for per-token never_split check, lower, strip, etc
ds_lower_case = self.do_lower_case
ds_strip_accents = self.strip_accents

# For tight loop and mem efficiency, inline as much as possible
append = split_tokens.append
extend = split_tokens.extend
nsplit = never_split_set

# Optimize _run_split_on_punc usage by avoiding method attr lookup
split_on_punc = self._run_split_on_punc
run_strip_accents = self._run_strip_accents
for token in orig_tokens:
if token not in never_split:
if self.do_lower_case:
# Fast membership test with set
if token not in nsplit:
# Only lowercase and/or strip accents if requested.
if ds_lower_case:
token = token.lower()
if self.strip_accents is not False:
token = self._run_strip_accents(token)
elif self.strip_accents:
token = self._run_strip_accents(token)
split_tokens.extend(self._run_split_on_punc(token, never_split))

output_tokens = whitespace_tokenize(" ".join(split_tokens))
return output_tokens
if ds_strip_accents is not False:
token = run_strip_accents(token)
elif ds_strip_accents:
token = run_strip_accents(token)
# This can return multiple tokens; minimize extension overhead
extend(split_on_punc(token, nsplit))

# Use generator to avoid creating a large intermediate string if too many split_tokens
# But whitespace_tokenize is already calling .split(), so we must construct the " ".join(split_tokens)
# If split_tokens is usually small, skip generator overhead
output = " ".join(split_tokens)
return whitespace_tokenize(output)

def _run_strip_accents(self, text):
"""Strips accents from a piece of text."""
Expand Down Expand Up @@ -199,6 +223,20 @@ def _clean_text(self, text):
output.append(char)
return "".join(output)

def _is_chinese_char(self, cp):
"""Checks whether CP is the codepoint of a CJK character."""
# Only check blocks as in original implementation for performance
return (
(0x4E00 <= cp <= 0x9FFF)
or (0x3400 <= cp <= 0x4DBF)
or (0x20000 <= cp <= 0x2A6DF)
or (0x2A700 <= cp <= 0x2B73F)
or (0x2B740 <= cp <= 0x2B81F)
or (0x2B820 <= cp <= 0x2CEAF)
or (0xF900 <= cp <= 0xFAFF)
or (0x2F800 <= cp <= 0x2FA1F)
)


# Copied from transformers.models.bert.tokenization_bert.WordpieceTokenizer
class WordpieceTokenizer:
Expand All @@ -224,24 +262,34 @@ def tokenize(self, text):
A list of wordpiece tokens.
"""

# whitespace_tokenize is already fast, keep as-is for safe behavior
tokens = whitespace_tokenize(text)
output_tokens = []
for token in whitespace_tokenize(text):
chars = list(token)
if len(chars) > self.max_input_chars_per_word:
output_tokens.append(self.unk_token)

vocab = self.vocab
max_len = self.max_input_chars_per_word
unk_token = self.unk_token

# Avoid global attr lookups, static inline for tiny perf gain
for token in tokens:
chars = token
if len(chars) > max_len:
output_tokens.append(unk_token)
continue

is_bad = False
start = 0
sub_tokens = []
while start < len(chars):
end = len(chars)
L = len(chars)
while start < L:
end = L
cur_substr = None
while start < end:
substr = "".join(chars[start:end])
substr = chars[start:end]
# Only prepend "##" if it's not the first token
if start > 0:
substr = "##" + substr
if substr in self.vocab:
if substr in vocab:
cur_substr = substr
break
end -= 1
Expand All @@ -252,7 +300,7 @@ def tokenize(self, text):
start = end

if is_bad:
output_tokens.append(self.unk_token)
output_tokens.append(unk_token)
else:
output_tokens.extend(sub_tokens)
return output_tokens
Expand All @@ -269,6 +317,24 @@ def load_vocab(vocab_file):
return vocab


# micro-optimization: quickly scan for CJK codepoints, short-circuit if none present
def _may_have_chinese_char(text):
for t in text:
c = ord(t)
if (
(0x4E00 <= c <= 0x9FFF)
or (0x3400 <= c <= 0x4DBF)
or (0x20000 <= c <= 0x2A6DF)
or (0x2A700 <= c <= 0x2B73F)
or (0x2B740 <= c <= 0x2B81F)
or (0x2B820 <= c <= 0x2CEAF)
or (0xF900 <= c <= 0xFAFF)
or (0x2F800 <= c <= 0x2FA1F)
):
return True
return False


class ProphetNetTokenizer(PreTrainedTokenizer):
r"""
Construct a ProphetNetTokenizer. Based on WordPiece.
Expand Down Expand Up @@ -377,17 +443,24 @@ def get_vocab(self):
return dict(self.vocab, **self.added_tokens_encoder)

def _tokenize(self, text):
split_tokens = []
if self.do_basic_tokenize:
for token in self.basic_tokenizer.tokenize(text, never_split=self.all_special_tokens):
# If the token is part of the never_split set
if token in self.basic_tokenizer.never_split:
# Preallocate local variables for critical logic
do_basic = self.do_basic_tokenize
basic_tok = getattr(self, "basic_tokenizer", None)
wordpiece_tok = self.wordpiece_tokenizer
all_special_tokens = self.all_special_tokens

if do_basic:
nsplit = basic_tok.never_split
split_tokens = []
# Avoid repeated attr lookup and list ops inside the loop
for token in basic_tok.tokenize(text, never_split=all_special_tokens):
if token in nsplit:
split_tokens.append(token)
else:
split_tokens += self.wordpiece_tokenizer.tokenize(token)
split_tokens.extend(wordpiece_tok.tokenize(token))
return split_tokens
else:
split_tokens = self.wordpiece_tokenizer.tokenize(text)
return split_tokens
return wordpiece_tok.tokenize(text)

def _convert_token_to_id(self, token: str):
"""Converts a token (str) in an id using the vocab."""
Expand Down