mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-10-29 21:03:13 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -16,7 +16,6 @@
|
||||
|
||||
from .tokenizer import build_tokenizer
|
||||
|
||||
|
||||
_TOKENIZER = None
|
||||
_PADDED_VOCAB_SIZE = -1
|
||||
|
||||
|
||||
@@ -15,13 +15,12 @@
|
||||
|
||||
"""Tokenization classes."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import collections
|
||||
import re
|
||||
import unicodedata
|
||||
|
||||
import six
|
||||
|
||||
|
||||
@@ -43,14 +42,13 @@ def validate_case_matches_checkpoint(do_lower_case, init_checkpoint):
|
||||
model_name = m.group(1)
|
||||
|
||||
lower_models = [
|
||||
"uncased_L-24_H-1024_A-16", "uncased_L-12_H-768_A-12",
|
||||
"multilingual_L-12_H-768_A-12", "chinese_L-12_H-768_A-12"
|
||||
"uncased_L-24_H-1024_A-16",
|
||||
"uncased_L-12_H-768_A-12",
|
||||
"multilingual_L-12_H-768_A-12",
|
||||
"chinese_L-12_H-768_A-12",
|
||||
]
|
||||
|
||||
cased_models = [
|
||||
"cased_L-12_H-768_A-12", "cased_L-24_H-1024_A-16",
|
||||
"multi_cased_L-12_H-768_A-12"
|
||||
]
|
||||
cased_models = ["cased_L-12_H-768_A-12", "cased_L-24_H-1024_A-16", "multi_cased_L-12_H-768_A-12"]
|
||||
|
||||
is_bad_config = False
|
||||
if model_name in lower_models and not do_lower_case:
|
||||
@@ -71,8 +69,8 @@ def validate_case_matches_checkpoint(do_lower_case, init_checkpoint):
|
||||
"However, `%s` seems to be a %s model, so you "
|
||||
"should pass in `--do_lower_case=%s` so that the fine-tuning matches "
|
||||
"how the model was pre-training. If this error is wrong, please "
|
||||
"just comment out this check." % (actual_flag, init_checkpoint,
|
||||
model_name, case_name, opposite_flag))
|
||||
"just comment out this check." % (actual_flag, init_checkpoint, model_name, case_name, opposite_flag)
|
||||
)
|
||||
|
||||
|
||||
def convert_to_unicode(text):
|
||||
@@ -183,27 +181,27 @@ class FullTokenizer(object):
|
||||
|
||||
@staticmethod
|
||||
def convert_tokens_to_string(tokens, clean_up_tokenization_spaces=True):
|
||||
""" Converts a sequence of tokens (string) in a single string. """
|
||||
"""Converts a sequence of tokens (string) in a single string."""
|
||||
|
||||
def clean_up_tokenization(out_string):
|
||||
""" Clean up a list of simple English tokenization artifacts
|
||||
"""Clean up a list of simple English tokenization artifacts
|
||||
like spaces before punctuations and abbreviated forms.
|
||||
"""
|
||||
out_string = (
|
||||
out_string.replace(" .", ".")
|
||||
.replace(" ?", "?")
|
||||
.replace(" !", "!")
|
||||
.replace(" ,", ",")
|
||||
.replace(" ' ", "'")
|
||||
.replace(" n't", "n't")
|
||||
.replace(" 'm", "'m")
|
||||
.replace(" 's", "'s")
|
||||
.replace(" 've", "'ve")
|
||||
.replace(" 're", "'re")
|
||||
.replace(" ?", "?")
|
||||
.replace(" !", "!")
|
||||
.replace(" ,", ",")
|
||||
.replace(" ' ", "'")
|
||||
.replace(" n't", "n't")
|
||||
.replace(" 'm", "'m")
|
||||
.replace(" 's", "'s")
|
||||
.replace(" 've", "'ve")
|
||||
.replace(" 're", "'re")
|
||||
)
|
||||
return out_string
|
||||
|
||||
text = ' '.join(tokens).replace(' ##', '').strip()
|
||||
text = " ".join(tokens).replace(" ##", "").strip()
|
||||
if clean_up_tokenization_spaces:
|
||||
clean_text = clean_up_tokenization(text)
|
||||
return clean_text
|
||||
@@ -303,14 +301,16 @@ class BasicTokenizer(object):
|
||||
# as is Japanese Hiragana and Katakana. Those alphabets are used to write
|
||||
# space-separated words, so they are not treated specially and handled
|
||||
# like the all of the other languages.
|
||||
if ((cp >= 0x4E00 and cp <= 0x9FFF) or #
|
||||
(cp >= 0x3400 and cp <= 0x4DBF) or #
|
||||
(cp >= 0x20000 and cp <= 0x2A6DF) or #
|
||||
(cp >= 0x2A700 and cp <= 0x2B73F) or #
|
||||
(cp >= 0x2B740 and cp <= 0x2B81F) or #
|
||||
(cp >= 0x2B820 and cp <= 0x2CEAF) or
|
||||
(cp >= 0xF900 and cp <= 0xFAFF) or #
|
||||
(cp >= 0x2F800 and cp <= 0x2FA1F)): #
|
||||
if (
|
||||
(cp >= 0x4E00 and cp <= 0x9FFF)
|
||||
or (cp >= 0x3400 and cp <= 0x4DBF) #
|
||||
or (cp >= 0x20000 and cp <= 0x2A6DF) #
|
||||
or (cp >= 0x2A700 and cp <= 0x2B73F) #
|
||||
or (cp >= 0x2B740 and cp <= 0x2B81F) #
|
||||
or (cp >= 0x2B820 and cp <= 0x2CEAF) #
|
||||
or (cp >= 0xF900 and cp <= 0xFAFF)
|
||||
or (cp >= 0x2F800 and cp <= 0x2FA1F) #
|
||||
): #
|
||||
return True
|
||||
|
||||
return False
|
||||
@@ -320,7 +320,7 @@ class BasicTokenizer(object):
|
||||
output = []
|
||||
for char in text:
|
||||
cp = ord(char)
|
||||
if cp == 0 or cp == 0xfffd or _is_control(char):
|
||||
if cp == 0 or cp == 0xFFFD or _is_control(char):
|
||||
continue
|
||||
if _is_whitespace(char):
|
||||
output.append(" ")
|
||||
@@ -422,8 +422,7 @@ def _is_punctuation(char):
|
||||
# Characters such as "^", "$", and "`" are not in the Unicode
|
||||
# Punctuation class but we treat them as punctuation anyways, for
|
||||
# consistency.
|
||||
if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or
|
||||
(cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)):
|
||||
if (cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126):
|
||||
return True
|
||||
cat = unicodedata.category(char)
|
||||
if cat.startswith("P"):
|
||||
|
||||
@@ -25,16 +25,15 @@ from .bert_tokenization import FullTokenizer as FullBertTokenizer
|
||||
def build_tokenizer(vocab_file, tokenizer_type, vocab_extra_ids=0):
|
||||
"""Initialize tokenizer."""
|
||||
if not gpc.is_initialized(ParallelMode.GLOBAL) or gpc.get_global_rank() == 0:
|
||||
print('> building {} tokenizer ...'.format(tokenizer_type), flush=True)
|
||||
print("> building {} tokenizer ...".format(tokenizer_type), flush=True)
|
||||
|
||||
# Select and instantiate the tokenizer.
|
||||
if tokenizer_type == 'BertWordPieceLowerCase':
|
||||
if tokenizer_type == "BertWordPieceLowerCase":
|
||||
tokenizer = _BertWordPieceTokenizer(vocab_file=vocab_file, lower_case=True, vocab_extra_ids=vocab_extra_ids)
|
||||
elif tokenizer_type == 'BertWordPieceCase':
|
||||
elif tokenizer_type == "BertWordPieceCase":
|
||||
tokenizer = _BertWordPieceTokenizer(vocab_file=vocab_file, lower_case=False, vocab_extra_ids=vocab_extra_ids)
|
||||
else:
|
||||
raise NotImplementedError('{} tokenizer is not '
|
||||
'implemented.'.format(tokenizer_type))
|
||||
raise NotImplementedError("{} tokenizer is not " "implemented.".format(tokenizer_type))
|
||||
|
||||
# Add vocab size.
|
||||
padded_vocab_size = _vocab_size_with_padding(tokenizer.vocab_size)
|
||||
@@ -55,9 +54,11 @@ def _vocab_size_with_padding(orig_vocab_size, make_vocab_size_divisible_by=128):
|
||||
while (after % multiple) != 0:
|
||||
after += 1
|
||||
if not gpc.is_initialized(ParallelMode.GLOBAL) or gpc.get_global_rank() == 0:
|
||||
print(' > padded vocab (size: {}) with {} dummy tokens '
|
||||
'(new size: {})'.format(orig_vocab_size, after - orig_vocab_size, after),
|
||||
flush=True)
|
||||
print(
|
||||
" > padded vocab (size: {}) with {} dummy tokens "
|
||||
"(new size: {})".format(orig_vocab_size, after - orig_vocab_size, after),
|
||||
flush=True,
|
||||
)
|
||||
return after
|
||||
|
||||
|
||||
@@ -77,46 +78,38 @@ class AbstractTokenizer(ABC):
|
||||
@abstractmethod
|
||||
def vocab(self):
|
||||
"""Dictionary from vocab text token to id token."""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def inv_vocab(self):
|
||||
"""Dictionary from vocab id token to text token."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def tokenize(self, text):
|
||||
pass
|
||||
|
||||
def detokenize(self, token_ids):
|
||||
raise NotImplementedError('detokenizer is not implemented for {} '
|
||||
'tokenizer'.format(self.name))
|
||||
raise NotImplementedError("detokenizer is not implemented for {} " "tokenizer".format(self.name))
|
||||
|
||||
@property
|
||||
def cls(self):
|
||||
raise NotImplementedError('CLS is not provided for {} '
|
||||
'tokenizer'.format(self.name))
|
||||
raise NotImplementedError("CLS is not provided for {} " "tokenizer".format(self.name))
|
||||
|
||||
@property
|
||||
def sep(self):
|
||||
raise NotImplementedError('SEP is not provided for {} '
|
||||
'tokenizer'.format(self.name))
|
||||
raise NotImplementedError("SEP is not provided for {} " "tokenizer".format(self.name))
|
||||
|
||||
@property
|
||||
def pad(self):
|
||||
raise NotImplementedError('PAD is not provided for {} '
|
||||
'tokenizer'.format(self.name))
|
||||
raise NotImplementedError("PAD is not provided for {} " "tokenizer".format(self.name))
|
||||
|
||||
@property
|
||||
def eod(self):
|
||||
raise NotImplementedError('EOD is not provided for {} '
|
||||
'tokenizer'.format(self.name))
|
||||
raise NotImplementedError("EOD is not provided for {} " "tokenizer".format(self.name))
|
||||
|
||||
@property
|
||||
def mask(self):
|
||||
raise NotImplementedError('MASK is not provided for {} '
|
||||
'tokenizer'.format(self.name))
|
||||
raise NotImplementedError("MASK is not provided for {} " "tokenizer".format(self.name))
|
||||
|
||||
|
||||
class _BertWordPieceTokenizer(AbstractTokenizer):
|
||||
@@ -124,24 +117,24 @@ class _BertWordPieceTokenizer(AbstractTokenizer):
|
||||
|
||||
def __init__(self, vocab_file, lower_case=True, vocab_extra_ids=0):
|
||||
if lower_case:
|
||||
name = 'BERT Lower Case'
|
||||
name = "BERT Lower Case"
|
||||
else:
|
||||
name = 'BERT Upper Case'
|
||||
name = "BERT Upper Case"
|
||||
super().__init__(name)
|
||||
self.tokenizer = FullBertTokenizer(vocab_file, do_lower_case=lower_case)
|
||||
self.cls_id = self.tokenizer.vocab['[CLS]']
|
||||
self.sep_id = self.tokenizer.vocab['[SEP]']
|
||||
self.pad_id = self.tokenizer.vocab['[PAD]']
|
||||
self.mask_id = self.tokenizer.vocab['[MASK]']
|
||||
self.cls_id = self.tokenizer.vocab["[CLS]"]
|
||||
self.sep_id = self.tokenizer.vocab["[SEP]"]
|
||||
self.pad_id = self.tokenizer.vocab["[PAD]"]
|
||||
self.mask_id = self.tokenizer.vocab["[MASK]"]
|
||||
self._additional_special_tokens = []
|
||||
|
||||
# (dsachan) Add BOS and EOS tokens
|
||||
SPECIAL_TOKENS = {'eos_token': '[EOS]', 'bos_token': '[BOS]'}
|
||||
self._bos_token = '[BOS]'
|
||||
SPECIAL_TOKENS = {"eos_token": "[EOS]", "bos_token": "[BOS]"}
|
||||
self._bos_token = "[BOS]"
|
||||
self.add_token(self._bos_token)
|
||||
self._bos_token_id = self.vocab.get(self._bos_token)
|
||||
|
||||
self._eos_token = '[EOS]'
|
||||
self._eos_token = "[EOS]"
|
||||
self.add_token(self._eos_token)
|
||||
self._eos_token_id = self.vocab.get(self._eos_token)
|
||||
|
||||
@@ -185,7 +178,7 @@ class _BertWordPieceTokenizer(AbstractTokenizer):
|
||||
|
||||
def decode_token_ids(self, token_ids):
|
||||
tokens = self.tokenizer.convert_ids_to_tokens(token_ids)
|
||||
exclude_list = ['[PAD]', '[CLS]']
|
||||
exclude_list = ["[PAD]", "[CLS]"]
|
||||
non_pads = [t for t in tokens if t not in exclude_list]
|
||||
|
||||
result = ""
|
||||
@@ -215,32 +208,32 @@ class _BertWordPieceTokenizer(AbstractTokenizer):
|
||||
|
||||
@property
|
||||
def bos_token(self):
|
||||
""" Beginning of sentence token id """
|
||||
"""Beginning of sentence token id"""
|
||||
return self._bos_token
|
||||
|
||||
@property
|
||||
def eos_token(self):
|
||||
""" End of sentence token id """
|
||||
"""End of sentence token id"""
|
||||
return self._eos_token
|
||||
|
||||
@property
|
||||
def additional_special_tokens(self):
|
||||
""" All the additional special tokens you may want to use (list of strings)."""
|
||||
"""All the additional special tokens you may want to use (list of strings)."""
|
||||
return self._additional_special_tokens
|
||||
|
||||
@property
|
||||
def bos_token_id(self):
|
||||
""" Id of the beginning of sentence token in the vocabulary."""
|
||||
"""Id of the beginning of sentence token in the vocabulary."""
|
||||
return self._bos_token_id
|
||||
|
||||
@property
|
||||
def eos_token_id(self):
|
||||
""" Id of the end of sentence token in the vocabulary."""
|
||||
"""Id of the end of sentence token in the vocabulary."""
|
||||
return self._eos_token_id
|
||||
|
||||
@property
|
||||
def additional_special_tokens_ids(self):
|
||||
""" Ids of all the additional special tokens in the vocabulary (list of integers)."""
|
||||
"""Ids of all the additional special tokens in the vocabulary (list of integers)."""
|
||||
return [self.vocab.get(token) for token in self._additional_special_tokens]
|
||||
|
||||
@additional_special_tokens.setter
|
||||
|
||||
Reference in New Issue
Block a user