mirror of
				https://github.com/hpcaitech/ColossalAI.git
				synced 2025-10-30 21:39:05 +00:00 
			
		
		
		
	[tutorial] edited hands-on practices (#1899)
* Add handson to ColossalAI. * Change names of handsons and edit sequence parallel example. * Edit wrong folder name * resolve conflict * delete readme
This commit is contained in:
		| @@ -0,0 +1,38 @@ | ||||
| # coding=utf-8 | ||||
| # Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved. | ||||
| # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | ||||
| # You may obtain a copy of the License at | ||||
| # | ||||
| #     http://www.apache.org/licenses/LICENSE-2.0 | ||||
| # | ||||
| # Unless required by applicable law or agreed to in writing, software | ||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
|  | ||||
|  | ||||
| from .tokenizer import build_tokenizer | ||||
|  | ||||
|  | ||||
| _TOKENIZER = None | ||||
| _PADDED_VOCAB_SIZE = -1 | ||||
|  | ||||
|  | ||||
| def initialize_tokenizer(vocab_file, tokenizer_type, vocab_extra_ids=0): | ||||
|     tokenizer, padded_vocab_size = build_tokenizer(vocab_file, tokenizer_type, vocab_extra_ids) | ||||
|     global _TOKENIZER, _PADDED_VOCAB_SIZE | ||||
|     _TOKENIZER = tokenizer | ||||
|     _PADDED_VOCAB_SIZE = padded_vocab_size | ||||
|  | ||||
|  | ||||
| def get_tokenizer(): | ||||
|     global _TOKENIZER | ||||
|     return _TOKENIZER | ||||
|  | ||||
|  | ||||
| def get_padded_vocab_size(): | ||||
|     global _PADDED_VOCAB_SIZE | ||||
|     return _PADDED_VOCAB_SIZE | ||||
| @@ -0,0 +1,431 @@ | ||||
| # coding=utf-8 | ||||
| # Copyright 2018 The Google AI Language Team Authors. | ||||
| # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | ||||
| # You may obtain a copy of the License at | ||||
| # | ||||
| #     http://www.apache.org/licenses/LICENSE-2.0 | ||||
| # | ||||
| # Unless required by applicable law or agreed to in writing, software | ||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
|  | ||||
| """Tokenization classes.""" | ||||
|  | ||||
| from __future__ import absolute_import | ||||
| from __future__ import division | ||||
| from __future__ import print_function | ||||
|  | ||||
| import collections | ||||
| import re | ||||
| import unicodedata | ||||
| import six | ||||
|  | ||||
|  | ||||
| def validate_case_matches_checkpoint(do_lower_case, init_checkpoint): | ||||
|     """Checks whether the casing config is consistent with the checkpoint name.""" | ||||
|  | ||||
|     # The casing has to be passed in by the user and there is no explicit check | ||||
|     # as to whether it matches the checkpoint. The casing information probably | ||||
|     # should have been stored in the bert_config.json file, but it's not, so | ||||
|     # we have to heuristically detect it to validate. | ||||
|  | ||||
|     if not init_checkpoint: | ||||
|         return | ||||
|  | ||||
|     m = re.match("^.*?([A-Za-z0-9_-]+)/bert_model.ckpt", init_checkpoint) | ||||
|     if m is None: | ||||
|         return | ||||
|  | ||||
|     model_name = m.group(1) | ||||
|  | ||||
|     lower_models = [ | ||||
|         "uncased_L-24_H-1024_A-16", "uncased_L-12_H-768_A-12", | ||||
|         "multilingual_L-12_H-768_A-12", "chinese_L-12_H-768_A-12" | ||||
|     ] | ||||
|  | ||||
|     cased_models = [ | ||||
|         "cased_L-12_H-768_A-12", "cased_L-24_H-1024_A-16", | ||||
|         "multi_cased_L-12_H-768_A-12" | ||||
|     ] | ||||
|  | ||||
|     is_bad_config = False | ||||
|     if model_name in lower_models and not do_lower_case: | ||||
|         is_bad_config = True | ||||
|         actual_flag = "False" | ||||
|         case_name = "lowercased" | ||||
|         opposite_flag = "True" | ||||
|  | ||||
|     if model_name in cased_models and do_lower_case: | ||||
|         is_bad_config = True | ||||
|         actual_flag = "True" | ||||
|         case_name = "cased" | ||||
|         opposite_flag = "False" | ||||
|  | ||||
|     if is_bad_config: | ||||
|         raise ValueError( | ||||
|             "You passed in `--do_lower_case=%s` with `--init_checkpoint=%s`. " | ||||
|             "However, `%s` seems to be a %s model, so you " | ||||
|             "should pass in `--do_lower_case=%s` so that the fine-tuning matches " | ||||
|             "how the model was pre-training. If this error is wrong, please " | ||||
|             "just comment out this check." % (actual_flag, init_checkpoint, | ||||
|                                               model_name, case_name, opposite_flag)) | ||||
|  | ||||
|  | ||||
| def convert_to_unicode(text): | ||||
|     """Converts `text` to Unicode (if it's not already), assuming utf-8 input.""" | ||||
|     if six.PY3: | ||||
|         if isinstance(text, str): | ||||
|             return text | ||||
|         elif isinstance(text, bytes): | ||||
|             return text.decode("utf-8", "ignore") | ||||
|         else: | ||||
|             raise ValueError("Unsupported string type: %s" % (type(text))) | ||||
|     elif six.PY2: | ||||
|         if isinstance(text, str): | ||||
|             return text.decode("utf-8", "ignore") | ||||
|         elif isinstance(text, unicode): | ||||
|             return text | ||||
|         else: | ||||
|             raise ValueError("Unsupported string type: %s" % (type(text))) | ||||
|     else: | ||||
|         raise ValueError("Not running on Python2 or Python 3?") | ||||
|  | ||||
|  | ||||
| def printable_text(text): | ||||
|     """Returns text encoded in a way suitable for print or `tf.logging`.""" | ||||
|  | ||||
|     # These functions want `str` for both Python2 and Python3, but in one case | ||||
|     # it's a Unicode string and in the other it's a byte string. | ||||
|     if six.PY3: | ||||
|         if isinstance(text, str): | ||||
|             return text | ||||
|         elif isinstance(text, bytes): | ||||
|             return text.decode("utf-8", "ignore") | ||||
|         else: | ||||
|             raise ValueError("Unsupported string type: %s" % (type(text))) | ||||
|     elif six.PY2: | ||||
|         if isinstance(text, str): | ||||
|             return text | ||||
|         elif isinstance(text, unicode): | ||||
|             return text.encode("utf-8") | ||||
|         else: | ||||
|             raise ValueError("Unsupported string type: %s" % (type(text))) | ||||
|     else: | ||||
|         raise ValueError("Not running on Python2 or Python 3?") | ||||
|  | ||||
|  | ||||
| def load_vocab(vocab_file): | ||||
|     """Loads a vocabulary file into a dictionary.""" | ||||
|     vocab = collections.OrderedDict() | ||||
|     index = 0 | ||||
|     with open(vocab_file, "r") as reader: | ||||
|         while True: | ||||
|             token = convert_to_unicode(reader.readline()) | ||||
|             if not token: | ||||
|                 break | ||||
|             token = token.strip() | ||||
|             vocab[token] = index | ||||
|             index += 1 | ||||
|     return vocab | ||||
|  | ||||
|  | ||||
| def convert_by_vocab(vocab, items): | ||||
|     """Converts a sequence of [tokens|ids] using the vocab.""" | ||||
|     output = [] | ||||
|     for item in items: | ||||
|         output.append(vocab[item]) | ||||
|     return output | ||||
|  | ||||
|  | ||||
| def convert_tokens_to_ids(vocab, tokens): | ||||
|     return convert_by_vocab(vocab, tokens) | ||||
|  | ||||
|  | ||||
| def convert_ids_to_tokens(inv_vocab, ids): | ||||
|     return convert_by_vocab(inv_vocab, ids) | ||||
|  | ||||
|  | ||||
| def whitespace_tokenize(text): | ||||
|     """Runs basic whitespace cleaning and splitting on a piece of text.""" | ||||
|     text = text.strip() | ||||
|     if not text: | ||||
|         return [] | ||||
|     tokens = text.split() | ||||
|     return tokens | ||||
|  | ||||
|  | ||||
| class FullTokenizer(object): | ||||
|     """Runs end-to-end tokenization.""" | ||||
|  | ||||
|     def __init__(self, vocab_file, do_lower_case=True): | ||||
|         self.vocab = load_vocab(vocab_file) | ||||
|         self.inv_vocab = {v: k for k, v in self.vocab.items()} | ||||
|         self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case) | ||||
|         self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab) | ||||
|  | ||||
|     def tokenize(self, text): | ||||
|         split_tokens = [] | ||||
|         for token in self.basic_tokenizer.tokenize(text): | ||||
|             for sub_token in self.wordpiece_tokenizer.tokenize(token): | ||||
|                 split_tokens.append(sub_token) | ||||
|  | ||||
|         return split_tokens | ||||
|  | ||||
|     def convert_tokens_to_ids(self, tokens): | ||||
|         return convert_by_vocab(self.vocab, tokens) | ||||
|  | ||||
|     def convert_ids_to_tokens(self, ids): | ||||
|         return convert_by_vocab(self.inv_vocab, ids) | ||||
|  | ||||
|     @staticmethod | ||||
|     def convert_tokens_to_string(tokens, clean_up_tokenization_spaces=True): | ||||
|         """ Converts a sequence of tokens (string) in a single string. """ | ||||
|  | ||||
|         def clean_up_tokenization(out_string): | ||||
|             """ Clean up a list of simple English tokenization artifacts | ||||
|             like spaces before punctuations and abbreviated forms. | ||||
|             """ | ||||
|             out_string = ( | ||||
|                 out_string.replace(" .", ".") | ||||
|                     .replace(" ?", "?") | ||||
|                     .replace(" !", "!") | ||||
|                     .replace(" ,", ",") | ||||
|                     .replace(" ' ", "'") | ||||
|                     .replace(" n't", "n't") | ||||
|                     .replace(" 'm", "'m") | ||||
|                     .replace(" 's", "'s") | ||||
|                     .replace(" 've", "'ve") | ||||
|                     .replace(" 're", "'re") | ||||
|             ) | ||||
|             return out_string | ||||
|  | ||||
|         text = ' '.join(tokens).replace(' ##', '').strip() | ||||
|         if clean_up_tokenization_spaces: | ||||
|             clean_text = clean_up_tokenization(text) | ||||
|             return clean_text | ||||
|         else: | ||||
|             return text | ||||
|  | ||||
|     def vocab_size(self): | ||||
|         return len(self.vocab) | ||||
|  | ||||
|  | ||||
| class BasicTokenizer(object): | ||||
|     """Runs basic tokenization (punctuation splitting, lower casing, etc.).""" | ||||
|  | ||||
|     def __init__(self, do_lower_case=True): | ||||
|         """Constructs a BasicTokenizer. | ||||
|  | ||||
|         Args: | ||||
|           do_lower_case: Whether to lower case the input. | ||||
|         """ | ||||
|         self.do_lower_case = do_lower_case | ||||
|  | ||||
|     def tokenize(self, text): | ||||
|         """Tokenizes a piece of text.""" | ||||
|         text = convert_to_unicode(text) | ||||
|         text = self._clean_text(text) | ||||
|  | ||||
|         # This was added on November 1st, 2018 for the multilingual and Chinese | ||||
|         # models. This is also applied to the English models now, but it doesn't | ||||
|         # matter since the English models were not trained on any Chinese data | ||||
|         # and generally don't have any Chinese data in them (there are Chinese | ||||
|         # characters in the vocabulary because Wikipedia does have some Chinese | ||||
|         # words in the English Wikipedia.). | ||||
|         text = self._tokenize_chinese_chars(text) | ||||
|  | ||||
|         orig_tokens = whitespace_tokenize(text) | ||||
|         split_tokens = [] | ||||
|         for token in orig_tokens: | ||||
|             if self.do_lower_case: | ||||
|                 token = token.lower() | ||||
|                 token = self._run_strip_accents(token) | ||||
|             split_tokens.extend(self._run_split_on_punc(token)) | ||||
|  | ||||
|         output_tokens = whitespace_tokenize(" ".join(split_tokens)) | ||||
|         return output_tokens | ||||
|  | ||||
|     def _run_strip_accents(self, text): | ||||
|         """Strips accents from a piece of text.""" | ||||
|         text = unicodedata.normalize("NFD", text) | ||||
|         output = [] | ||||
|         for char in text: | ||||
|             cat = unicodedata.category(char) | ||||
|             if cat == "Mn": | ||||
|                 continue | ||||
|             output.append(char) | ||||
|         return "".join(output) | ||||
|  | ||||
|     def _run_split_on_punc(self, text): | ||||
|         """Splits punctuation on a piece of text.""" | ||||
|         chars = list(text) | ||||
|         i = 0 | ||||
|         start_new_word = True | ||||
|         output = [] | ||||
|         while i < len(chars): | ||||
|             char = chars[i] | ||||
|             if _is_punctuation(char): | ||||
|                 output.append([char]) | ||||
|                 start_new_word = True | ||||
|             else: | ||||
|                 if start_new_word: | ||||
|                     output.append([]) | ||||
|                 start_new_word = False | ||||
|                 output[-1].append(char) | ||||
|             i += 1 | ||||
|  | ||||
|         return ["".join(x) for x in output] | ||||
|  | ||||
|     def _tokenize_chinese_chars(self, text): | ||||
|         """Adds whitespace around any CJK character.""" | ||||
|         output = [] | ||||
|         for char in text: | ||||
|             cp = ord(char) | ||||
|             if self._is_chinese_char(cp): | ||||
|                 output.append(" ") | ||||
|                 output.append(char) | ||||
|                 output.append(" ") | ||||
|             else: | ||||
|                 output.append(char) | ||||
|         return "".join(output) | ||||
|  | ||||
|     def _is_chinese_char(self, cp): | ||||
|         """Checks whether CP is the codepoint of a CJK character.""" | ||||
|         # This defines a "chinese character" as anything in the CJK Unicode block: | ||||
|         #   https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) | ||||
|         # | ||||
|         # Note that the CJK Unicode block is NOT all Japanese and Korean characters, | ||||
|         # despite its name. The modern Korean Hangul alphabet is a different block, | ||||
|         # as is Japanese Hiragana and Katakana. Those alphabets are used to write | ||||
|         # space-separated words, so they are not treated specially and handled | ||||
|         # like the all of the other languages. | ||||
|         if ((cp >= 0x4E00 and cp <= 0x9FFF) or  # | ||||
|             (cp >= 0x3400 and cp <= 0x4DBF) or  # | ||||
|             (cp >= 0x20000 and cp <= 0x2A6DF) or  # | ||||
|             (cp >= 0x2A700 and cp <= 0x2B73F) or  # | ||||
|             (cp >= 0x2B740 and cp <= 0x2B81F) or  # | ||||
|             (cp >= 0x2B820 and cp <= 0x2CEAF) or | ||||
|             (cp >= 0xF900 and cp <= 0xFAFF) or  # | ||||
|                 (cp >= 0x2F800 and cp <= 0x2FA1F)):  # | ||||
|             return True | ||||
|  | ||||
|         return False | ||||
|  | ||||
|     def _clean_text(self, text): | ||||
|         """Performs invalid character removal and whitespace cleanup on text.""" | ||||
|         output = [] | ||||
|         for char in text: | ||||
|             cp = ord(char) | ||||
|             if cp == 0 or cp == 0xfffd or _is_control(char): | ||||
|                 continue | ||||
|             if _is_whitespace(char): | ||||
|                 output.append(" ") | ||||
|             else: | ||||
|                 output.append(char) | ||||
|         return "".join(output) | ||||
|  | ||||
|  | ||||
| class WordpieceTokenizer(object): | ||||
|     """Runs WordPiece tokenization.""" | ||||
|  | ||||
|     def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=200): | ||||
|         self.vocab = vocab | ||||
|         self.unk_token = unk_token | ||||
|         self.max_input_chars_per_word = max_input_chars_per_word | ||||
|  | ||||
|     def tokenize(self, text): | ||||
|         """Tokenizes a piece of text into its word pieces. | ||||
|  | ||||
|         This uses a greedy longest-match-first algorithm to perform tokenization | ||||
|         using the given vocabulary. | ||||
|  | ||||
|         For example: | ||||
|           input = "unaffable" | ||||
|           output = ["un", "##aff", "##able"] | ||||
|  | ||||
|         Args: | ||||
|           text: A single token or whitespace separated tokens. This should have | ||||
|             already been passed through `BasicTokenizer. | ||||
|  | ||||
|         Returns: | ||||
|           A list of wordpiece tokens. | ||||
|         """ | ||||
|  | ||||
|         text = convert_to_unicode(text) | ||||
|  | ||||
|         output_tokens = [] | ||||
|         for token in whitespace_tokenize(text): | ||||
|             chars = list(token) | ||||
|             if len(chars) > self.max_input_chars_per_word: | ||||
|                 output_tokens.append(self.unk_token) | ||||
|                 continue | ||||
|  | ||||
|             is_bad = False | ||||
|             start = 0 | ||||
|             sub_tokens = [] | ||||
|             while start < len(chars): | ||||
|                 end = len(chars) | ||||
|                 cur_substr = None | ||||
|                 while start < end: | ||||
|                     substr = "".join(chars[start:end]) | ||||
|                     if start > 0: | ||||
|                         substr = "##" + substr | ||||
|                     if substr in self.vocab: | ||||
|                         cur_substr = substr | ||||
|                         break | ||||
|                     end -= 1 | ||||
|                 if cur_substr is None: | ||||
|                     is_bad = True | ||||
|                     break | ||||
|                 sub_tokens.append(cur_substr) | ||||
|                 start = end | ||||
|  | ||||
|             if is_bad: | ||||
|                 output_tokens.append(self.unk_token) | ||||
|             else: | ||||
|                 output_tokens.extend(sub_tokens) | ||||
|         return output_tokens | ||||
|  | ||||
|  | ||||
| def _is_whitespace(char): | ||||
|     """Checks whether `chars` is a whitespace character.""" | ||||
|     # \t, \n, and \r are technically control characters but we treat them | ||||
|     # as whitespace since they are generally considered as such. | ||||
|     if char == " " or char == "\t" or char == "\n" or char == "\r": | ||||
|         return True | ||||
|     cat = unicodedata.category(char) | ||||
|     if cat == "Zs": | ||||
|         return True | ||||
|     return False | ||||
|  | ||||
|  | ||||
| def _is_control(char): | ||||
|     """Checks whether `chars` is a control character.""" | ||||
|     # These are technically control characters but we count them as whitespace | ||||
|     # characters. | ||||
|     if char == "\t" or char == "\n" or char == "\r": | ||||
|         return False | ||||
|     cat = unicodedata.category(char) | ||||
|     if cat in ("Cc", "Cf"): | ||||
|         return True | ||||
|     return False | ||||
|  | ||||
|  | ||||
| def _is_punctuation(char): | ||||
|     """Checks whether `chars` is a punctuation character.""" | ||||
|     cp = ord(char) | ||||
|     # We treat all non-letter/number ASCII as punctuation. | ||||
|     # Characters such as "^", "$", and "`" are not in the Unicode | ||||
|     # Punctuation class but we treat them as punctuation anyways, for | ||||
|     # consistency. | ||||
|     if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or | ||||
|             (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)): | ||||
|         return True | ||||
|     cat = unicodedata.category(char) | ||||
|     if cat.startswith("P"): | ||||
|         return True | ||||
|     return False | ||||
							
								
								
									
										256
									
								
								examples/tutorial/sequence_parallel/data/tokenizer/tokenizer.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										256
									
								
								examples/tutorial/sequence_parallel/data/tokenizer/tokenizer.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,256 @@ | ||||
