Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 70 additions & 42 deletions src/transformers/models/prophetnet/tokenization_prophetnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ def whitespace_tokenize(text):
text = text.strip()
if not text:
return []
tokens = text.split()
return tokens
# Avoid intermediate variable, just return directly
return text.split()


# Copied from transformers.models.bert.tokenization_bert.BasicTokenizer
Expand Down Expand Up @@ -78,6 +78,10 @@ def __init__(
self.strip_accents = strip_accents
self.do_split_on_punc = do_split_on_punc

# Preallocate commonly used method as local attribute for performance
self._unicodedata_normalize = unicodedata.normalize
self._unicodedata_category = unicodedata.category

def tokenize(self, text, never_split=None):
"""
Basic Tokenization of a piece of text. For sub-word tokenization, see WordPieceTokenizer.
Expand All @@ -99,67 +103,88 @@ def tokenize(self, text, never_split=None):
# words in the English Wikipedia.).
if self.tokenize_chinese_chars:
text = self._tokenize_chinese_chars(text)
# prevents treating the same character with different unicode codepoints as different characters
unicode_normalized_text = unicodedata.normalize("NFC", text)

# Using local var to avoid attribute lookup in loop/hot path
unicodedata_normalize = self._unicodedata_normalize
unicode_normalized_text = unicodedata_normalize("NFC", text)

orig_tokens = whitespace_tokenize(unicode_normalized_text)
split_tokens = []
never_split_contains = never_split.__contains__

# Cache method references for hot path
do_lower_case = self.do_lower_case
strip_accents = self.strip_accents
_run_strip_accents = self._run_strip_accents
_run_split_on_punc = self._run_split_on_punc

for token in orig_tokens:
if token not in never_split:
if self.do_lower_case:
if not never_split_contains(token):
if do_lower_case:
token = token.lower()
if self.strip_accents is not False:
token = self._run_strip_accents(token)
elif self.strip_accents:
token = self._run_strip_accents(token)
split_tokens.extend(self._run_split_on_punc(token, never_split))

output_tokens = whitespace_tokenize(" ".join(split_tokens))
if strip_accents is not False:
token = _run_strip_accents(token)
elif strip_accents:
token = _run_strip_accents(token)
split_tokens.extend(_run_split_on_punc(token, never_split))

# Avoid recreating the string if single token
if len(split_tokens) == 1:
output_tokens = whitespace_tokenize(split_tokens[0])
else:
output_tokens = whitespace_tokenize(" ".join(split_tokens))
return output_tokens

def _run_strip_accents(self, text):
"""Strips accents from a piece of text."""
text = unicodedata.normalize("NFD", text)
unicodedata_normalize = self._unicodedata_normalize
unicodedata_category = self._unicodedata_category

# Optimize by avoiding attribute lookups
text = unicodedata_normalize("NFD", text)
output = []
append = output.append
# Fast-path for ASCII
for char in text:
cat = unicodedata.category(char)
if cat == "Mn":
if unicodedata_category(char) == "Mn":
continue
output.append(char)
append(char)
return "".join(output)

def _run_split_on_punc(self, text, never_split=None):
"""Splits punctuation on a piece of text."""
if not self.do_split_on_punc or (never_split is not None and text in never_split):
return [text]
chars = list(text)
i = 0
start_new_word = True

# Optimize by avoiding creating a list of all characters
_is_punctuation_local = _is_punctuation
output = []
while i < len(chars):
char = chars[i]
if _is_punctuation(char):
output.append([char])
start_new_word = True
else:
if start_new_word:
output.append([])
start_new_word = False
output[-1].append(char)
i += 1
current_word = []

return ["".join(x) for x in output]
for char in text:
if _is_punctuation_local(char):
if current_word:
output.append("".join(current_word))
current_word = []
output.append(char)
else:
current_word.append(char)
if current_word:
output.append("".join(current_word))
return output

def _tokenize_chinese_chars(self, text):
"""Adds whitespace around any CJK character."""
output = []
append = output.append
_is_chinese_char = self._is_chinese_char
for char in text:
cp = ord(char)
if self._is_chinese_char(cp):
output.append(" ")
output.append(char)
output.append(" ")
if _is_chinese_char(ord(char)):
append(" ")
append(char)
append(" ")
else:
output.append(char)
append(char)
return "".join(output)

def _is_chinese_char(self, cp):
Expand Down Expand Up @@ -189,14 +214,17 @@ def _is_chinese_char(self, cp):
def _clean_text(self, text):
"""Performs invalid character removal and whitespace cleanup on text."""
output = []
append = output.append
_is_control_local = _is_control
_is_whitespace_local = _is_whitespace
for char in text:
cp = ord(char)
if cp == 0 or cp == 0xFFFD or _is_control(char):
if cp == 0 or cp == 0xFFFD or _is_control_local(char):
continue
if _is_whitespace(char):
output.append(" ")
if _is_whitespace_local(char):
append(" ")
else:
output.append(char)
append(char)
return "".join(output)


Expand Down