| # coding=utf-8 | ||||
| # Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved. | ||||
| # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | ||||
| # You may obtain a copy of the License at | ||||
| # | ||||
| #     http://www.apache.org/licenses/LICENSE-2.0 | ||||
| # | ||||
| # Unless required by applicable law or agreed to in writing, software | ||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
|  | ||||
| """Megatron tokenizers.""" | ||||
|  | ||||
| from abc import ABC | ||||
| from abc import abstractmethod | ||||
| from colossalai.core import global_context as gpc | ||||
| from colossalai.context import ParallelMode | ||||
|  | ||||
| from .bert_tokenization import FullTokenizer as FullBertTokenizer | ||||
|  | ||||
|  | ||||
| def build_tokenizer(vocab_file, tokenizer_type, vocab_extra_ids=0): | ||||
|     """Initialize tokenizer.""" | ||||
|     if not gpc.is_initialized(ParallelMode.GLOBAL) or gpc.get_global_rank() == 0: | ||||
|         print('> building {} tokenizer ...'.format(tokenizer_type), | ||||
|               flush=True) | ||||
|  | ||||
|     # Select and instantiate the tokenizer. | ||||
|     if tokenizer_type == 'BertWordPieceLowerCase': | ||||
|         tokenizer = _BertWordPieceTokenizer(vocab_file=vocab_file, | ||||
|                                             lower_case=True, | ||||
|                                             vocab_extra_ids=vocab_extra_ids) | ||||
|     elif tokenizer_type == 'BertWordPieceCase': | ||||
|         tokenizer = _BertWordPieceTokenizer(vocab_file=vocab_file, | ||||
|                                             lower_case=False, | ||||
|                                             vocab_extra_ids=vocab_extra_ids) | ||||
|     else: | ||||
|         raise NotImplementedError('{} tokenizer is not ' | ||||
|                                   'implemented.'.format(tokenizer_type)) | ||||
|  | ||||
|     # Add vocab size. | ||||
|     padded_vocab_size = _vocab_size_with_padding(tokenizer.vocab_size) | ||||
|  | ||||
|     return tokenizer, padded_vocab_size | ||||
|  | ||||
|  | ||||
| def _vocab_size_with_padding(orig_vocab_size, make_vocab_size_divisible_by=128): | ||||
|     """Pad vocab size so it is divisible by model parallel size and | ||||
|     still having GPU friendly size.""" | ||||
|  | ||||
|     after = orig_vocab_size | ||||
|  | ||||
|     if gpc.is_initialized(ParallelMode.TENSOR): | ||||
|         multiple = make_vocab_size_divisible_by * gpc.get_world_size(ParallelMode.TENSOR) | ||||
|     else: | ||||
|         multiple = make_vocab_size_divisible_by | ||||
|     while (after % multiple) != 0: | ||||
|         after += 1 | ||||
|     if not gpc.is_initialized(ParallelMode.GLOBAL) or gpc.get_global_rank() == 0: | ||||
|         print(' > padded vocab (size: {}) with {} dummy tokens ' | ||||
|               '(new size: {})'.format( | ||||
|                   orig_vocab_size, after - orig_vocab_size, after), flush=True) | ||||
|     return after | ||||
|  | ||||
|  | ||||
| class AbstractTokenizer(ABC): | ||||
|     """Abstract class for tokenizer.""" | ||||
|  | ||||
|     def __init__(self, name): | ||||
|         self.name = name | ||||
|         super().__init__() | ||||
|  | ||||
|     @property | ||||
|     @abstractmethod | ||||
|     def vocab_size(self): | ||||
|         pass | ||||
|  | ||||
|     @property | ||||
|     @abstractmethod | ||||
|     def vocab(self): | ||||
|         """Dictionary from vocab text token to id token.""" | ||||
|         pass | ||||
|  | ||||
|     @property | ||||
|     @abstractmethod | ||||
|     def inv_vocab(self): | ||||
|         """Dictionary from vocab id token to text token.""" | ||||
|         pass | ||||
|  | ||||
|     @abstractmethod | ||||
|     def tokenize(self, text): | ||||
|         pass | ||||
|  | ||||
|     def detokenize(self, token_ids): | ||||
|         raise NotImplementedError('detokenizer is not implemented for {} ' | ||||
|                                   'tokenizer'.format(self.name)) | ||||
|  | ||||
|     @property | ||||
|     def cls(self): | ||||
|         raise NotImplementedError('CLS is not provided for {} ' | ||||
|                                   'tokenizer'.format(self.name)) | ||||
|  | ||||
|     @property | ||||
|     def sep(self): | ||||
|         raise NotImplementedError('SEP is not provided for {} ' | ||||
|                                   'tokenizer'.format(self.name)) | ||||
|  | ||||
|     @property | ||||
|     def pad(self): | ||||
|         raise NotImplementedError('PAD is not provided for {} ' | ||||
|                                   'tokenizer'.format(self.name)) | ||||
|  | ||||
|     @property | ||||
|     def eod(self): | ||||
|         raise NotImplementedError('EOD is not provided for {} ' | ||||
|                                   'tokenizer'.format(self.name)) | ||||
|  | ||||
|     @property | ||||
|     def mask(self): | ||||
|         raise NotImplementedError('MASK is not provided for {} ' | ||||
|                                   'tokenizer'.format(self.name)) | ||||
|  | ||||
|  | ||||
| class _BertWordPieceTokenizer(AbstractTokenizer): | ||||
|     """Original BERT wordpiece tokenizer.""" | ||||
|  | ||||
|     def __init__(self, vocab_file, lower_case=True, vocab_extra_ids=0): | ||||
|         if lower_case: | ||||
|             name = 'BERT Lower Case' | ||||
|         else: | ||||
|             name = 'BERT Upper Case' | ||||
|         super().__init__(name) | ||||
|         self.tokenizer = FullBertTokenizer(vocab_file, do_lower_case=lower_case) | ||||
|         self.cls_id = self.tokenizer.vocab['[CLS]'] | ||||
|         self.sep_id = self.tokenizer.vocab['[SEP]'] | ||||
|         self.pad_id = self.tokenizer.vocab['[PAD]'] | ||||
|         self.mask_id = self.tokenizer.vocab['[MASK]'] | ||||
|         self._additional_special_tokens = [] | ||||
|  | ||||
|         # (dsachan) Add BOS and EOS tokens | ||||
|         SPECIAL_TOKENS = {'eos_token': '[EOS]', | ||||
|                           'bos_token': '[BOS]'} | ||||
|         self._bos_token = '[BOS]' | ||||
|         self.add_token(self._bos_token) | ||||
|         self._bos_token_id = self.vocab.get(self._bos_token) | ||||
|  | ||||
|         self._eos_token = '[EOS]' | ||||
|         self.add_token(self._eos_token) | ||||
|         self._eos_token_id = self.vocab.get(self._eos_token) | ||||
|  | ||||
|         # (dsachan) Add additional special tokens | ||||
|         # These can be used as sentinel tokens in T5 model inputs | ||||
|         additional_special_tokens = [] | ||||
|         additional_special_tokens.extend( | ||||
|             ["<extra_id_{}>".format(i) for i in range(vocab_extra_ids)]) | ||||
|         self.add_additional_special_tokens(additional_special_tokens) | ||||
|  | ||||
|     def add_token(self, token): | ||||
|         if token not in self.vocab: | ||||
|             self.inv_vocab[self.vocab_size] = token | ||||
|             # self.vocab_size comes from len(vocab) | ||||
|             # and it will increase as we add elements | ||||
|             self.vocab[token] = self.vocab_size | ||||
|  | ||||
|     def add_additional_special_tokens(self, tokens_list): | ||||
|         setattr(self, "additional_special_tokens", tokens_list) | ||||
|         for value in tokens_list: | ||||
|             self.add_token(value) | ||||
|  | ||||
|     @property | ||||
|     def vocab_size(self): | ||||
|         return self.tokenizer.vocab_size() | ||||
|  | ||||
|     @property | ||||
|     def vocab(self): | ||||
|         return self.tokenizer.vocab | ||||
|  | ||||
|     @property | ||||
|     def inv_vocab(self): | ||||
|         return self.tokenizer.inv_vocab | ||||
|  | ||||
|     def tokenize(self, text): | ||||
|         text_tokens = self.tokenizer.tokenize(text) | ||||
|         return self.tokenizer.convert_tokens_to_ids(text_tokens) | ||||
|  | ||||
|     def decode(self, ids): | ||||
|         tokens = self.tokenizer.convert_ids_to_tokens(ids) | ||||
|         return self.tokenizer.convert_tokens_to_string(tokens) | ||||
|  | ||||
|     def decode_token_ids(self, token_ids): | ||||
|         tokens = self.tokenizer.convert_ids_to_tokens(token_ids) | ||||
|         exclude_list = ['[PAD]', '[CLS]'] | ||||
|         non_pads = [t for t in tokens if t not in exclude_list] | ||||
|  | ||||
|         result = "" | ||||
|         for s in non_pads: | ||||
|             if s.startswith("##"): | ||||
|                 result += s[2:] | ||||
|             else: | ||||
|                 result += " " + s | ||||
|  | ||||
|         return result | ||||
|  | ||||
|     @property | ||||
|     def cls(self): | ||||
|         return self.cls_id | ||||
|  | ||||
|     @property | ||||
|     def sep(self): | ||||
|         return self.sep_id | ||||
|  | ||||
|     @property | ||||
|     def pad(self): | ||||
|         return self.pad_id | ||||
|  | ||||
|     @property | ||||
|     def mask(self): | ||||
|         return self.mask_id | ||||
|  | ||||
|     @property | ||||
|     def bos_token(self): | ||||
|         """ Beginning of sentence token id """ | ||||
|         return self._bos_token | ||||
|  | ||||
|     @property | ||||
|     def eos_token(self): | ||||
|         """ End of sentence token id """ | ||||
|         return self._eos_token | ||||
|  | ||||
|     @property | ||||
|     def additional_special_tokens(self): | ||||
|         """ All the additional special tokens you may want to use (list of strings).""" | ||||
|         return self._additional_special_tokens | ||||
|  | ||||
|     @property | ||||
|     def bos_token_id(self): | ||||
|         """ Id of the beginning of sentence token in the vocabulary.""" | ||||
|         return self._bos_token_id | ||||
|  | ||||
|     @property | ||||
|     def eos_token_id(self): | ||||
|         """ Id of the end of sentence token in the vocabulary.""" | ||||
|         return self._eos_token_id | ||||
|  | ||||
|     @property | ||||
|     def additional_special_tokens_ids(self): | ||||
|         """ Ids of all the additional special tokens in the vocabulary (list of integers).""" | ||||
|         return [self.vocab.get(token) for token in self._additional_special_tokens] | ||||
|  | ||||
|     @additional_special_tokens.setter | ||||
|     def additional_special_tokens(self, value): | ||||
|         self._additional_special_tokens = value | ||||
		Reference in New Issue
	
	Block a user