mirror of
				https://github.com/hpcaitech/ColossalAI.git
				synced 2025-10-30 21:39:05 +00:00 
			
		
		
		
	[tutorial] edited hands-on practices (#1899)
* Add handson to ColossalAI. * Change names of handsons and edit sequence parallel example. * Edit wrong folder name * resolve conflict * delete readme
This commit is contained in:
		
							
								
								
									
										102
									
								
								examples/tutorial/sequence_parallel/data/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										102
									
								
								examples/tutorial/sequence_parallel/data/__init__.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,102 @@ | ||||
| from colossalai.context.parallel_context import ParallelContext | ||||
| from colossalai.core import global_context as gpc | ||||
| from colossalai.logging import get_dist_logger | ||||
| from colossalai.context import ParallelMode | ||||
| from .datasets.data_samplers import build_pretraining_data_loader | ||||
| from .datasets.builder import build_train_valid_test_datasets | ||||
| import torch | ||||
|  | ||||
|  | ||||
| def cyclic_iter(iter): | ||||
|     while True: | ||||
|         for x in iter: | ||||
|             yield x | ||||
|  | ||||
|  | ||||
| def build_train_valid_test_data_iterators(train_iters, | ||||
|                                           global_batch_size, | ||||
|                                           eval_interval, | ||||
|                                           eval_iters, | ||||
|                                           dataloader_type='single', | ||||
|                                           **kwargs | ||||
|                                           ): | ||||
|     (train_dataloader, valid_dataloader, test_dataloader) = (None, None, None) | ||||
|  | ||||
|     logger = get_dist_logger() | ||||
|     logger.info('> building train, validation, and test datasets ...', ranks=[0]) | ||||
|  | ||||
|     # Backward compatibility, assume fixed batch size. | ||||
|     # if iteration > 0 and consumed_train_samples == 0: | ||||
|     #     assert train_samples is None, \ | ||||
|     #         'only backward compatibility support for iteration-based training' | ||||
|     #     consumed_train_samples = iteration * global_batch_size | ||||
|     # if iteration > 0 and consumed_valid_samples == 0: | ||||
|     #     if train_samples is None: | ||||
|     #         consumed_valid_samples = (iteration // eval_interval) * \ | ||||
|     #             eval_iters * global_batch_size | ||||
|  | ||||
|     # Data loader only on rank 0 of each model parallel group. | ||||
|     if not gpc.is_initialized(ParallelMode.TENSOR) or gpc.get_local_rank(ParallelMode.TENSOR) == 0: | ||||
|  | ||||
|         # Number of train/valid/test samples. | ||||
|         train_samples = train_iters * global_batch_size | ||||
|         eval_iters_ = (train_iters // eval_interval + 1) * eval_iters | ||||
|         test_iters = eval_iters | ||||
|         train_val_test_num_samples = [train_samples, | ||||
|                                       eval_iters_ * global_batch_size, | ||||
|                                       test_iters * global_batch_size] | ||||
|         logger.info(' > datasets target sizes (minimum size):') | ||||
|         logger.info('    train:      {}'.format(train_val_test_num_samples[0]), ranks=[0]) | ||||
|         logger.info('    validation: {}'.format(train_val_test_num_samples[1]), ranks=[0]) | ||||
|         logger.info('    test:       {}'.format(train_val_test_num_samples[2]), ranks=[0]) | ||||
|  | ||||
|         # Build the datasets. | ||||
|         train_ds, valid_ds, test_ds = build_train_valid_test_datasets( | ||||
|             train_valid_test_num_samples=train_val_test_num_samples, **kwargs) | ||||
|  | ||||
|         # Build dataloaders. | ||||
|         dp_size = gpc.get_world_size(ParallelMode.DATA) | ||||
|         train_dataloader = build_pretraining_data_loader( | ||||
|             train_ds, consumed_samples=0, micro_batch_size=global_batch_size//dp_size) | ||||
|         valid_dataloader = build_pretraining_data_loader( | ||||
|             valid_ds, consumed_samples=0, micro_batch_size=global_batch_size//dp_size) | ||||
|         test_dataloader = build_pretraining_data_loader(test_ds, 0, micro_batch_size=global_batch_size//dp_size) | ||||
|  | ||||
|         # Flags to know if we need to do training/validation/testing. | ||||
|         do_train = train_dataloader is not None and train_iters > 0 | ||||
|         do_valid = valid_dataloader is not None and eval_iters > 0 | ||||
|         do_test = test_dataloader is not None and eval_iters > 0 | ||||
|         # Need to broadcast num_tokens and num_type_tokens. | ||||
|         flags = torch.cuda.LongTensor( | ||||
|             [int(do_train), int(do_valid), int(do_test)]) | ||||
|     else: | ||||
|         flags = torch.cuda.LongTensor([0, 0, 0]) | ||||
|  | ||||
|     # Broadcast num tokens. | ||||
|     torch.distributed.broadcast(flags, | ||||
|                                 gpc.get_ranks_in_group(ParallelMode.TENSOR)[0], | ||||
|                                 group=gpc.get_group(ParallelMode.TENSOR)) | ||||
|  | ||||
|     # Build iterators. | ||||
|     dl_type = dataloader_type | ||||
|     assert dl_type in ['single', 'cyclic'] | ||||
|  | ||||
|     if train_dataloader is not None: | ||||
|         train_data_iterator = iter(train_dataloader) if dl_type == 'single' \ | ||||
|             else iter(cyclic_iter(train_dataloader)) | ||||
|     else: | ||||
|         train_data_iterator = None | ||||
|  | ||||
|     if valid_dataloader is not None: | ||||
|         valid_data_iterator = iter(valid_dataloader) if dl_type == 'single' \ | ||||
|             else iter(cyclic_iter(valid_dataloader)) | ||||
|     else: | ||||
|         valid_data_iterator = None | ||||
|  | ||||
|     if test_dataloader is not None: | ||||
|         test_data_iterator = iter(test_dataloader) if dl_type == 'single' \ | ||||
|             else iter(cyclic_iter(test_dataloader)) | ||||
|     else: | ||||
|         test_data_iterator = None | ||||
|  | ||||
|     return train_data_iterator, valid_data_iterator, test_data_iterator | ||||
							
								
								
									
										165
									
								
								examples/tutorial/sequence_parallel/data/bert_helper.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										165
									
								
								examples/tutorial/sequence_parallel/data/bert_helper.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,165 @@ | ||||
| from colossalai.core import global_context as gpc | ||||
| from colossalai.context import ParallelMode | ||||
| import torch | ||||
|  | ||||
| _MAX_DATA_DIM = 5 | ||||
|  | ||||
|  | ||||
| def _build_key_size_numel_dictionaries(keys, data): | ||||
|     """Build the size on rank 0 and broadcast.""" | ||||
|     max_dim = _MAX_DATA_DIM | ||||
|     sizes = [0 for _ in range(max_dim) for _ in keys] | ||||
|  | ||||
|     # Pack the sizes on rank zero. | ||||
|     if not gpc.is_initialized(ParallelMode.TENSOR) or gpc.get_local_rank(ParallelMode.TENSOR) == 0: | ||||
|         offset = 0 | ||||
|         for key in keys: | ||||
|             assert data[key].dim() < max_dim, 'you should increase MAX_DATA_DIM' | ||||
|             size = data[key].size() | ||||
|             for i, s in enumerate(size): | ||||
|                 sizes[i + offset] = s | ||||
|             offset += max_dim | ||||
|  | ||||
|     # Move to GPU and broadcast. | ||||
|     sizes_cuda = torch.cuda.LongTensor(sizes) | ||||
|     torch.distributed.broadcast(sizes_cuda, gpc.get_ranks_in_group(ParallelMode.TENSOR)[0], | ||||
|                                 group=gpc.get_group(ParallelMode.TENSOR)) | ||||
|  | ||||
|     # Move back to cpu and unpack. | ||||
|     sizes_cpu = sizes_cuda.cpu() | ||||
|     key_size = {} | ||||
|     key_numel = {} | ||||
|     total_numel = 0 | ||||
|     offset = 0 | ||||
|     for key in keys: | ||||
|         i = 0 | ||||
|         size = [] | ||||
|         numel = 1 | ||||
|         while sizes_cpu[offset + i] > 0: | ||||
|             this_size = sizes_cpu[offset + i] | ||||
|             size.append(this_size) | ||||
|             numel *= this_size | ||||
|             i += 1 | ||||
|         key_size[key] = size | ||||
|         key_numel[key] = numel | ||||
|         total_numel += numel | ||||
|         offset += max_dim | ||||
|  | ||||
|     return key_size, key_numel, total_numel | ||||
|  | ||||
|  | ||||
| def broadcast_data(keys, data, datatype): | ||||
|     """Broadcast data from rank zero of each model parallel group to the | ||||
|     members of the same model parallel group. | ||||
|  | ||||
|     Arguments: | ||||
|         keys: list of keys in the data dictionary to be broadcasted | ||||
|         data: data dictionary of string keys and cpu tensor values. | ||||
|         datatype: torch data type of all tensors in data associated | ||||
|                   with keys. | ||||
|     """ | ||||
|     # Build (key, size) and (key, number of elements) dictionaries along | ||||
|     # with the total number of elements on all ranks. | ||||
|     key_size, key_numel, total_numel = _build_key_size_numel_dictionaries(keys, | ||||
|                                                                           data) | ||||
|  | ||||
|     # Pack on rank zero. | ||||
|     if not gpc.is_initialized(ParallelMode.TENSOR) or gpc.get_local_rank(ParallelMode.TENSOR) == 0: | ||||
|         # Check that all keys have the same data type. | ||||
|         # Flatten the data associated with the keys | ||||
|         flatten_data = torch.cat( | ||||
|             [data[key].contiguous().view(-1) for key in keys], dim=0).cuda() | ||||
|     else: | ||||
|         flatten_data = torch.empty(total_numel, | ||||
|                                    device=torch.cuda.current_device(), | ||||
|                                    dtype=datatype) | ||||
|  | ||||
|     # Broadcast | ||||
|     torch.distributed.broadcast(flatten_data, | ||||
|                                 gpc.get_ranks_in_group(ParallelMode.TENSOR)[0], | ||||
|                                 group=gpc.get_group(ParallelMode.TENSOR)) | ||||
|  | ||||
|     # Unpack | ||||
|     output = {} | ||||
|     offset = 0 | ||||
|     for key in keys: | ||||
|         size = key_size[key] | ||||
|         numel = key_numel[key] | ||||
|         output[key] = flatten_data.narrow(0, offset, numel).view(size) | ||||
|         offset += numel | ||||
|  | ||||
|     return output | ||||
|  | ||||
|  | ||||
| def get_batch(data_iterator): | ||||
|     """Build the batch.""" | ||||
|  | ||||
|     # Items and their type. | ||||
|     keys = ['text', 'types', 'labels', 'is_random', 'loss_mask', 'padding_mask'] | ||||
|     datatype = torch.int64 | ||||
|  | ||||
|     # Broadcast data. | ||||
|     if data_iterator is not None: | ||||
|         data = next(data_iterator) | ||||
|     else: | ||||
|         data = None | ||||
|     data_b = broadcast_data(keys, data, datatype) | ||||
|  | ||||
|     # Unpack. | ||||
|     tokens = data_b['text'].long() | ||||
|     types = data_b['types'].long() | ||||
|     sentence_order = data_b['is_random'].long() | ||||
|     loss_mask = data_b['loss_mask'].float() | ||||
|     lm_labels = data_b['labels'].long() | ||||
|     padding_mask = data_b['padding_mask'].long() | ||||
|  | ||||
|     return tokens, types, sentence_order, loss_mask, lm_labels, padding_mask | ||||
|  | ||||
|  | ||||
| def get_batch_for_sequence_parallel(data_iterator): | ||||
|     """Build the batch.""" | ||||
|  | ||||
|     # Items and their type. | ||||
|     keys = ['text', 'types', 'labels', 'is_random', 'loss_mask', 'padding_mask'] | ||||
|     datatype = torch.int64 | ||||
|  | ||||
|     # Broadcast data. | ||||
|     if data_iterator is not None: | ||||
|         data = next(data_iterator) | ||||
|     else: | ||||
|         data = None | ||||
|  | ||||
|     # unpack | ||||
|     data_b = broadcast_data(keys, data, datatype) | ||||
|  | ||||
|     # # get tensor parallel local rank | ||||
|     global_rank = torch.distributed.get_rank() | ||||
|     local_world_size = 1 if not gpc.is_initialized(ParallelMode.TENSOR) else gpc.get_world_size(ParallelMode.TENSOR) | ||||
|     local_rank = global_rank % local_world_size | ||||
|     seq_length = data_b['text'].size(1) | ||||
|     sub_seq_length = seq_length // local_world_size | ||||
|     sub_seq_start = local_rank * sub_seq_length | ||||
|     sub_seq_end = (local_rank+1) * sub_seq_length | ||||
|     # | ||||
|     # # Unpack. | ||||
|     tokens = data_b['text'][:, sub_seq_start:sub_seq_end].long() | ||||
|     types = data_b['types'][:, sub_seq_start:sub_seq_end].long() | ||||
|     sentence_order = data_b['is_random'].long() | ||||
|     loss_mask = data_b['loss_mask'][:, sub_seq_start:sub_seq_end].float() | ||||
|     lm_labels = data_b['labels'][:, sub_seq_start:sub_seq_end].long() | ||||
|     padding_mask = data_b['padding_mask'].long() | ||||
|  | ||||
|     return tokens, types, sentence_order, loss_mask, lm_labels, padding_mask | ||||
|  | ||||
|  | ||||
| class SequenceParallelDataIterator: | ||||
|  | ||||
|     def __init__(self, data_iter): | ||||
|         self.data_iter = data_iter | ||||
|      | ||||
|  | ||||
|     def __iter__(self): | ||||
|         return self.data_iter | ||||
|  | ||||
|     def __next__(self): | ||||
|         return get_batch_for_sequence_parallel(self.data_iter) | ||||
| @@ -0,0 +1,9 @@ | ||||
| CXXFLAGS += -O3 -Wall -shared -std=c++11 -fPIC -fdiagnostics-color | ||||
| CPPFLAGS += $(shell python3 -m pybind11 --includes) | ||||
| LIBNAME = helpers | ||||
| LIBEXT = $(shell python3-config --extension-suffix) | ||||
|  | ||||
| default: $(LIBNAME)$(LIBEXT) | ||||
|  | ||||
| %$(LIBEXT): %.cpp | ||||
| 	$(CXX) $(CXXFLAGS) $(CPPFLAGS) $< -o $@ | ||||
| @@ -0,0 +1 @@ | ||||
| from . import indexed_dataset | ||||
| @@ -0,0 +1,225 @@ | ||||
| # coding=utf-8 | ||||
| # Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved. | ||||
| # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | ||||
| # You may obtain a copy of the License at | ||||
| # | ||||
| #     http://www.apache.org/licenses/LICENSE-2.0 | ||||
| # | ||||
| # Unless required by applicable law or agreed to in writing, software | ||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| """BERT Style dataset.""" | ||||
|  | ||||
| from colossalai.logging import get_dist_logger | ||||
| import numpy as np | ||||
| import torch | ||||
| from torch.utils.data import Dataset | ||||
|  | ||||
| from ..tokenizer import get_tokenizer | ||||
| from .dataset_utils import (get_a_and_b_segments, truncate_segments, create_tokens_and_tokentypes, | ||||
|                             create_masked_lm_predictions, pad_and_convert_to_numpy) | ||||
| from colossalai.core import global_context as gpc | ||||
| from colossalai.context import ParallelMode | ||||
| import time | ||||
| import os | ||||
| from . import helpers | ||||
|  | ||||
|  | ||||
| class BertDataset(Dataset): | ||||
|  | ||||
|     def __init__(self, name, indexed_dataset, data_prefix, num_epochs, max_num_samples, masked_lm_prob, max_seq_length, | ||||
|                  short_seq_prob, seed, binary_head): | ||||
|  | ||||
|         # Params to store. | ||||
|         self.name = name | ||||
|         self.seed = seed | ||||
|         self.masked_lm_prob = masked_lm_prob | ||||
|         self.max_seq_length = max_seq_length | ||||
|         self.binary_head = binary_head | ||||
|  | ||||
|         # Dataset. | ||||
|         self.indexed_dataset = indexed_dataset | ||||
|  | ||||
|         # Build the samples mapping. | ||||
|         self.samples_mapping = get_samples_mapping_( | ||||
|             self.indexed_dataset, | ||||
|             data_prefix, | ||||
|             num_epochs, | ||||
|             max_num_samples, | ||||
|             self.max_seq_length - 3,    # account for added tokens, | ||||
|             short_seq_prob, | ||||
|             self.seed, | ||||
|             self.name, | ||||
|             self.binary_head) | ||||
|  | ||||
|         # Vocab stuff. | ||||
|         tokenizer = get_tokenizer() | ||||
|         self.vocab_id_list = list(tokenizer.inv_vocab.keys()) | ||||
|         self.vocab_id_to_token_dict = tokenizer.inv_vocab | ||||
|         self.cls_id = tokenizer.cls | ||||
|         self.sep_id = tokenizer.sep | ||||
|         self.mask_id = tokenizer.mask | ||||
|         self.pad_id = tokenizer.pad | ||||
|  | ||||
|     def __len__(self): | ||||
|         return self.samples_mapping.shape[0] | ||||
|  | ||||
|     def __getitem__(self, idx): | ||||
|         start_idx, end_idx, seq_length = self.samples_mapping[idx] | ||||
|         sample = [self.indexed_dataset[i] for i in range(start_idx, end_idx)] | ||||
|         # Note that this rng state should be numpy and not python since | ||||
|         # python randint is inclusive whereas the numpy one is exclusive. | ||||
|         # We % 2**32 since numpy requires the seed to be between 0 and 2**32 - 1 | ||||
|         np_rng = np.random.RandomState(seed=((self.seed + idx) % 2**32)) | ||||
|         return build_training_sample( | ||||
|             sample, | ||||
|             seq_length, | ||||
|             self.max_seq_length,    # needed for padding | ||||
|             self.vocab_id_list, | ||||
|             self.vocab_id_to_token_dict, | ||||
|             self.cls_id, | ||||
|             self.sep_id, | ||||
|             self.mask_id, | ||||
|             self.pad_id, | ||||
|             self.masked_lm_prob, | ||||
|             np_rng, | ||||
|             self.binary_head) | ||||
|  | ||||
|  | ||||
| def get_samples_mapping_(indexed_dataset, data_prefix, num_epochs, max_num_samples, max_seq_length, short_seq_prob, | ||||
|                          seed, name, binary_head): | ||||
|     logger = get_dist_logger() | ||||
|     if not num_epochs: | ||||
|         if not max_num_samples: | ||||
|             raise ValueError("Need to specify either max_num_samples " | ||||
|                              "or num_epochs") | ||||
|         num_epochs = np.iinfo(np.int32).max - 1 | ||||
|     if not max_num_samples: | ||||
|         max_num_samples = np.iinfo(np.int64).max - 1 | ||||
|  | ||||
|     # Filename of the index mapping | ||||
|     indexmap_filename = data_prefix | ||||
|     indexmap_filename += '_{}_indexmap'.format(name) | ||||
|     if num_epochs != (np.iinfo(np.int32).max - 1): | ||||
|         indexmap_filename += '_{}ep'.format(num_epochs) | ||||
|     if max_num_samples != (np.iinfo(np.int64).max - 1): | ||||
|         indexmap_filename += '_{}mns'.format(max_num_samples) | ||||
|     indexmap_filename += '_{}msl'.format(max_seq_length) | ||||
|     indexmap_filename += '_{:0.2f}ssp'.format(short_seq_prob) | ||||
|     indexmap_filename += '_{}s'.format(seed) | ||||
|     indexmap_filename += '.npy' | ||||
|  | ||||
|     # Build the indexed mapping if not exist. | ||||
|     if torch.distributed.get_rank() == 0 and \ | ||||
|        not os.path.isfile(indexmap_filename): | ||||
|         print(' > WARNING: could not find index map file {}, building ' | ||||
|               'the indices on rank 0 ...'.format(indexmap_filename)) | ||||
|  | ||||
|         # Make sure the types match the helpers input types. | ||||
|         assert indexed_dataset.doc_idx.dtype == np.int64 | ||||
|         assert indexed_dataset.sizes.dtype == np.int32 | ||||
|  | ||||
|         # Build samples mapping | ||||
|         verbose = torch.distributed.get_rank() == 0 | ||||
|         start_time = time.time() | ||||
|         logger.info('\n > building samples index mapping for {} ...'.format(name), ranks=[0]) | ||||
|         # First compile and then import. | ||||
|         samples_mapping = helpers.build_mapping(indexed_dataset.doc_idx, indexed_dataset.sizes, num_epochs, | ||||
|                                                 max_num_samples, max_seq_length, short_seq_prob, seed, verbose, | ||||
|                                                 2 if binary_head else 1) | ||||
|         logger.info('\n > done building samples index maping', ranks=[0]) | ||||
|         np.save(indexmap_filename, samples_mapping, allow_pickle=True) | ||||
|         logger.info('\n > saved the index mapping in {}'.format(indexmap_filename), ranks=[0]) | ||||
|         # Make sure all the ranks have built the mapping | ||||
|         logger.info('\n > elapsed time to build and save samples mapping ' | ||||
|                     '(seconds): {:4f}'.format(time.time() - start_time), | ||||
|                     ranks=[0]) | ||||
|     # This should be a barrier but nccl barrier assumes | ||||
|     # device_index=rank which is not the case for model | ||||
|     # parallel case | ||||
|     counts = torch.cuda.LongTensor([1]) | ||||
|     torch.distributed.all_reduce(counts, group=gpc.get_group(ParallelMode.DATA)) | ||||
|     if gpc.is_initialized(ParallelMode.PIPELINE): | ||||
|         torch.distributed.all_reduce(counts, group=gpc.get_group(ParallelMode.PIPELINE)) | ||||
|     assert counts[0].item() == (torch.distributed.get_world_size() // | ||||
|                                 torch.distributed.get_world_size(group=gpc.get_group(ParallelMode.SEQUENCE))) | ||||
|  | ||||
|     # Load indexed dataset. | ||||
|     start_time = time.time() | ||||
|     samples_mapping = np.load(indexmap_filename, allow_pickle=True, mmap_mode='r') | ||||
|     logger.info('\n > loading indexed mapping from {}'.format(indexmap_filename) + | ||||
|                 '\n    loaded indexed file in {:3.3f} seconds'.format(time.time() - start_time) + | ||||
|                 '\n    total number of samples: {}'.format(samples_mapping.shape[0]), | ||||
|                 ranks=[0]) | ||||
|  | ||||
|     return samples_mapping | ||||
|  | ||||
|  | ||||
| def build_training_sample(sample, target_seq_length, max_seq_length, vocab_id_list, vocab_id_to_token_dict, cls_id, | ||||
|                           sep_id, mask_id, pad_id, masked_lm_prob, np_rng, binary_head): | ||||
|     """Build training sample. | ||||
|  | ||||
|     Arguments: | ||||
|         sample: A list of sentences in which each sentence is a list token ids. | ||||
|         target_seq_length: Desired sequence length. | ||||
|         max_seq_length: Maximum length of the sequence. All values are padded to | ||||
|             this length. | ||||
|         vocab_id_list: List of vocabulary ids. Used to pick a random id. | ||||
|         vocab_id_to_token_dict: A dictionary from vocab ids to text tokens. | ||||
|         cls_id: Start of example id. | ||||
|         sep_id: Separator id. | ||||
|         mask_id: Mask token id. | ||||
|         pad_id: Padding token id. | ||||
|         masked_lm_prob: Probability to mask tokens. | ||||
|         np_rng: Random number genenrator. Note that this rng state should be | ||||
|               numpy and not python since python randint is inclusive for | ||||
|               the opper bound whereas the numpy one is exclusive. | ||||
|     """ | ||||
|  | ||||
|     if binary_head: | ||||
|         # We assume that we have at least two sentences in the sample | ||||
|         assert len(sample) > 1 | ||||
|     assert target_seq_length <= max_seq_length | ||||
|  | ||||
|     # Divide sample into two segments (A and B). | ||||
|     if binary_head: | ||||
|         tokens_a, tokens_b, is_next_random = get_a_and_b_segments(sample, np_rng) | ||||
|     else: | ||||
|         tokens_a = [] | ||||
|         for j in range(len(sample)): | ||||
|             tokens_a.extend(sample[j]) | ||||
|         tokens_b = [] | ||||
|         is_next_random = False | ||||
|  | ||||
|     # Truncate to `target_sequence_length`. | ||||
|     max_num_tokens = target_seq_length | ||||
|     truncated = truncate_segments(tokens_a, tokens_b, len(tokens_a), len(tokens_b), max_num_tokens, np_rng) | ||||
|  | ||||
|     # Build tokens and toketypes. | ||||
|     tokens, tokentypes = create_tokens_and_tokentypes(tokens_a, tokens_b, cls_id, sep_id) | ||||
|  | ||||
|     # Masking. | ||||
|     max_predictions_per_seq = masked_lm_prob * max_num_tokens | ||||
|     (tokens, masked_positions, masked_labels, | ||||
|      _) = create_masked_lm_predictions(tokens, vocab_id_list, vocab_id_to_token_dict, masked_lm_prob, cls_id, sep_id, | ||||
|                                        mask_id, max_predictions_per_seq, np_rng) | ||||
|  | ||||
|     # Padding. | ||||
|     tokens_np, tokentypes_np, labels_np, padding_mask_np, loss_mask_np \ | ||||
|         = pad_and_convert_to_numpy(tokens, tokentypes, masked_positions, | ||||
|                                    masked_labels, pad_id, max_seq_length) | ||||
|  | ||||
|     train_sample = { | ||||
|         'text': tokens_np, | ||||
|         'types': tokentypes_np, | ||||
|         'labels': labels_np, | ||||
|         'is_random': int(is_next_random), | ||||
|         'loss_mask': loss_mask_np, | ||||
|         'padding_mask': padding_mask_np, | ||||
|         'truncated': int(truncated) | ||||
|     } | ||||
|     return train_sample | ||||
| @@ -0,0 +1,62 @@ | ||||
| # coding=utf-8 | ||||
| # Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved. | ||||
| # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | ||||
| # You may obtain a copy of the License at | ||||
| # | ||||
| #     http://www.apache.org/licenses/LICENSE-2.0 | ||||
| # | ||||
| # Unless required by applicable law or agreed to in writing, software | ||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
|  | ||||
| """Blendable dataset.""" | ||||
|  | ||||
| import time | ||||
|  | ||||
| import numpy as np | ||||
| import torch | ||||
|  | ||||
|  | ||||
| class BlendableDataset(torch.utils.data.Dataset): | ||||
|  | ||||
|     def __init__(self, datasets, weights): | ||||
|  | ||||
|         self.datasets = datasets | ||||
|         num_datasets = len(datasets) | ||||
|         assert num_datasets == len(weights) | ||||
|  | ||||
|         self.size = 0 | ||||
|         for dataset in self.datasets: | ||||
|             self.size += len(dataset) | ||||
|  | ||||
|         # Normalize weights. | ||||
|         weights = np.array(weights, dtype=np.float64) | ||||
|         sum_weights = np.sum(weights) | ||||
|         assert sum_weights > 0.0 | ||||
|         weights /= sum_weights | ||||
|  | ||||
|         # Build indices. | ||||
|         start_time = time.time() | ||||
|         assert num_datasets < 255 | ||||
|         self.dataset_index = np.zeros(self.size, dtype=np.uint8) | ||||
|         self.dataset_sample_index = np.zeros(self.size, dtype=np.int64) | ||||
|  | ||||
|         from . import helpers | ||||
|         helpers.build_blending_indices(self.dataset_index, | ||||
|                                        self.dataset_sample_index, | ||||
|                                        weights, num_datasets, self.size, | ||||
|                                        torch.distributed.get_rank() == 0) | ||||
|         print('> elapsed time for building blendable dataset indices: ' | ||||
|               '{:.2f} (sec)'.format(time.time() - start_time)) | ||||
|  | ||||
|     def __len__(self): | ||||
|         return self.size | ||||
|  | ||||
|     def __getitem__(self, idx): | ||||
|         dataset_idx = self.dataset_index[idx] | ||||
|         sample_idx = self.dataset_sample_index[idx] | ||||
|         return self.datasets[dataset_idx][sample_idx] | ||||
							
								
								
									
										152
									
								
								examples/tutorial/sequence_parallel/data/datasets/builder.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										152
									
								
								examples/tutorial/sequence_parallel/data/datasets/builder.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,152 @@ | ||||
| from .blendable_dataset import BlendableDataset | ||||
| from .dataset_utils import get_datasets_weights_and_num_samples, get_indexed_dataset_, get_train_valid_test_split_ | ||||
| from .bert_dataset import BertDataset | ||||
| from colossalai.logging import get_dist_logger | ||||
|  | ||||
| DSET_TYPE_BERT = 'standard_bert' | ||||
| DSET_TYPE_ICT = 'ict' | ||||
| DSET_TYPE_T5 = 't5' | ||||
|  | ||||
| DSET_TYPES = [DSET_TYPE_BERT, DSET_TYPE_ICT, DSET_TYPE_T5] | ||||
|  | ||||
|  | ||||
| def _build_train_valid_test_datasets(data_prefix, data_impl, splits_string, | ||||
|                                      train_valid_test_num_samples, | ||||
|                                      max_seq_length, masked_lm_prob, | ||||
|                                      short_seq_prob, seed, skip_warmup, | ||||
|                                      binary_head, | ||||
|                                      dataset_type='standard_bert'): | ||||
|  | ||||
|     if dataset_type not in DSET_TYPES: | ||||
|         raise ValueError("Invalid dataset_type: ", dataset_type) | ||||
|  | ||||
|     # Indexed dataset. | ||||
|     indexed_dataset = get_indexed_dataset_(data_prefix, | ||||
|                                            data_impl, | ||||
|                                            skip_warmup) | ||||
|  | ||||
|     # Get start and end indices of train/valid/train into doc-idx | ||||
|     # Note that doc-idx is designed to be num-docs + 1 so we can | ||||
|     # easily iterate over it. | ||||
|     total_num_of_documents = indexed_dataset.doc_idx.shape[0] - 1 | ||||
|     splits = get_train_valid_test_split_(splits_string, total_num_of_documents) | ||||
|  | ||||
|     logger = get_dist_logger() | ||||
|  | ||||
|     # Print stats about the splits. | ||||
|     logger.info('\n > dataset split:', ranks=[0]) | ||||
|  | ||||
|     def print_split_stats(name, index): | ||||
|         start_index = indexed_dataset.doc_idx[splits[index]] | ||||
|         end_index = indexed_dataset.doc_idx[splits[index + 1]] | ||||
|         logger.info('\n    {}:'.format(name) + | ||||
|                     '\n     document indices in [{}, {}) total of {} documents'.format( | ||||
|                         splits[index], splits[index + 1], | ||||
|                         splits[index + 1] - splits[index]) + | ||||
|                     '\n     sentence indices in [{}, {}) total of {} sentences'.format( | ||||
|                         start_index, end_index, | ||||
|                         end_index - start_index), | ||||
|                     ranks=[0]) | ||||
|     print_split_stats('train', 0) | ||||
|     print_split_stats('validation', 1) | ||||
|     print_split_stats('test', 2) | ||||
|  | ||||
|     def build_dataset(index, name): | ||||
|         dataset = None | ||||
|         if splits[index + 1] > splits[index]: | ||||
|             # Get the pointer to the original doc-idx so we can set it later. | ||||
|             doc_idx_ptr = indexed_dataset.get_doc_idx() | ||||
|             # Slice the doc-idx | ||||
|             start_index = splits[index] | ||||
|             # Add +1 so we can index into the dataset to get the upper bound. | ||||
|             end_index = splits[index + 1] + 1 | ||||
|             # New doc_idx view. | ||||
|             indexed_dataset.set_doc_idx(doc_idx_ptr[start_index:end_index]) | ||||
|             # Build the dataset accordingly. | ||||
|             kwargs = dict( | ||||
|                 name=name, | ||||
|                 data_prefix=data_prefix, | ||||
|                 num_epochs=None, | ||||
|                 max_num_samples=train_valid_test_num_samples[index], | ||||
|                 max_seq_length=max_seq_length, | ||||
|                 seed=seed, | ||||
|             ) | ||||
|  | ||||
|             if dataset_type != DSET_TYPE_BERT: | ||||
|                 raise NotImplementedError("Only BERT dataset is supported") | ||||
|             else: | ||||
|                 dataset = BertDataset( | ||||
|                     indexed_dataset=indexed_dataset, | ||||
|                     masked_lm_prob=masked_lm_prob, | ||||
|                     short_seq_prob=short_seq_prob, | ||||
|                     binary_head=binary_head, | ||||
|                     **kwargs | ||||
|                 ) | ||||
|  | ||||
|             # Set the original pointer so dataset remains the main dataset. | ||||
|             indexed_dataset.set_doc_idx(doc_idx_ptr) | ||||
|             # Checks. | ||||
|             assert indexed_dataset.doc_idx[0] == 0 | ||||
|             assert indexed_dataset.doc_idx.shape[0] == \ | ||||
|                 (total_num_of_documents + 1) | ||||
|         return dataset | ||||
|  | ||||
|     train_dataset = build_dataset(0, 'train') | ||||
|     valid_dataset = build_dataset(1, 'valid') | ||||
|     test_dataset = build_dataset(2, 'test') | ||||
|  | ||||
|     return (train_dataset, valid_dataset, test_dataset) | ||||
|  | ||||
|  | ||||
| def build_train_valid_test_datasets(data_prefix, data_impl, splits_string, | ||||
|                                     train_valid_test_num_samples, | ||||
|                                     max_seq_length, masked_lm_prob, | ||||
|                                     short_seq_prob, seed, skip_warmup, | ||||
|                                     binary_head, | ||||
|                                     dataset_type='standard_bert'): | ||||
|  | ||||
|     if len(data_prefix) == 1: | ||||
|         return _build_train_valid_test_datasets(data_prefix[0], | ||||
|                                                 data_impl, splits_string, | ||||
|                                                 train_valid_test_num_samples, | ||||
|                                                 max_seq_length, masked_lm_prob, | ||||
|                                                 short_seq_prob, seed, | ||||
|                                                 skip_warmup, | ||||
|                                                 binary_head, | ||||
|                                                 dataset_type=dataset_type) | ||||
|     # Blending dataset. | ||||
|     # Parse the values. | ||||
|     output = get_datasets_weights_and_num_samples(data_prefix, | ||||
|                                                   train_valid_test_num_samples) | ||||
|     prefixes, weights, datasets_train_valid_test_num_samples = output | ||||
|  | ||||
|     # Build individual datasets. | ||||
|     train_datasets = [] | ||||
|     valid_datasets = [] | ||||
|     test_datasets = [] | ||||
|     for i in range(len(prefixes)): | ||||
|         train_ds, valid_ds, test_ds = _build_train_valid_test_datasets( | ||||
|             prefixes[i], data_impl, splits_string, | ||||
|             datasets_train_valid_test_num_samples[i], | ||||
|             max_seq_length, masked_lm_prob, short_seq_prob, | ||||
|             seed, skip_warmup, binary_head, dataset_type=dataset_type) | ||||
|         if train_ds: | ||||
|             train_datasets.append(train_ds) | ||||
|         if valid_ds: | ||||
|             valid_datasets.append(valid_ds) | ||||
|         if test_ds: | ||||
|             test_datasets.append(test_ds) | ||||
|  | ||||
|         # Blend. | ||||
|     blending_train_dataset = None | ||||
|     if train_datasets: | ||||
|         blending_train_dataset = BlendableDataset(train_datasets, weights) | ||||
|     blending_valid_dataset = None | ||||
|     if valid_datasets: | ||||
|         blending_valid_dataset = BlendableDataset(valid_datasets, weights) | ||||
|     blending_test_dataset = None | ||||
|     if test_datasets: | ||||
|         blending_test_dataset = BlendableDataset(test_datasets, weights) | ||||
|  | ||||
|     return (blending_train_dataset, blending_valid_dataset, | ||||
|             blending_test_dataset) | ||||
| @@ -0,0 +1,153 @@ | ||||
| # coding=utf-8 | ||||
| # Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved. | ||||
| # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | ||||
| # You may obtain a copy of the License at | ||||
| # | ||||
| #     http://www.apache.org/licenses/LICENSE-2.0 | ||||
| # | ||||
| # Unless required by applicable law or agreed to in writing, software | ||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| """Dataloaders.""" | ||||
|  | ||||
| import torch | ||||
| import random | ||||
| from colossalai.core import global_context as gpc | ||||
| from colossalai.context import ParallelMode | ||||
|  | ||||
|  | ||||
| def build_pretraining_data_loader(dataset, consumed_samples, micro_batch_size, dataloader_type='single', num_workers=0): | ||||
|     """Build dataloader given an input dataset.""" | ||||
|  | ||||
|     if dataset is None: | ||||
|         return None | ||||
|  | ||||
|     # Megatron sampler | ||||
|     if dataloader_type == 'single': | ||||
|         batch_sampler = MegatronPretrainingSampler(total_samples=len(dataset), | ||||
|                                                    consumed_samples=consumed_samples, | ||||
|                                                    micro_batch_size=micro_batch_size, | ||||
|                                                    data_parallel_rank=gpc.get_local_rank(ParallelMode.DATA), | ||||
|                                                    data_parallel_size=gpc.get_world_size(ParallelMode.DATA)) | ||||
|     elif dataloader_type == 'cyclic': | ||||
|         batch_sampler = MegatronPretrainingRandomSampler(total_samples=len(dataset), | ||||
|                                                          consumed_samples=consumed_samples, | ||||
|                                                          micro_batch_size=micro_batch_size, | ||||
|                                                          data_parallel_rank=gpc.get_local_rank(ParallelMode.DATA), | ||||
|                                                          data_parallel_size=gpc.get_world_size(ParallelMode.DATA)) | ||||
|     else: | ||||
|         raise Exception('{} dataloader type is not supported.'.format(dataloader_type)) | ||||
|  | ||||
|     # Torch dataloader. | ||||
|     return torch.utils.data.DataLoader(dataset, batch_sampler=batch_sampler, num_workers=num_workers, pin_memory=True) | ||||
|  | ||||
|  | ||||
| class MegatronPretrainingSampler: | ||||
|  | ||||
|     def __init__(self, | ||||
|                  total_samples, | ||||
|                  consumed_samples, | ||||
|                  micro_batch_size, | ||||
|                  data_parallel_rank, | ||||
|                  data_parallel_size, | ||||
|                  drop_last=True): | ||||
|         # Keep a copy of input params for later use. | ||||
|         self.total_samples = total_samples | ||||
|         self.consumed_samples = consumed_samples | ||||
|         self.micro_batch_size = micro_batch_size | ||||
|         self.data_parallel_rank = data_parallel_rank | ||||
|         self.micro_batch_times_data_parallel_size = \ | ||||
|             self.micro_batch_size * data_parallel_size | ||||
|         self.drop_last = drop_last | ||||
|  | ||||
|         # Sanity checks. | ||||
|         assert self.total_samples > 0, \ | ||||
|             'no sample to consume: {}'.format(self.total_samples) | ||||
|         assert self.consumed_samples < self.total_samples, \ | ||||
|             'no samples left to consume: {}, {}'.format(self.consumed_samples, | ||||
|                                                         self.total_samples) | ||||
|         assert self.micro_batch_size > 0 | ||||
|         assert data_parallel_size > 0 | ||||
|         assert self.data_parallel_rank < data_parallel_size, \ | ||||
|             'data_parallel_rank should be smaller than data size: {}, ' \ | ||||
|             '{}'.format(self.data_parallel_rank, data_parallel_size) | ||||
|  | ||||
|     def __len__(self): | ||||
|         return self.total_samples | ||||
|  | ||||
|     def get_start_end_idx(self): | ||||
|         start_idx = self.data_parallel_rank * self.micro_batch_size | ||||
|         end_idx = start_idx + self.micro_batch_size | ||||
|         return start_idx, end_idx | ||||
|  | ||||
|     def __iter__(self): | ||||
|         batch = [] | ||||
|         # Last batch will be dropped if drop_last is not set False | ||||
|         for idx in range(self.consumed_samples, self.total_samples): | ||||
|             batch.append(idx) | ||||
|             if len(batch) == self.micro_batch_times_data_parallel_size: | ||||
|                 start_idx, end_idx = self.get_start_end_idx() | ||||
|                 yield batch[start_idx:end_idx] | ||||
|                 batch = [] | ||||
|  | ||||
|         # Check the last partial batch and see drop_last is set | ||||
|         if len(batch) > 0 and not self.drop_last: | ||||
|             start_idx, end_idx = self.get_start_end_idx() | ||||
|             yield batch[start_idx:end_idx] | ||||
|  | ||||
|  | ||||
| class MegatronPretrainingRandomSampler: | ||||
|  | ||||
|     def __init__(self, total_samples, consumed_samples, micro_batch_size, data_parallel_rank, data_parallel_size): | ||||
|         # Keep a copy of input params for later use. | ||||
|         self.total_samples = total_samples | ||||
|         self.consumed_samples = consumed_samples | ||||
|         self.micro_batch_size = micro_batch_size | ||||
|         self.data_parallel_rank = data_parallel_rank | ||||
|         self.data_parallel_size = data_parallel_size | ||||
|         self.micro_batch_times_data_parallel_size = \ | ||||
|             self.micro_batch_size * data_parallel_size | ||||
|         self.last_batch_size = \ | ||||
|             self.total_samples % self.micro_batch_times_data_parallel_size | ||||
|  | ||||
|         # Sanity checks. | ||||
|         assert self.total_samples > 0, \ | ||||
|             'no sample to consume: {}'.format(self.total_samples) | ||||
|         assert self.micro_batch_size > 0 | ||||
|         assert data_parallel_size > 0 | ||||
|         assert self.data_parallel_rank < data_parallel_size, \ | ||||
|             'data_parallel_rank should be smaller than data size: {}, ' \ | ||||
|             '{}'.format(self.data_parallel_rank, data_parallel_size) | ||||
|  | ||||
|     def __len__(self): | ||||
|         return self.total_samples | ||||
|  | ||||
|     def __iter__(self): | ||||
|         active_total_samples = self.total_samples - self.last_batch_size | ||||
|         self.epoch = self.consumed_samples // active_total_samples | ||||
|         current_epoch_samples = self.consumed_samples % active_total_samples | ||||
|         assert current_epoch_samples % self.micro_batch_times_data_parallel_size == 0 | ||||
|  | ||||
|         # data sharding and random sampling | ||||
|         bucket_size = (self.total_samples // self.micro_batch_times_data_parallel_size) \ | ||||
|             * self.micro_batch_size | ||||
|         bucket_offset = current_epoch_samples // self.data_parallel_size | ||||
|         start_idx = self.data_parallel_rank * bucket_size | ||||
|  | ||||
|         g = torch.Generator() | ||||
|         g.manual_seed(self.epoch) | ||||
|         random_idx = torch.randperm(bucket_size, generator=g).tolist() | ||||
|         idx_range = [start_idx + x for x in random_idx[bucket_offset:]] | ||||
|  | ||||
|         batch = [] | ||||
|         # Last batch if not complete will be dropped. | ||||
|         for idx in idx_range: | ||||
|             batch.append(idx) | ||||
|             if len(batch) == self.micro_batch_size: | ||||
|                 self.consumed_samples += self.micro_batch_times_data_parallel_size | ||||
|                 yield batch | ||||
|                 batch = [] | ||||
| @@ -0,0 +1,592 @@ | ||||
| # coding=utf-8 | ||||
| # Copyright 2018 The Google AI Language Team Authors, and NVIDIA. | ||||
| # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | ||||
| # You may obtain a copy of the License at | ||||
| # | ||||
| #     http://www.apache.org/licenses/LICENSE-2.0 | ||||
| # | ||||
| # Unless required by applicable law or agreed to in writing, software | ||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
|  | ||||
|  | ||||
| # Most of the code here has been copied from: | ||||
| #   https://github.com/google-research/albert/blob/master/create_pretraining_data.py | ||||
| # with some modifications. | ||||
|  | ||||
| import math | ||||
| import time | ||||
| import collections | ||||
| from colossalai.logging import get_dist_logger | ||||
| import numpy as np | ||||
| from .blendable_dataset import BlendableDataset | ||||
| from .indexed_dataset import make_dataset as make_indexed_dataset | ||||
|  | ||||
| DSET_TYPE_STD = 'standard_bert' | ||||
| DSET_TYPE_ICT = 'ict' | ||||
|  | ||||
| DSET_TYPES = [DSET_TYPE_ICT, DSET_TYPE_STD] | ||||
|  | ||||
|  | ||||
| def get_datasets_weights_and_num_samples(data_prefix, | ||||
|                                          train_valid_test_num_samples): | ||||
|  | ||||
|     # The data prefix should be in the format of: | ||||
|     #   weight-1, data-prefix-1, weight-2, data-prefix-2, .. | ||||
|     assert len(data_prefix) % 2 == 0 | ||||
|     num_datasets = len(data_prefix) // 2 | ||||
|     weights = [0]*num_datasets | ||||
|     prefixes = [0]*num_datasets | ||||
|     for i in range(num_datasets): | ||||
|         weights[i] = float(data_prefix[2*i]) | ||||
|         prefixes[i] = (data_prefix[2*i+1]).strip() | ||||
|     # Normalize weights | ||||
|     weight_sum = 0.0 | ||||
|     for weight in weights: | ||||
|         weight_sum += weight | ||||
|     assert weight_sum > 0.0 | ||||
|     weights = [weight / weight_sum for weight in weights] | ||||
|  | ||||
|     # Add 0.5% (the 1.005 factor) so in case the bleding dataset does | ||||
|     # not uniformly distribute the number of samples, we still have | ||||
|     # samples left to feed to the network. | ||||
|     datasets_train_valid_test_num_samples = [] | ||||
|     for weight in weights: | ||||
|         datasets_train_valid_test_num_samples.append( | ||||
|             [int(math.ceil(val * weight * 1.005)) | ||||
|              for val in train_valid_test_num_samples]) | ||||
|  | ||||
|     return prefixes, weights, datasets_train_valid_test_num_samples | ||||
|  | ||||
|  | ||||
| def compile_helper(): | ||||
|     """Compile helper function ar runtime. Make sure this | ||||
|     is invoked on a single process.""" | ||||
|     import os | ||||
|     import subprocess | ||||
|     path = os.path.abspath(os.path.dirname(__file__)) | ||||
|     ret = subprocess.run(['make', '-C', path]) | ||||
|     if ret.returncode != 0: | ||||
|         print("Making C++ dataset helpers module failed, exiting.") | ||||
|         import sys | ||||
|         sys.exit(1) | ||||
|  | ||||
|  | ||||
| def get_a_and_b_segments(sample, np_rng): | ||||
|     """Divide sample into a and b segments.""" | ||||
|  | ||||
|     # Number of sentences in the sample. | ||||
|     n_sentences = len(sample) | ||||
|     # Make sure we always have two sentences. | ||||
|     assert n_sentences > 1, 'make sure each sample has at least two sentences.' | ||||
|  | ||||
|     # First part: | ||||
|     # `a_end` is how many sentences go into the `A`. | ||||
|     a_end = 1 | ||||
|     if n_sentences >= 3: | ||||
|         # Note that randin in numpy is exclusive. | ||||
|         a_end = np_rng.randint(1, n_sentences) | ||||
|     tokens_a = [] | ||||
|     for j in range(a_end): | ||||
|         tokens_a.extend(sample[j]) | ||||
|  | ||||
|     # Second part: | ||||
|     tokens_b = [] | ||||
|     for j in range(a_end, n_sentences): | ||||
|         tokens_b.extend(sample[j]) | ||||
|  | ||||
|     # Random next: | ||||
|     is_next_random = False | ||||
|     if np_rng.random() < 0.5: | ||||
|         is_next_random = True | ||||
|         tokens_a, tokens_b = tokens_b, tokens_a | ||||
|  | ||||
|     return tokens_a, tokens_b, is_next_random | ||||
|  | ||||
|  | ||||
| def truncate_segments(tokens_a, tokens_b, len_a, len_b, max_num_tokens, np_rng): | ||||
|     """Truncates a pair of sequences to a maximum sequence length.""" | ||||
|     #print(len_a, len_b, max_num_tokens) | ||||
|     assert len_a > 0 | ||||
|     if len_a + len_b <= max_num_tokens: | ||||
|         return False | ||||
|     while len_a + len_b > max_num_tokens: | ||||
|         if len_a > len_b: | ||||
|             len_a -= 1 | ||||
|             tokens = tokens_a | ||||
|         else: | ||||
|             len_b -= 1 | ||||
|             tokens = tokens_b | ||||
|         if np_rng.random() < 0.5: | ||||
|             del tokens[0] | ||||
|         else: | ||||
|             tokens.pop() | ||||
|     return True | ||||
|  | ||||
|  | ||||
| def create_tokens_and_tokentypes(tokens_a, tokens_b, cls_id, sep_id): | ||||
|     """Merge segments A and B, add [CLS] and [SEP] and build tokentypes.""" | ||||
|  | ||||
|     tokens = [] | ||||
|     tokentypes = [] | ||||
|     # [CLS]. | ||||
|     tokens.append(cls_id) | ||||
|     tokentypes.append(0) | ||||
|     # Segment A. | ||||
|     for token in tokens_a: | ||||
|         tokens.append(token) | ||||
|         tokentypes.append(0) | ||||
|     # [SEP]. | ||||
|     tokens.append(sep_id) | ||||
|     tokentypes.append(0) | ||||
|     # Segment B. | ||||
|     for token in tokens_b: | ||||
|         tokens.append(token) | ||||
|         tokentypes.append(1) | ||||
|     if tokens_b: | ||||
|         # [SEP]. | ||||
|         tokens.append(sep_id) | ||||
|         tokentypes.append(1) | ||||
|  | ||||
|     return tokens, tokentypes | ||||
|  | ||||
|  | ||||
| MaskedLmInstance = collections.namedtuple("MaskedLmInstance", | ||||
|                                           ["index", "label"]) | ||||
|  | ||||
|  | ||||
| def is_start_piece(piece): | ||||
|     """Check if the current word piece is the starting piece (BERT).""" | ||||
|     # When a word has been split into | ||||
|     # WordPieces, the first token does not have any marker and any subsequence | ||||
|     # tokens are prefixed with ##. So whenever we see the ## token, we | ||||
|     # append it to the previous set of word indexes. | ||||
|     return not piece.startswith("##") | ||||
|  | ||||
|  | ||||
| def create_masked_lm_predictions(tokens, | ||||
|                                  vocab_id_list, vocab_id_to_token_dict, | ||||
|                                  masked_lm_prob, | ||||
|                                  cls_id, sep_id, mask_id, | ||||
|                                  max_predictions_per_seq, | ||||
|                                  np_rng, | ||||
|                                  max_ngrams=3, | ||||
|                                  do_whole_word_mask=True, | ||||
|                                  favor_longer_ngram=False, | ||||
|                                  do_permutation=False): | ||||
|     """Creates the predictions for the masked LM objective. | ||||
|     Note: Tokens here are vocab ids and not text tokens.""" | ||||
|  | ||||
|     cand_indexes = [] | ||||
|     # Note(mingdachen): We create a list for recording if the piece is | ||||
|     # the starting piece of current token, where 1 means true, so that | ||||
|     # on-the-fly whole word masking is possible. | ||||
|     token_boundary = [0] * len(tokens) | ||||
|  | ||||
|     for (i, token) in enumerate(tokens): | ||||
|         if token == cls_id or token == sep_id: | ||||
|             token_boundary[i] = 1 | ||||
|             continue | ||||
|         # Whole Word Masking means that if we mask all of the wordpieces | ||||
|         # corresponding to an original word. | ||||
|         # | ||||
|         # Note that Whole Word Masking does *not* change the training code | ||||
|         # at all -- we still predict each WordPiece independently, softmaxed | ||||
|         # over the entire vocabulary. | ||||
|         if (do_whole_word_mask and len(cand_indexes) >= 1 and | ||||
|                 not is_start_piece(vocab_id_to_token_dict[token])): | ||||
|             cand_indexes[-1].append(i) | ||||
|         else: | ||||
|             cand_indexes.append([i]) | ||||
|             if is_start_piece(vocab_id_to_token_dict[token]): | ||||
|                 token_boundary[i] = 1 | ||||
|  | ||||
|     output_tokens = list(tokens) | ||||
|  | ||||
|     masked_lm_positions = [] | ||||
|     masked_lm_labels = [] | ||||
|  | ||||
|     if masked_lm_prob == 0: | ||||
|         return (output_tokens, masked_lm_positions, | ||||
|                 masked_lm_labels, token_boundary) | ||||
|  | ||||
|     num_to_predict = min(max_predictions_per_seq, | ||||
|                          max(1, int(round(len(tokens) * masked_lm_prob)))) | ||||
|  | ||||
|     # Note(mingdachen): | ||||
|     # By default, we set the probabilities to favor shorter ngram sequences. | ||||
|     ngrams = np.arange(1, max_ngrams + 1, dtype=np.int64) | ||||
|     pvals = 1. / np.arange(1, max_ngrams + 1) | ||||
|     pvals /= pvals.sum(keepdims=True) | ||||
|  | ||||
|     if favor_longer_ngram: | ||||
|         pvals = pvals[::-1] | ||||
|  | ||||
|     ngram_indexes = [] | ||||
|     for idx in range(len(cand_indexes)): | ||||
|         ngram_index = [] | ||||
|         for n in ngrams: | ||||
|             ngram_index.append(cand_indexes[idx:idx + n]) | ||||
|         ngram_indexes.append(ngram_index) | ||||
|  | ||||
|     np_rng.shuffle(ngram_indexes) | ||||
|  | ||||
|     masked_lms = [] | ||||
|     covered_indexes = set() | ||||
|     for cand_index_set in ngram_indexes: | ||||
|         if len(masked_lms) >= num_to_predict: | ||||
|             break | ||||
|         if not cand_index_set: | ||||
|             continue | ||||
|         # Note(mingdachen): | ||||
|         # Skip current piece if they are covered in lm masking or previous ngrams. | ||||
|         for index_set in cand_index_set[0]: | ||||
|             for index in index_set: | ||||
|                 if index in covered_indexes: | ||||
|                     continue | ||||
|  | ||||
|         n = np_rng.choice(ngrams[:len(cand_index_set)], | ||||
|                           p=pvals[:len(cand_index_set)] / | ||||
|                           pvals[:len(cand_index_set)].sum(keepdims=True)) | ||||
|         index_set = sum(cand_index_set[n - 1], []) | ||||
|         n -= 1 | ||||
|         # Note(mingdachen): | ||||
|         # Repeatedly looking for a candidate that does not exceed the | ||||
|         # maximum number of predictions by trying shorter ngrams. | ||||
|         while len(masked_lms) + len(index_set) > num_to_predict: | ||||
|             if n == 0: | ||||
|                 break | ||||
|             index_set = sum(cand_index_set[n - 1], []) | ||||
|             n -= 1 | ||||
|         # If adding a whole-word mask would exceed the maximum number of | ||||
|         # predictions, then just skip this candidate. | ||||
|         if len(masked_lms) + len(index_set) > num_to_predict: | ||||
|             continue | ||||
|         is_any_index_covered = False | ||||
|         for index in index_set: | ||||
|             if index in covered_indexes: | ||||
|                 is_any_index_covered = True | ||||
|                 break | ||||
|         if is_any_index_covered: | ||||
|             continue | ||||
|         for index in index_set: | ||||
|             covered_indexes.add(index) | ||||
|  | ||||
|             masked_token = None | ||||
|             # 80% of the time, replace with [MASK] | ||||
|             if np_rng.random() < 0.8: | ||||
|                 masked_token = mask_id | ||||
|             else: | ||||
|                 # 10% of the time, keep original | ||||
|                 if np_rng.random() < 0.5: | ||||
|                     masked_token = tokens[index] | ||||
|                 # 10% of the time, replace with random word | ||||
|                 else: | ||||
|                     masked_token = vocab_id_list[np_rng.randint(0, len(vocab_id_list))] | ||||
|  | ||||
|             output_tokens[index] = masked_token | ||||
|  | ||||
|             masked_lms.append(MaskedLmInstance(index=index, label=tokens[index])) | ||||
|     assert len(masked_lms) <= num_to_predict | ||||
|  | ||||
|     np_rng.shuffle(ngram_indexes) | ||||
|  | ||||
|     select_indexes = set() | ||||
|     if do_permutation: | ||||
|         for cand_index_set in ngram_indexes: | ||||
|             if len(select_indexes) >= num_to_predict: | ||||
|                 break | ||||
|             if not cand_index_set: | ||||
|                 continue | ||||
|             # Note(mingdachen): | ||||
|             # Skip current piece if they are covered in lm masking or previous ngrams. | ||||
|             for index_set in cand_index_set[0]: | ||||
|                 for index in index_set: | ||||
|                     if index in covered_indexes or index in select_indexes: | ||||
|                         continue | ||||
|  | ||||
|             n = np.random.choice(ngrams[:len(cand_index_set)], | ||||
|                                  p=pvals[:len(cand_index_set)] / | ||||
|                                  pvals[:len(cand_index_set)].sum(keepdims=True)) | ||||
|             index_set = sum(cand_index_set[n - 1], []) | ||||
|             n -= 1 | ||||
|  | ||||
|             while len(select_indexes) + len(index_set) > num_to_predict: | ||||
|                 if n == 0: | ||||
|                     break | ||||
|                 index_set = sum(cand_index_set[n - 1], []) | ||||
|                 n -= 1 | ||||
|             # If adding a whole-word mask would exceed the maximum number of | ||||
|             # predictions, then just skip this candidate. | ||||
|             if len(select_indexes) + len(index_set) > num_to_predict: | ||||
|                 continue | ||||
|             is_any_index_covered = False | ||||
|             for index in index_set: | ||||
|                 if index in covered_indexes or index in select_indexes: | ||||
|                     is_any_index_covered = True | ||||
|                     break | ||||
|             if is_any_index_covered: | ||||
|                 continue | ||||
|             for index in index_set: | ||||
|                 select_indexes.add(index) | ||||
|         assert len(select_indexes) <= num_to_predict | ||||
|  | ||||
|         select_indexes = sorted(select_indexes) | ||||
|         permute_indexes = list(select_indexes) | ||||
|         np_rng.shuffle(permute_indexes) | ||||
|         orig_token = list(output_tokens) | ||||
|  | ||||
|         for src_i, tgt_i in zip(select_indexes, permute_indexes): | ||||
|             output_tokens[src_i] = orig_token[tgt_i] | ||||
|             masked_lms.append(MaskedLmInstance(index=src_i, label=orig_token[src_i])) | ||||
|  | ||||
|     masked_lms = sorted(masked_lms, key=lambda x: x.index) | ||||
|  | ||||
|     for p in masked_lms: | ||||
|         masked_lm_positions.append(p.index) | ||||
|         masked_lm_labels.append(p.label) | ||||
|  | ||||
|     return (output_tokens, masked_lm_positions, masked_lm_labels, token_boundary) | ||||
|  | ||||
|  | ||||
| def pad_and_convert_to_numpy(tokens, tokentypes, masked_positions, | ||||
|                              masked_labels, pad_id, max_seq_length): | ||||
|     """Pad sequences and convert them to numpy.""" | ||||
|  | ||||
|     # Some checks. | ||||
|     num_tokens = len(tokens) | ||||
|     padding_length = max_seq_length - num_tokens | ||||
|     assert padding_length >= 0 | ||||
|     assert len(tokentypes) == num_tokens | ||||
|     assert len(masked_positions) == len(masked_labels) | ||||
|  | ||||
|     # Tokens and token types. | ||||
|     filler = [pad_id] * padding_length | ||||
|     tokens_np = np.array(tokens + filler, dtype=np.int64) | ||||
|     tokentypes_np = np.array(tokentypes + filler, dtype=np.int64) | ||||
|  | ||||
|     # Padding mask. | ||||
|     padding_mask_np = np.array([1] * num_tokens + [0] * padding_length, | ||||
|                                dtype=np.int64) | ||||
|  | ||||
|     # Lables and loss mask. | ||||
|     labels = [-1] * max_seq_length | ||||
|     loss_mask = [0] * max_seq_length | ||||
|     for i in range(len(masked_positions)): | ||||
|         assert masked_positions[i] < num_tokens | ||||
|         labels[masked_positions[i]] = masked_labels[i] | ||||
|         loss_mask[masked_positions[i]] = 1 | ||||
|     labels_np = np.array(labels, dtype=np.int64) | ||||
|     loss_mask_np = np.array(loss_mask, dtype=np.int64) | ||||
|  | ||||
|     return tokens_np, tokentypes_np, labels_np, padding_mask_np, loss_mask_np | ||||
|  | ||||
|  | ||||
| def build_train_valid_test_datasets(data_prefix, data_impl, splits_string, | ||||
|                                     train_valid_test_num_samples, | ||||
|                                     max_seq_length, masked_lm_prob, | ||||
|                                     short_seq_prob, seed, skip_warmup, | ||||
|                                     binary_head, | ||||
|                                     dataset_type='standard_bert'): | ||||
|  | ||||
|     if len(data_prefix) == 1: | ||||
|         return _build_train_valid_test_datasets(data_prefix[0], | ||||
|                                                 data_impl, splits_string, | ||||
|                                                 train_valid_test_num_samples, | ||||
|                                                 max_seq_length, masked_lm_prob, | ||||
|                                                 short_seq_prob, seed, | ||||
|                                                 skip_warmup, | ||||
|                                                 binary_head, | ||||
|                                                 dataset_type=dataset_type) | ||||
|     # Blending dataset. | ||||
|     # Parse the values. | ||||
|     output = get_datasets_weights_and_num_samples(data_prefix, | ||||
|                                                   train_valid_test_num_samples) | ||||
|     prefixes, weights, datasets_train_valid_test_num_samples = output | ||||
|  | ||||
|     # Build individual datasets. | ||||
|     train_datasets = [] | ||||
|     valid_datasets = [] | ||||
|     test_datasets = [] | ||||
|     for i in range(len(prefixes)): | ||||
|         train_ds, valid_ds, test_ds = _build_train_valid_test_datasets( | ||||
|             prefixes[i], data_impl, splits_string, | ||||
|             datasets_train_valid_test_num_samples[i], | ||||
|             max_seq_length, masked_lm_prob, short_seq_prob, | ||||
|             seed, skip_warmup, binary_head, dataset_type=dataset_type) | ||||
|         if train_ds: | ||||
|             train_datasets.append(train_ds) | ||||
|         if valid_ds: | ||||
|             valid_datasets.append(valid_ds) | ||||
|         if test_ds: | ||||
|             test_datasets.append(test_ds) | ||||
|  | ||||
|         # Blend. | ||||
|     blending_train_dataset = None | ||||
|     if train_datasets: | ||||
|         blending_train_dataset = BlendableDataset(train_datasets, weights) | ||||
|     blending_valid_dataset = None | ||||
|     if valid_datasets: | ||||
|         blending_valid_dataset = BlendableDataset(valid_datasets, weights) | ||||
|     blending_test_dataset = None | ||||
|     if test_datasets: | ||||
|         blending_test_dataset = BlendableDataset(test_datasets, weights) | ||||
|  | ||||
|     return (blending_train_dataset, blending_valid_dataset, | ||||
|             blending_test_dataset) | ||||
|  | ||||
|  | ||||
| def _build_train_valid_test_datasets(data_prefix, data_impl, splits_string, | ||||
|                                      train_valid_test_num_samples, | ||||
|                                      max_seq_length, masked_lm_prob, | ||||
|                                      short_seq_prob, seed, skip_warmup, | ||||
|                                      binary_head, | ||||
|                                      dataset_type='standard_bert'): | ||||
|     logger = get_dist_logger() | ||||
|  | ||||
|     if dataset_type not in DSET_TYPES: | ||||
|         raise ValueError("Invalid dataset_type: ", dataset_type) | ||||
|  | ||||
|     # Indexed dataset. | ||||
|     indexed_dataset = get_indexed_dataset_(data_prefix, | ||||
|                                            data_impl, | ||||
|                                            skip_warmup) | ||||
|  | ||||
|     if dataset_type == DSET_TYPE_ICT: | ||||
|         args = get_args() | ||||
|         title_dataset = get_indexed_dataset_(args.titles_data_path, | ||||
|                                              data_impl, | ||||
|                                              skip_warmup) | ||||
|  | ||||
|     # Get start and end indices of train/valid/train into doc-idx | ||||
|     # Note that doc-idx is designed to be num-docs + 1 so we can | ||||
|     # easily iterate over it. | ||||
|     total_num_of_documents = indexed_dataset.doc_idx.shape[0] - 1 | ||||
|     splits = get_train_valid_test_split_(splits_string, total_num_of_documents) | ||||
|  | ||||
|     # Print stats about the splits. | ||||
|     logger.info('\n > dataset split:') | ||||
|  | ||||
|     def print_split_stats(name, index): | ||||
|         start_index = indexed_dataset.doc_idx[splits[index]] | ||||
|         end_index = indexed_dataset.doc_idx[splits[index + 1]] | ||||
|         logger.info('\n    {}:'.format(name) + | ||||
|                     '\n     document indices in [{}, {}) total of {} documents'.format( | ||||
|                         splits[index], | ||||
|                         splits[index + 1], | ||||
|                         splits[index + 1] - splits[index]) + | ||||
|                     '\n     sentence indices in [{}, {}) total of {} sentences'.format( | ||||
|                         start_index, | ||||
|                         end_index, | ||||
|                         end_index - start_index), | ||||
|                     ranks=[0]) | ||||
|     print_split_stats('train', 0) | ||||
|     print_split_stats('validation', 1) | ||||
|     print_split_stats('test', 2) | ||||
|  | ||||
|     def build_dataset(index, name): | ||||
|         from .bert_dataset import BertDataset | ||||
|         dataset = None | ||||
|         if splits[index + 1] > splits[index]: | ||||
|             # Get the pointer to the original doc-idx so we can set it later. | ||||
|             doc_idx_ptr = indexed_dataset.get_doc_idx() | ||||
|             # Slice the doc-idx | ||||
|             start_index = splits[index] | ||||
|             # Add +1 so we can index into the dataset to get the upper bound. | ||||
|             end_index = splits[index + 1] + 1 | ||||
|             # New doc_idx view. | ||||
|             indexed_dataset.set_doc_idx(doc_idx_ptr[start_index:end_index]) | ||||
|             # Build the dataset accordingly. | ||||
|             kwargs = dict( | ||||
|                 name=name, | ||||
|                 data_prefix=data_prefix, | ||||
|                 num_epochs=None, | ||||
|                 max_num_samples=train_valid_test_num_samples[index], | ||||
|                 max_seq_length=max_seq_length, | ||||
|                 seed=seed, | ||||
|                 binary_head=binary_head | ||||
|             ) | ||||
|  | ||||
|             if dataset_type == DSET_TYPE_ICT: | ||||
|                 args = get_args() | ||||
|                 dataset = ICTDataset( | ||||
|                     block_dataset=indexed_dataset, | ||||
|                     title_dataset=title_dataset, | ||||
|                     query_in_block_prob=args.query_in_block_prob, | ||||
|                     use_one_sent_docs=args.use_one_sent_docs, | ||||
|                     **kwargs | ||||
|                 ) | ||||
|             else: | ||||
|                 dataset = BertDataset( | ||||
|                     indexed_dataset=indexed_dataset, | ||||
|                     masked_lm_prob=masked_lm_prob, | ||||
|                     short_seq_prob=short_seq_prob, | ||||
|                     **kwargs | ||||
|                 ) | ||||
|  | ||||
|             # Set the original pointer so dataset remains the main dataset. | ||||
|             indexed_dataset.set_doc_idx(doc_idx_ptr) | ||||
|             # Checks. | ||||
|             assert indexed_dataset.doc_idx[0] == 0 | ||||
|             assert indexed_dataset.doc_idx.shape[0] == \ | ||||
|                 (total_num_of_documents + 1) | ||||
|         return dataset | ||||
|  | ||||
|     train_dataset = build_dataset(0, 'train') | ||||
|     valid_dataset = build_dataset(1, 'valid') | ||||
|     test_dataset = build_dataset(2, 'test') | ||||
|  | ||||
|     return (train_dataset, valid_dataset, test_dataset) | ||||
|  | ||||
|  | ||||
| def get_indexed_dataset_(data_prefix, data_impl, skip_warmup): | ||||
|     logger = get_dist_logger() | ||||
|     start_time = time.time() | ||||
|     indexed_dataset = make_indexed_dataset(data_prefix, | ||||
|                                            data_impl, | ||||
|                                            skip_warmup) | ||||
|     assert indexed_dataset.sizes.shape[0] == indexed_dataset.doc_idx[-1] | ||||
|     logger.info('\n > building dataset index ...', ranks=[0]) | ||||
|     logger.info('\n > finished creating indexed dataset in {:4f} ' | ||||
|                 'seconds'.format(time.time() - start_time), ranks=[0]) | ||||
|     logger.info('\n > indexed dataset stats:' + | ||||
|                 '\n    number of documents: {}'.format( | ||||
|                     indexed_dataset.doc_idx.shape[0] - 1) + | ||||
|                 '\n    number of sentences: {}'.format( | ||||
|                     indexed_dataset.sizes.shape[0]), | ||||
|                 ranks=[0] | ||||
|                 ) | ||||
|  | ||||
|     return indexed_dataset | ||||
|  | ||||
|  | ||||
| def get_train_valid_test_split_(splits_string, size): | ||||
|     """ Get dataset splits from comma or '/' separated string list.""" | ||||
|  | ||||
|     splits = [] | ||||
|     if splits_string.find(',') != -1: | ||||
|         splits = [float(s) for s in splits_string.split(',')] | ||||
|     elif splits_string.find('/') != -1: | ||||
|         splits = [float(s) for s in splits_string.split('/')] | ||||
|     else: | ||||
|         splits = [float(splits_string)] | ||||
|     while len(splits) < 3: | ||||
|         splits.append(0.) | ||||
|     splits = splits[:3] | ||||
|     splits_sum = sum(splits) | ||||
|     assert splits_sum > 0.0 | ||||
|     splits = [split / splits_sum for split in splits] | ||||
|     splits_index = [0] | ||||
|     for index, split in enumerate(splits): | ||||
|         splits_index.append(splits_index[index] + | ||||
|                             int(round(split * float(size)))) | ||||
|     diff = splits_index[-1] - size | ||||
|     for index in range(1, len(splits_index)): | ||||
|         splits_index[index] -= diff | ||||
|     assert len(splits_index) == 4 | ||||
|     assert splits_index[-1] == size | ||||
|     return splits_index | ||||
							
								
								
									
										717
									
								
								examples/tutorial/sequence_parallel/data/datasets/helpers.cpp
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										717
									
								
								examples/tutorial/sequence_parallel/data/datasets/helpers.cpp
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,717 @@ | ||||
| /* | ||||
|  coding=utf-8 | ||||
|  Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved. | ||||
|  | ||||
|  Licensed under the Apache License, Version 2.0 (the "License"); | ||||
|  you may not use this file except in compliance with the License. | ||||
|  You may obtain a copy of the License at | ||||
|  | ||||
|      http://www.apache.org/licenses/LICENSE-2.0 | ||||
|  | ||||
|  Unless required by applicable law or agreed to in writing, software | ||||
|  distributed under the License is distributed on an "AS IS" BASIS, | ||||
|  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
|  See the License for the specific language governing permissions and | ||||
|  limitations under the License. | ||||
|  */ | ||||
|  | ||||
|  | ||||
| /* Helper methods for fast index mapping builds */ | ||||
|  | ||||
| #include <algorithm> | ||||
| #include <iostream> | ||||
| #include <limits> | ||||
| #include <math.h> | ||||
| #include <stdexcept> | ||||
| #include <pybind11/pybind11.h> | ||||
| #include <pybind11/numpy.h> | ||||
| #include <random> | ||||
|  | ||||
| namespace py = pybind11; | ||||
| using namespace std; | ||||
|  | ||||
| const int32_t LONG_SENTENCE_LEN = 512; | ||||
|  | ||||
|  | ||||
| void build_blending_indices(py::array_t<uint8_t>& dataset_index, | ||||
| 			    py::array_t<int64_t>& dataset_sample_index, | ||||
| 			    const py::array_t<double>& weights, | ||||
| 			    const int32_t num_datasets, | ||||
| 			    const int64_t size, const bool verbose) { | ||||
|   /* Given multiple datasets and a weighting array, build samples | ||||
|    such that it follows those wieghts.*/ | ||||
|  | ||||
|   if (verbose) { | ||||
|     std::cout << "> building indices for blendable datasets ..." << std::endl; | ||||
|   } | ||||
|  | ||||
|   // Get the pointer access without the checks. | ||||
|   auto dataset_index_ptr = dataset_index.mutable_unchecked<1>(); | ||||
|   auto dataset_sample_index_ptr = dataset_sample_index.mutable_unchecked<1>(); | ||||
|   auto weights_ptr = weights.unchecked<1>(); | ||||
|  | ||||
|   // Initialize buffer for number of samples used for each dataset. | ||||
|   int64_t current_samples[num_datasets]; | ||||
|   for(int64_t i = 0; i < num_datasets; ++i) { | ||||
|     current_samples[i] = 0; | ||||
|   } | ||||
|  | ||||
|   // For each sample: | ||||
|   for(int64_t sample_idx = 0; sample_idx < size; ++sample_idx) { | ||||
|  | ||||
|     // Determine where the max error in sampling is happening. | ||||
|     auto sample_idx_double = std::max(static_cast<double>(sample_idx), 1.0); | ||||
|     int64_t max_error_index = 0; | ||||
|     double max_error = weights_ptr[0] * sample_idx_double - | ||||
|       static_cast<double>(current_samples[0]); | ||||
|     for (int64_t dataset_idx = 1; dataset_idx < num_datasets; ++dataset_idx) { | ||||
|       double error = weights_ptr[dataset_idx] * sample_idx_double - | ||||
| 	static_cast<double>(current_samples[dataset_idx]); | ||||
|       if (error > max_error) { | ||||
| 	max_error = error; | ||||
| 	max_error_index = dataset_idx; | ||||
|       } | ||||
|     } | ||||
|  | ||||
|     // Populate the indices. | ||||
|     dataset_index_ptr[sample_idx] = static_cast<uint8_t>(max_error_index); | ||||
|     dataset_sample_index_ptr[sample_idx] = current_samples[max_error_index]; | ||||
|  | ||||
|     // Update the total samples. | ||||
|     current_samples[max_error_index] += 1; | ||||
|      | ||||
|   } | ||||
|  | ||||
|   // print info | ||||
|   if (verbose) { | ||||
|     std::cout << " > sample ratios:" << std::endl; | ||||
|     for (int64_t dataset_idx = 0; dataset_idx < num_datasets; ++dataset_idx) { | ||||
|       auto ratio = static_cast<double>(current_samples[dataset_idx]) / | ||||
| 	static_cast<double>(size); | ||||
|       std::cout << "   dataset " << dataset_idx << ", input: " << | ||||
| 	weights_ptr[dataset_idx] << ", achieved: " << ratio << std::endl;  | ||||
|     } | ||||
|   } | ||||
|  | ||||
| } | ||||
|  | ||||
|  | ||||
| py::array build_sample_idx(const py::array_t<int32_t>& sizes_, | ||||
| 			   const py::array_t<int32_t>& doc_idx_, | ||||
| 			   const int32_t seq_length, | ||||
| 			   const int32_t num_epochs, | ||||
| 			   const int64_t tokens_per_epoch) { | ||||
|     /* Sample index (sample_idx) is used for gpt2 like dataset for which | ||||
|        the documents are flattened and the samples are built based on this | ||||
|        1-D flatten array. It is a 2D array with sizes [number-of-samples + 1, 2] | ||||
|        where [..., 0] contains the index into `doc_idx` and [..., 1] is the | ||||
|        starting offset in that document.*/ | ||||
|  | ||||
|     // Consistency checks. | ||||
|     assert(seq_length > 1); | ||||
|     assert(num_epochs > 0); | ||||
|     assert(tokens_per_epoch > 1); | ||||
|  | ||||
|     // Remove bound checks. | ||||
|     auto sizes = sizes_.unchecked<1>(); | ||||
|     auto doc_idx = doc_idx_.unchecked<1>(); | ||||
|  | ||||
|     // Mapping and it's length (1D). | ||||
|     int64_t num_samples = (num_epochs * tokens_per_epoch - 1) / seq_length; | ||||
|     int32_t* sample_idx = new int32_t[2*(num_samples+1)]; | ||||
|  | ||||
|     cout << "    using:" << endl << std::flush; | ||||
|     cout << "     number of documents:       " << | ||||
|       doc_idx_.shape(0) / num_epochs << endl << std::flush; | ||||
|     cout << "     number of epochs:          " << num_epochs << | ||||
|       endl << std::flush; | ||||
|     cout << "     sequence length:           " << seq_length << | ||||
|       endl << std::flush; | ||||
|     cout << "     total number of samples:   " << num_samples << | ||||
|       endl << std::flush; | ||||
|  | ||||
|     // Index into sample_idx. | ||||
|     int64_t sample_index = 0; | ||||
|     // Index into doc_idx. | ||||
|     int64_t doc_idx_index = 0; | ||||
|     // Begining offset for each document. | ||||
|     int32_t doc_offset = 0; | ||||
|     // Start with first document and no offset. | ||||
|     sample_idx[2 * sample_index] = doc_idx_index; | ||||
|     sample_idx[2 * sample_index + 1] = doc_offset; | ||||
|     ++sample_index; | ||||
|  | ||||
|     while (sample_index <= num_samples) { | ||||
|         // Start with a fresh sequence. | ||||
|       int32_t remaining_seq_length = seq_length + 1; | ||||
|       while (remaining_seq_length != 0) { | ||||
|             // Get the document length. | ||||
| 	auto doc_id = doc_idx[doc_idx_index]; | ||||
| 	auto doc_length = sizes[doc_id] - doc_offset; | ||||
| 	// And add it to the current sequence. | ||||
| 	remaining_seq_length -= doc_length; | ||||
| 	// If we have more than a full sequence, adjust offset and set | ||||
| 	// remaining length to zero so we return from the while loop. | ||||
| 	// Note that -1 here is for the same reason we have -1 in | ||||
| 	// `_num_epochs` calculations. | ||||
| 	if (remaining_seq_length <= 0) { | ||||
| 	  doc_offset += (remaining_seq_length + doc_length - 1); | ||||
| 	  remaining_seq_length = 0; | ||||
| 	} else { | ||||
| 	  // Otherwise, start from the begining of the next document. | ||||
| 	  ++doc_idx_index; | ||||
| 	  doc_offset = 0; | ||||
| 	} | ||||
|       } | ||||
|       // Record the sequence. | ||||
|       sample_idx[2 * sample_index] = doc_idx_index; | ||||
|       sample_idx[2 * sample_index + 1] = doc_offset; | ||||
|       ++sample_index; | ||||
|     } | ||||
|  | ||||
|     // Method to deallocate memory. | ||||
|     py::capsule free_when_done(sample_idx, [](void *mem_) { | ||||
| 	int32_t *mem = reinterpret_cast<int32_t*>(mem_); | ||||
| 	delete[] mem; | ||||
|       }); | ||||
|  | ||||
|     // Return the numpy array. | ||||
|     const auto byte_size = sizeof(int32_t); | ||||
|     return py::array(std::vector<int64_t>{num_samples+1, 2}, // shape | ||||
|                      {2*byte_size, byte_size}, // C-style contiguous strides | ||||
|                      sample_idx, // the data pointer | ||||
|                      free_when_done); // numpy array references | ||||
|      | ||||
| } | ||||
|  | ||||
|  | ||||
| inline int32_t get_target_sample_len(const int32_t short_seq_ratio, | ||||
| 				     const int32_t max_length, | ||||
| 				     std::mt19937& rand32_gen) { | ||||
|     /* Training sample length. */ | ||||
|     if (short_seq_ratio == 0) { | ||||
|       return max_length; | ||||
|     } | ||||
|     const auto random_number = rand32_gen(); | ||||
|     if ((random_number % short_seq_ratio) == 0) { | ||||
|       return 2 + random_number % (max_length - 1); | ||||
|     } | ||||
|     return max_length; | ||||
| } | ||||
|  | ||||
|  | ||||
| template<typename DocIdx> | ||||
| py::array build_mapping_impl(const py::array_t<int64_t>& docs_, | ||||
|                              const py::array_t<int32_t>& sizes_, | ||||
|                              const int32_t num_epochs, | ||||
|                              const uint64_t max_num_samples, | ||||
|                              const int32_t max_seq_length, | ||||
|                              const double short_seq_prob, | ||||
|                              const int32_t seed, | ||||
| 			     const bool verbose, | ||||
| 			     const int32_t min_num_sent) { | ||||
|     /* Build a mapping of (start-index, end-index, sequence-length) where | ||||
|        start and end index are the indices of the sentences in the sample | ||||
|        and sequence-length is the target sequence length. | ||||
|     */ | ||||
|  | ||||
|     // Consistency checks. | ||||
|     assert(num_epochs > 0); | ||||
|     assert(max_seq_length > 1); | ||||
|     assert(short_seq_prob >= 0.0); | ||||
|     assert(short_seq_prob <= 1.0); | ||||
|     assert(seed > 0); | ||||
|  | ||||
|     // Remove bound checks. | ||||
|     auto docs = docs_.unchecked<1>(); | ||||
|     auto sizes = sizes_.unchecked<1>(); | ||||
|  | ||||
|     // For efficiency, convert probability to ratio. Note: rand() generates int. | ||||
|     int32_t short_seq_ratio = 0; | ||||
|     if (short_seq_prob > 0) { | ||||
|       short_seq_ratio = static_cast<int32_t>(round(1.0 / short_seq_prob)); | ||||
|     } | ||||
|  | ||||
|     if (verbose) { | ||||
|         const auto sent_start_index = docs[0]; | ||||
| 	const auto sent_end_index = docs[docs_.shape(0) - 1]; | ||||
| 	const auto num_sentences = sent_end_index - sent_start_index; | ||||
| 	cout << "    using:" << endl << std::flush; | ||||
| 	cout << "     number of documents:            " << docs_.shape(0) - 1 << | ||||
| 	  endl << std::flush; | ||||
| 	cout << "     sentences range:                [" << sent_start_index << | ||||
| 	", " << sent_end_index << ")" << endl << std::flush; | ||||
| 	cout << "     total number of sentences:      " << num_sentences << | ||||
| 	  endl << std::flush; | ||||
| 	cout << "     number of epochs:               " << num_epochs << | ||||
| 	  endl << std::flush; | ||||
| 	cout << "     maximum number of samples:      " << max_num_samples << | ||||
| 	  endl << std::flush; | ||||
| 	cout << "     maximum sequence length:        " << max_seq_length << | ||||
| 	  endl << std::flush; | ||||
| 	cout << "     short sequence probability:     " << short_seq_prob << | ||||
| 	endl << std::flush; | ||||
| 	cout << "     short sequence ration (1/prob): " << short_seq_ratio << | ||||
| 	  endl << std::flush; | ||||
| 	cout << "     seed:                           " << seed << endl << | ||||
| 	  std::flush; | ||||
|     } | ||||
|  | ||||
|     // Mapping and it's length (1D). | ||||
|     int64_t num_samples = -1; | ||||
|     DocIdx* maps = NULL; | ||||
|  | ||||
|     // Perform two iterations, in the first iteration get the size | ||||
|     // and allocate memory and in the second iteration populate the map. | ||||
|     bool second = false; | ||||
|     for (int32_t iteration=0; iteration<2; ++iteration) { | ||||
|  | ||||
|         // Set the seed so both iterations produce the same results. | ||||
|         std::mt19937 rand32_gen(seed); | ||||
|  | ||||
|         // Set the flag on second iteration. | ||||
|         second = (iteration == 1); | ||||
|  | ||||
|         // Counters: | ||||
|         uint64_t empty_docs = 0; | ||||
|         uint64_t one_sent_docs = 0; | ||||
| 	uint64_t long_sent_docs = 0; | ||||
|  | ||||
|         // Current map index. | ||||
|         uint64_t map_index = 0; | ||||
|  | ||||
|         // For each epoch: | ||||
|         for (int32_t epoch=0; epoch<num_epochs; ++epoch) { | ||||
|             if (map_index >= max_num_samples) { | ||||
| 	        if (verbose && (!second)) { | ||||
| 		  cout << "    reached " << max_num_samples << " samples after " | ||||
| 		       << epoch << " epochs ..." << endl << std::flush; | ||||
| 		} | ||||
|                 break; | ||||
|             } | ||||
|             // For each document: | ||||
|             for (int32_t doc=0; doc<(docs.shape(0) - 1); ++doc) { | ||||
|  | ||||
|                 // Document sentences are in [sent_index_first, sent_index_last) | ||||
|                 const auto sent_index_first = docs[doc]; | ||||
|                 const auto sent_index_last = docs[doc + 1]; | ||||
|  | ||||
|                 // At the begining of the document previous index is the | ||||
| 		// start index. | ||||
|                 auto prev_start_index = sent_index_first; | ||||
|  | ||||
|                 // Remaining documents. | ||||
|                 auto num_remain_sent = sent_index_last - sent_index_first; | ||||
|  | ||||
|                 // Some bookkeeping | ||||
|                 if ((epoch == 0) && (!second)) { | ||||
|                     if (num_remain_sent == 0) { | ||||
| 		        ++empty_docs; | ||||
|                     } | ||||
|                     if (num_remain_sent == 1) { | ||||
| 		        ++one_sent_docs; | ||||
|                     } | ||||
|                 } | ||||
|  | ||||
| 		// Detect documents with long sentences. | ||||
| 		bool contains_long_sentence = false; | ||||
| 		if (num_remain_sent > 1) { | ||||
| 		    for (auto sent_index=sent_index_first; | ||||
| 			 sent_index < sent_index_last; ++sent_index) { | ||||
| 		        if (sizes[sent_index] > LONG_SENTENCE_LEN){ | ||||
| 			    if ((epoch == 0) && (!second)) { | ||||
| 			        ++long_sent_docs; | ||||
| 			    } | ||||
| 			    contains_long_sentence = true; | ||||
| 			    break; | ||||
| 			} | ||||
| 		    } | ||||
| 		} | ||||
|  | ||||
|                 // If we have more than two sentences. | ||||
|                 if ((num_remain_sent >= min_num_sent) && (!contains_long_sentence)) { | ||||
|  | ||||
|                     // Set values. | ||||
|                     auto seq_len = int32_t{0}; | ||||
|                     auto num_sent = int32_t{0}; | ||||
|                     auto target_seq_len = get_target_sample_len(short_seq_ratio, | ||||
| 								max_seq_length, | ||||
| 								rand32_gen); | ||||
|  | ||||
|                     // Loop through sentences. | ||||
|                     for (auto sent_index=sent_index_first; | ||||
|                          sent_index < sent_index_last; ++sent_index) { | ||||
|  | ||||
| 		        // Add the size and number of sentences. | ||||
| 		        seq_len += sizes[sent_index]; | ||||
| 		        ++num_sent; | ||||
| 			--num_remain_sent; | ||||
|  | ||||
| 			// If we have reached the target length. | ||||
| 			// and if not only one sentence is left in the document. | ||||
| 			// and if we have at least two sentneces. | ||||
| 			// and if we have reached end of the document. | ||||
| 			if (((seq_len >= target_seq_len) && | ||||
| 			     (num_remain_sent > 1) && | ||||
| 			     (num_sent >= min_num_sent) ) || (num_remain_sent == 0)) { | ||||
|  | ||||
| 			    // Check for overflow. | ||||
| 			    if ((3 * map_index + 2) > | ||||
| 				std::numeric_limits<int64_t>::max()) { | ||||
| 			        cout << "number of samples exceeded maximum " | ||||
| 				     << "allowed by type int64: " | ||||
| 				     << std::numeric_limits<int64_t>::max() | ||||
| 				     << endl; | ||||
| 				throw std::overflow_error("Number of samples"); | ||||
| 			    } | ||||
|  | ||||
| 			    // Populate the map. | ||||
| 			    if (second) { | ||||
| 			        const auto map_index_0 = 3 * map_index; | ||||
| 				maps[map_index_0] = static_cast<DocIdx>(prev_start_index); | ||||
| 				maps[map_index_0 + 1] = static_cast<DocIdx>(sent_index + 1); | ||||
| 				maps[map_index_0 + 2] = static_cast<DocIdx>(target_seq_len); | ||||
| 			    } | ||||
|  | ||||
| 			    // Update indices / counters. | ||||
| 			    ++map_index; | ||||
| 			    prev_start_index = sent_index + 1; | ||||
| 			    target_seq_len = get_target_sample_len(short_seq_ratio, | ||||
| 								   max_seq_length, | ||||
| 								   rand32_gen); | ||||
| 			    seq_len = 0; | ||||
| 			    num_sent = 0; | ||||
| 			} | ||||
|  | ||||
|                     } // for (auto sent_index=sent_index_first; ... | ||||
|                 } // if (num_remain_sent > 1) { | ||||
|             } // for (int doc=0; doc < num_docs; ++doc) { | ||||
|         } // for (int epoch=0; epoch < num_epochs; ++epoch) { | ||||
|  | ||||
|         if (!second) { | ||||
| 	    if (verbose) { | ||||
| 	        cout << "   number of empty documents: " << empty_docs << | ||||
| 		  endl << std::flush; | ||||
| 		cout << "   number of documents with one sentence: " << | ||||
| 		  one_sent_docs << endl << std::flush; | ||||
| 		cout << "   number of documents with long sentences: " << | ||||
| 		  long_sent_docs << endl << std::flush; | ||||
| 		cout << "   will create mapping for " << map_index << | ||||
| 		  " samples" << endl << std::flush; | ||||
| 	    } | ||||
| 	    assert(maps == NULL); | ||||
| 	    assert(num_samples < 0); | ||||
|             maps = new DocIdx[3*map_index]; | ||||
|             num_samples = static_cast<int64_t>(map_index); | ||||
|         } | ||||
|  | ||||
|     } // for (int iteration=0; iteration < 2; ++iteration) { | ||||
|  | ||||
|     // Shuffle. | ||||
|     // We need a 64 bit random number generator as we might have more | ||||
|     // than 2 billion samples. | ||||
|     std::mt19937_64 rand64_gen(seed + 1); | ||||
|     for (auto i=(num_samples - 1); i > 0; --i) { | ||||
|       const auto j = static_cast<int64_t>(rand64_gen() % (i + 1)); | ||||
|       const auto i0 = 3 * i; | ||||
|       const auto j0 = 3 * j; | ||||
|       // Swap values. | ||||
|       swap(maps[i0], maps[j0]); | ||||
|       swap(maps[i0 + 1], maps[j0 + 1]); | ||||
|       swap(maps[i0 + 2], maps[j0 + 2]); | ||||
|     } | ||||
|  | ||||
|     // Method to deallocate memory. | ||||
|     py::capsule free_when_done(maps, [](void *mem_) { | ||||
|             DocIdx *mem = reinterpret_cast<DocIdx*>(mem_); | ||||
| 	    delete[] mem; | ||||
|         }); | ||||
|  | ||||
|     // Return the numpy array. | ||||
|     const auto byte_size = sizeof(DocIdx); | ||||
|     return py::array(std::vector<int64_t>{num_samples, 3}, // shape | ||||
|                      {3*byte_size, byte_size}, // C-style contiguous strides | ||||
|                      maps, // the data pointer | ||||
|                      free_when_done); // numpy array references | ||||
|  | ||||
| } | ||||
|  | ||||
|  | ||||
| py::array build_mapping(const py::array_t<int64_t>& docs_, | ||||
|                         const py::array_t<int>& sizes_, | ||||
|                         const int num_epochs, | ||||
|                         const uint64_t max_num_samples, | ||||
|                         const int max_seq_length, | ||||
|                         const double short_seq_prob, | ||||
|                         const int seed, | ||||
| 			const bool verbose, | ||||
| 			const int32_t min_num_sent) { | ||||
|  | ||||
|     if (sizes_.size() > std::numeric_limits<uint32_t>::max()) { | ||||
|         if (verbose) { | ||||
| 	   cout << "    using uint64 for data mapping..." << endl << std::flush; | ||||
| 	} | ||||
| 	return build_mapping_impl<uint64_t>(docs_, sizes_, num_epochs, | ||||
| 					    max_num_samples, max_seq_length, | ||||
| 					    short_seq_prob, seed, verbose, | ||||
| 					    min_num_sent); | ||||
|     } else { | ||||
|        if (verbose) { | ||||
| 	   cout << "    using uint32 for data mapping..." << endl << std::flush; | ||||
|        } | ||||
|        return build_mapping_impl<uint32_t>(docs_, sizes_, num_epochs, | ||||
| 					   max_num_samples, max_seq_length, | ||||
| 					   short_seq_prob, seed, verbose, | ||||
| 					   min_num_sent); | ||||
|     } | ||||
| } | ||||
|  | ||||
| template<typename DocIdx> | ||||
| py::array build_blocks_mapping_impl(const py::array_t<int64_t>& docs_, | ||||
|                                     const py::array_t<int32_t>& sizes_, | ||||
|                                     const py::array_t<int32_t>& titles_sizes_, | ||||
|                                     const int32_t num_epochs, | ||||
|                                     const uint64_t max_num_samples, | ||||
|                                     const int32_t max_seq_length, | ||||
|                                     const int32_t seed, | ||||
|                                     const bool verbose, | ||||
|                                     const bool use_one_sent_blocks) { | ||||
|     /* Build a mapping of (start-index, end-index, sequence-length) where | ||||
|        start and end index are the indices of the sentences in the sample | ||||
|        and sequence-length is the target sequence length. | ||||
|     */ | ||||
|  | ||||
|     // Consistency checks. | ||||
|     assert(num_epochs > 0); | ||||
|     assert(max_seq_length > 1); | ||||
|     assert(seed > 0); | ||||
|  | ||||
|     // Remove bound checks. | ||||
|     auto docs = docs_.unchecked<1>(); | ||||
|     auto sizes = sizes_.unchecked<1>(); | ||||
|     auto titles_sizes = titles_sizes_.unchecked<1>(); | ||||
|  | ||||
|     if (verbose) { | ||||
|         const auto sent_start_index = docs[0]; | ||||
|         const auto sent_end_index = docs[docs_.shape(0) - 1]; | ||||
|         const auto num_sentences = sent_end_index - sent_start_index; | ||||
|         cout << "    using:" << endl << std::flush; | ||||
|         cout << "     number of documents:            " << docs_.shape(0) - 1 << | ||||
|           endl << std::flush; | ||||
|         cout << "     sentences range:                [" << sent_start_index << | ||||
|         ", " << sent_end_index << ")" << endl << std::flush; | ||||
|         cout << "     total number of sentences:      " << num_sentences << | ||||
|           endl << std::flush; | ||||
|         cout << "     number of epochs:               " << num_epochs << | ||||
|           endl << std::flush; | ||||
|         cout << "     maximum number of samples:      " << max_num_samples << | ||||
|           endl << std::flush; | ||||
|         cout << "     maximum sequence length:        " << max_seq_length << | ||||
|           endl << std::flush; | ||||
|         cout << "     seed:                           " << seed << endl << | ||||
|           std::flush; | ||||
|     } | ||||
|  | ||||
|     // Mapping and its length (1D). | ||||
|     int64_t num_samples = -1; | ||||
|     DocIdx* maps = NULL; | ||||
|  | ||||
|     // Acceptable number of sentences per block. | ||||
|     int min_num_sent = 2; | ||||
|     if (use_one_sent_blocks) { | ||||
|         min_num_sent = 1; | ||||
|     } | ||||
|  | ||||
|     // Perform two iterations, in the first iteration get the size | ||||
|     // and allocate memory and in the second iteration populate the map. | ||||
|     bool second = false; | ||||
|     for (int32_t iteration=0; iteration<2; ++iteration) { | ||||
|  | ||||
|         // Set the flag on second iteration. | ||||
|         second = (iteration == 1); | ||||
|  | ||||
|         // Current map index. | ||||
|         uint64_t map_index = 0; | ||||
|  | ||||
|         uint64_t empty_docs = 0; | ||||
|         uint64_t one_sent_docs = 0; | ||||
|         uint64_t long_sent_docs = 0; | ||||
|         // For each epoch: | ||||
|         for (int32_t epoch=0; epoch<num_epochs; ++epoch) { | ||||
|             // assign every block a unique id | ||||
|             int32_t block_id = 0; | ||||
|  | ||||
|             if (map_index >= max_num_samples) { | ||||
|                 if (verbose && (!second)) { | ||||
|                 cout << "    reached " << max_num_samples << " samples after " | ||||
|                      << epoch << " epochs ..." << endl << std::flush; | ||||
|                 } | ||||
|                 break; | ||||
|             } | ||||
|             // For each document: | ||||
|             for (int32_t doc=0; doc<(docs.shape(0) - 1); ++doc) { | ||||
|  | ||||
|                 // Document sentences are in [sent_index_first, sent_index_last) | ||||
|                 const auto sent_index_first = docs[doc]; | ||||
|                 const auto sent_index_last = docs[doc + 1]; | ||||
|                 const auto target_seq_len = max_seq_length - titles_sizes[doc]; | ||||
|  | ||||
|                 // At the begining of the document previous index is the | ||||
|                 // start index. | ||||
|                 auto prev_start_index = sent_index_first; | ||||
|  | ||||
|                 // Remaining documents. | ||||
|                 auto num_remain_sent = sent_index_last - sent_index_first; | ||||
|  | ||||
|                 // Some bookkeeping | ||||
|                 if ((epoch == 0) && (!second)) { | ||||
|                     if (num_remain_sent == 0) { | ||||
| 		                ++empty_docs; | ||||
|                     } | ||||
|                     if (num_remain_sent == 1) { | ||||
| 		                ++one_sent_docs; | ||||
|                     } | ||||
|                 } | ||||
|                 // Detect documents with long sentences. | ||||
|                 bool contains_long_sentence = false; | ||||
|                 if (num_remain_sent >= min_num_sent) { | ||||
|                     for (auto sent_index=sent_index_first; | ||||
|                     sent_index < sent_index_last; ++sent_index) { | ||||
|                         if (sizes[sent_index] > LONG_SENTENCE_LEN){ | ||||
|                             if ((epoch == 0) && (!second)) { | ||||
|                                 ++long_sent_docs; | ||||
|                             } | ||||
|                             contains_long_sentence = true; | ||||
|                             break; | ||||
|                         } | ||||
|                     } | ||||
|                 } | ||||
|                 // If we have enough sentences and no long sentences. | ||||
|                 if ((num_remain_sent >= min_num_sent) && (!contains_long_sentence)) { | ||||
|  | ||||
|                     // Set values. | ||||
|                     auto seq_len = int32_t{0}; | ||||
|                     auto num_sent = int32_t{0}; | ||||
|  | ||||
|                     // Loop through sentences. | ||||
|                     for (auto sent_index=sent_index_first; | ||||
|                          sent_index < sent_index_last; ++sent_index) { | ||||
|  | ||||
|                             // Add the size and number of sentences. | ||||
|                             seq_len += sizes[sent_index]; | ||||
|                             ++num_sent; | ||||
|                             --num_remain_sent; | ||||
|  | ||||
|                         // If we have reached the target length. | ||||
|                         // and there are an acceptable number of sentences left | ||||
|                         // and if we have at least the minimum number of sentences. | ||||
|                         // or if we have reached end of the document. | ||||
|                         if (((seq_len >= target_seq_len) && | ||||
|                              (num_remain_sent >= min_num_sent) && | ||||
|                              (num_sent >= min_num_sent) ) || (num_remain_sent == 0)) { | ||||
|  | ||||
|                             // Populate the map. | ||||
|                             if (second) { | ||||
|                                 const auto map_index_0 = 4 * map_index; | ||||
|                                 // Each sample has 4 items: the starting sentence index, ending sentence index, | ||||
|                                 // the index of the document from which the block comes (used for fetching titles) | ||||
|                                 // and the unique id of the block (used for creating block indexes) | ||||
|  | ||||
|                                 maps[map_index_0] = static_cast<DocIdx>(prev_start_index); | ||||
|                                 maps[map_index_0 + 1] = static_cast<DocIdx>(sent_index + 1); | ||||
|                                 maps[map_index_0 + 2] = static_cast<DocIdx>(doc); | ||||
|                                 maps[map_index_0 + 3] = static_cast<DocIdx>(block_id); | ||||
|                             } | ||||
|  | ||||
|                             // Update indices / counters. | ||||
|                             ++map_index; | ||||
|                             ++block_id; | ||||
|                             prev_start_index = sent_index + 1; | ||||
|                             seq_len = 0; | ||||
|                             num_sent = 0; | ||||
|                         } | ||||
|                     } // for (auto sent_index=sent_index_first; ... | ||||
|                 } // if (num_remain_sent > 1) { | ||||
|             } // for (int doc=0; doc < num_docs; ++doc) { | ||||
|         } // for (int epoch=0; epoch < num_epochs; ++epoch) { | ||||
|  | ||||
|         if (!second) { | ||||
|             if (verbose) { | ||||
| 	        cout << "   number of empty documents: " << empty_docs << | ||||
|               endl << std::flush; | ||||
|             cout << "   number of documents with one sentence: " << | ||||
|               one_sent_docs << endl << std::flush; | ||||
|             cout << "   number of documents with long sentences: " << | ||||
|               long_sent_docs << endl << std::flush; | ||||
|             cout << "   will create mapping for " << map_index << | ||||
|               " samples" << endl << std::flush; | ||||
|             } | ||||
|             assert(maps == NULL); | ||||
|             assert(num_samples < 0); | ||||
|             maps = new DocIdx[4*map_index]; | ||||
|             num_samples = static_cast<int64_t>(map_index); | ||||
|         } | ||||
|  | ||||
|     } // for (int iteration=0; iteration < 2; ++iteration) { | ||||
|  | ||||
|     // Shuffle. | ||||
|     // We need a 64 bit random number generator as we might have more | ||||
|     // than 2 billion samples. | ||||
|     std::mt19937_64 rand64_gen(seed + 1); | ||||
|     for (auto i=(num_samples - 1); i > 0; --i) { | ||||
|         const auto j = static_cast<int64_t>(rand64_gen() % (i + 1)); | ||||
|         const auto i0 = 4 * i; | ||||
|         const auto j0 = 4 * j; | ||||
|         // Swap values. | ||||
|         swap(maps[i0], maps[j0]); | ||||
|         swap(maps[i0 + 1], maps[j0 + 1]); | ||||
|         swap(maps[i0 + 2], maps[j0 + 2]); | ||||
|         swap(maps[i0 + 3], maps[j0 + 3]); | ||||
|     } | ||||
|  | ||||
|     // Method to deallocate memory. | ||||
|     py::capsule free_when_done(maps, [](void *mem_) { | ||||
|             DocIdx *mem = reinterpret_cast<DocIdx*>(mem_); | ||||
| 	    delete[] mem; | ||||
|         }); | ||||
|  | ||||
|     // Return the numpy array. | ||||
|     const auto byte_size = sizeof(DocIdx); | ||||
|     return py::array(std::vector<int64_t>{num_samples, 4}, // shape | ||||
|                      {4*byte_size, byte_size}, // C-style contiguous strides | ||||
|                      maps, // the data pointer | ||||
|                      free_when_done); // numpy array references | ||||
|  | ||||
| } | ||||
|  | ||||
| py::array build_blocks_mapping(const py::array_t<int64_t>& docs_, | ||||
|                                const py::array_t<int>& sizes_, | ||||
|                                const py::array_t<int>& titles_sizes_, | ||||
|                                const int num_epochs, | ||||
|                                const uint64_t max_num_samples, | ||||
|                                const int max_seq_length, | ||||
|                                const int seed, | ||||
|                     const bool verbose, | ||||
|                     const bool use_one_sent_blocks) { | ||||
|  | ||||
|     if (sizes_.size() > std::numeric_limits<uint32_t>::max()) { | ||||
|         if (verbose) { | ||||
| 	   cout << "    using uint64 for data mapping..." << endl << std::flush; | ||||
| 	} | ||||
| 	return build_blocks_mapping_impl<uint64_t>(docs_, sizes_, titles_sizes_, | ||||
| 	                    num_epochs, max_num_samples, max_seq_length, seed, verbose, use_one_sent_blocks); | ||||
|     } else { | ||||
|        if (verbose) { | ||||
| 	   cout << "    using uint32 for data mapping..." << endl << std::flush; | ||||
|        } | ||||
|        return build_blocks_mapping_impl<uint32_t>(docs_, sizes_, titles_sizes_, | ||||
|                         num_epochs, max_num_samples, max_seq_length, seed, verbose, use_one_sent_blocks); | ||||
|     } | ||||
| } | ||||
|  | ||||
| PYBIND11_MODULE(helpers, m) { | ||||
|     m.def("build_mapping", &build_mapping); | ||||
|     m.def("build_blocks_mapping", &build_blocks_mapping); | ||||
|     m.def("build_sample_idx", &build_sample_idx); | ||||
|     m.def("build_blending_indices", &build_blending_indices); | ||||
| } | ||||
							
								
								
									
										156
									
								
								examples/tutorial/sequence_parallel/data/datasets/ict_dataset.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										156
									
								
								examples/tutorial/sequence_parallel/data/datasets/ict_dataset.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,156 @@ | ||||
| import itertools | ||||
| import random | ||||
|  | ||||
| import numpy as np | ||||
| from torch.utils.data import Dataset | ||||
|  | ||||
| from megatron import get_tokenizer | ||||
| from megatron import get_args | ||||
| from megatron.data.dataset_utils import get_indexed_dataset_ | ||||
| from megatron.data.realm_dataset_utils import get_block_samples_mapping | ||||
|  | ||||
| def make_attention_mask(source_block, target_block): | ||||
|     """ | ||||
|     Returns a 2-dimensional (2-D) attention mask | ||||
|     :param source_block: 1-D array | ||||
|     :param target_block: 1-D array | ||||
|     """ | ||||
|     mask = (target_block[None, :] >= 1) * (source_block[:, None] >= 1) | ||||
|     mask = mask.astype(np.int64) | ||||
|     # (source_length, target_length) | ||||
|     return mask | ||||
|  | ||||
| def get_ict_dataset(use_titles=True, query_in_block_prob=1): | ||||
|     """Get a dataset which uses block samples mappings to get ICT/block indexing data (via get_block()) | ||||
|     rather than for training, since it is only built with a single epoch sample mapping. | ||||
|     """ | ||||
|     args = get_args() | ||||
|     block_dataset = get_indexed_dataset_(args.data_path, 'mmap', True) | ||||
|     titles_dataset = get_indexed_dataset_(args.titles_data_path, 'mmap', True) | ||||
|  | ||||
|     kwargs = dict( | ||||
|         name='full', | ||||
|         block_dataset=block_dataset, | ||||
|         title_dataset=titles_dataset, | ||||
|         data_prefix=args.data_path, | ||||
|         num_epochs=1, | ||||
|         max_num_samples=None, | ||||
|         max_seq_length=args.seq_length, | ||||
|         seed=1, | ||||
|         query_in_block_prob=query_in_block_prob, | ||||
|         use_titles=use_titles, | ||||
|         use_one_sent_docs=args.use_one_sent_docs | ||||
|     ) | ||||
|     dataset = ICTDataset(**kwargs) | ||||
|     return dataset | ||||
|  | ||||
|  | ||||
| class ICTDataset(Dataset): | ||||
|     """Dataset containing sentences and their blocks for an inverse cloze task.""" | ||||
|     def __init__(self, name, block_dataset, title_dataset, data_prefix, | ||||
|                  num_epochs, max_num_samples, max_seq_length, query_in_block_prob, | ||||
|                  seed, use_titles=True, use_one_sent_docs=False, binary_head=False): | ||||
|         self.name = name | ||||
|         self.seed = seed | ||||
|         self.max_seq_length = max_seq_length | ||||
|         self.query_in_block_prob = query_in_block_prob | ||||
|         self.block_dataset = block_dataset | ||||
|         self.title_dataset = title_dataset | ||||
|         self.rng = random.Random(self.seed) | ||||
|         self.use_titles = use_titles | ||||
|         self.use_one_sent_docs = use_one_sent_docs | ||||
|  | ||||
|         self.samples_mapping = get_block_samples_mapping( | ||||
|             block_dataset, title_dataset, data_prefix, num_epochs, | ||||
|             max_num_samples, max_seq_length, seed, name, use_one_sent_docs) | ||||
|         self.tokenizer = get_tokenizer() | ||||
|         self.vocab_id_list = list(self.tokenizer.inv_vocab.keys()) | ||||
|         self.vocab_id_to_token_list = self.tokenizer.inv_vocab | ||||
|         self.cls_id = self.tokenizer.cls | ||||
|         self.sep_id = self.tokenizer.sep | ||||
|         self.mask_id = self.tokenizer.mask | ||||
|         self.pad_id = self.tokenizer.pad | ||||
|  | ||||
|     def __len__(self): | ||||
|         return len(self.samples_mapping) | ||||
|  | ||||
|     def __getitem__(self, idx): | ||||
|         """Get an ICT example of a pseudo-query and the block of text from which it was extracted""" | ||||
|         sample_data = self.samples_mapping[idx] | ||||
|         start_idx, end_idx, doc_idx, block_idx = sample_data.as_tuple() | ||||
|  | ||||
|         if self.use_titles: | ||||
|             title = self.title_dataset[int(doc_idx)] | ||||
|             title_pad_offset = 3 + len(title) | ||||
|         else: | ||||
|             title = None | ||||
|             title_pad_offset = 2 | ||||
|         block = [self.block_dataset[i] for i in range(start_idx, end_idx)] | ||||
|         assert len(block) > 1 or self.use_one_sent_docs or self.query_in_block_prob == 1 | ||||
|  | ||||
|         # randint() is inclusive for Python rng | ||||
|         rand_sent_idx = self.rng.randint(0, len(block) - 1) | ||||
|  | ||||
|         # keep the query in the context query_in_block_prob fraction of the time. | ||||
|         if self.rng.random() < self.query_in_block_prob: | ||||
|             query = block[rand_sent_idx].copy() | ||||
|         else: | ||||
|             query = block.pop(rand_sent_idx) | ||||
|  | ||||
|         # still need to truncate because blocks are concluded when | ||||
|         # the sentence lengths have exceeded max_seq_length. | ||||
|         query = query[:self.max_seq_length - 2] | ||||
|         block = list(itertools.chain(*block))[:self.max_seq_length - title_pad_offset] | ||||
|  | ||||
|         query_tokens, query_pad_mask = self.concat_and_pad_tokens(query) | ||||
|         context_tokens, context_pad_mask = self.concat_and_pad_tokens(block, title) | ||||
|  | ||||
|         query_mask = make_attention_mask(query_tokens, query_tokens) | ||||
|         context_mask = make_attention_mask(context_tokens, context_tokens) | ||||
|  | ||||
|         block_data = sample_data.as_array() | ||||
|  | ||||
|         sample = { | ||||
|             'query_tokens': query_tokens, | ||||
|             'query_mask': query_mask, | ||||
|             'query_pad_mask': query_pad_mask, | ||||
|             'context_tokens': context_tokens, | ||||
|             'context_mask': context_mask, | ||||
|             'context_pad_mask': context_pad_mask, | ||||
|             'block_data': block_data, | ||||
|         } | ||||
|  | ||||
|         return sample | ||||
|  | ||||
|     def get_block(self, start_idx, end_idx, doc_idx): | ||||
|         """Get the IDs for an evidence block plus the title of the corresponding document""" | ||||
|         block = [self.block_dataset[i] for i in range(start_idx, end_idx)] | ||||
|         title = self.title_dataset[int(doc_idx)] | ||||
|  | ||||
|         block = list(itertools.chain(*block))[:self.max_seq_length - (3 + len(title))] | ||||
|         block_tokens, block_pad_mask = self.concat_and_pad_tokens(block, title) | ||||
|  | ||||
|         return block_tokens, block_pad_mask | ||||
|  | ||||
|     def get_null_block(self): | ||||
|         """Get empty block and title - used in REALM pretraining""" | ||||
|         block, title = [], [] | ||||
|         block_tokens, block_pad_mask = self.concat_and_pad_tokens(block, title) | ||||
|  | ||||
|         return block_tokens, block_pad_mask | ||||
|  | ||||
|     def concat_and_pad_tokens(self, tokens, title=None): | ||||
|         """Concat with special tokens and pad sequence to self.max_seq_length""" | ||||
|         tokens = list(tokens) | ||||
|         if title is None: | ||||
|             tokens = [self.cls_id] + tokens + [self.sep_id] | ||||
|         else: | ||||
|             title = list(title) | ||||
|             tokens = [self.cls_id] + title + [self.sep_id] + tokens + [self.sep_id] | ||||
|         assert len(tokens) <= self.max_seq_length | ||||
|  | ||||
|         num_pad = self.max_seq_length - len(tokens) | ||||
|         pad_mask = [1] * len(tokens) + [0] * num_pad | ||||
|         tokens += [self.pad_id] * num_pad | ||||
|  | ||||
|         return np.array(tokens), np.array(pad_mask) | ||||
| @@ -0,0 +1,569 @@ | ||||
| # Copyright (c) Facebook, Inc. and its affiliates. | ||||
| # | ||||
| # This source code is licensed under the MIT license found in the | ||||
| # LICENSE file in the root directory of this source tree. | ||||
|  | ||||
|  | ||||
| # copied from fairseq/fairseq/data/indexed_dataset.py | ||||
| # Removed IndexedRawTextDataset since it relied on Fairseq dictionary | ||||
| # other slight modifications to remove fairseq dependencies | ||||
| # Added document index to index file and made it accessible. | ||||
| #    An empty sentence no longer separates documents. | ||||
|  | ||||
| from functools import lru_cache | ||||
| import os | ||||
| import shutil | ||||
| import struct | ||||
| from itertools import accumulate | ||||
|  | ||||
| import numpy as np | ||||
| import torch | ||||
|  | ||||
|  | ||||
| def __best_fitting_dtype(vocab_size=None): | ||||
|     if vocab_size is not None and vocab_size < 65500: | ||||
|         return np.uint16 | ||||
|     else: | ||||
|         return np.int32 | ||||
|  | ||||
|  | ||||
| def get_available_dataset_impl(): | ||||
|     return ['lazy', 'cached', 'mmap'] | ||||
|  | ||||
|  | ||||
| def infer_dataset_impl(path): | ||||
|     if IndexedDataset.exists(path): | ||||
|         with open(index_file_path(path), 'rb') as f: | ||||
|             magic = f.read(8) | ||||
|             if magic == IndexedDataset._HDR_MAGIC: | ||||
|                 return 'cached' | ||||
|             elif magic == MMapIndexedDataset.Index._HDR_MAGIC[:8]: | ||||
|                 return 'mmap' | ||||
|             else: | ||||
|                 return None | ||||
|     else: | ||||
|         print(f"Dataset does not exist: {path}") | ||||
|         print("Path should be a basename that both .idx and .bin can be appended to get full filenames.") | ||||
|         return None | ||||
|  | ||||
|  | ||||
| def make_builder(out_file, impl, vocab_size=None): | ||||
|     if impl == 'mmap': | ||||
|         return MMapIndexedDatasetBuilder(out_file, dtype=__best_fitting_dtype(vocab_size)) | ||||
|     else: | ||||
|         return IndexedDatasetBuilder(out_file) | ||||
|  | ||||
|  | ||||
| def make_dataset(path, impl, skip_warmup=False): | ||||
|     if not IndexedDataset.exists(path): | ||||
|         print(f"Dataset does not exist: {path}") | ||||
|         print("Path should be a basename that both .idx and .bin can be appended to get full filenames.") | ||||
|         return None | ||||
|     if impl == 'infer': | ||||
|         impl = infer_dataset_impl(path) | ||||
|     if impl == 'lazy' and IndexedDataset.exists(path): | ||||
|         return IndexedDataset(path) | ||||
|     elif impl == 'cached' and IndexedDataset.exists(path): | ||||
|         return IndexedCachedDataset(path) | ||||
|     elif impl == 'mmap' and MMapIndexedDataset.exists(path): | ||||
|         return MMapIndexedDataset(path, skip_warmup) | ||||
|     print(f"Unknown dataset implementation: {impl}") | ||||
|     return None | ||||
|  | ||||
|  | ||||
| def dataset_exists(path, impl): | ||||
|     if impl == 'mmap': | ||||
|         return MMapIndexedDataset.exists(path) | ||||
|     else: | ||||
|         return IndexedDataset.exists(path) | ||||
|  | ||||
|  | ||||
| def read_longs(f, n): | ||||
|     a = np.empty(n, dtype=np.int64) | ||||
|     f.readinto(a) | ||||
|     return a | ||||
|  | ||||
|  | ||||
| def write_longs(f, a): | ||||
|     f.write(np.array(a, dtype=np.int64)) | ||||
|  | ||||
|  | ||||
| dtypes = { | ||||
|     1: np.uint8, | ||||
|     2: np.int8, | ||||
|     3: np.int16, | ||||
|     4: np.int32, | ||||
|     5: np.int64, | ||||
|     6: np.float, | ||||
|     7: np.double, | ||||
|     8: np.uint16 | ||||
| } | ||||
|  | ||||
|  | ||||
| def code(dtype): | ||||
|     for k in dtypes.keys(): | ||||
|         if dtypes[k] == dtype: | ||||
|             return k | ||||
|     raise ValueError(dtype) | ||||
|  | ||||
|  | ||||
| def index_file_path(prefix_path): | ||||
|     return prefix_path + '.idx' | ||||
|  | ||||
|  | ||||
| def data_file_path(prefix_path): | ||||
|     return prefix_path + '.bin' | ||||
|  | ||||
|  | ||||
| def create_doc_idx(sizes): | ||||
|     doc_idx = [0] | ||||
|     for i, s in enumerate(sizes): | ||||
|         if s == 0: | ||||
|             doc_idx.append(i + 1) | ||||
|     return doc_idx | ||||
|  | ||||
|  | ||||
| class IndexedDataset(torch.utils.data.Dataset): | ||||
|     """Loader for IndexedDataset""" | ||||
|     _HDR_MAGIC = b'TNTIDX\x00\x00' | ||||
|  | ||||
|     def __init__(self, path): | ||||
|         super().__init__() | ||||
|         self.path = path | ||||
|         self.data_file = None | ||||
|         self.read_index(path) | ||||
|  | ||||
|     def read_index(self, path): | ||||
|         with open(index_file_path(path), 'rb') as f: | ||||
|             magic = f.read(8) | ||||
|             assert magic == self._HDR_MAGIC, ( | ||||
|                 'Index file doesn\'t match expected format. ' | ||||
|                 'Make sure that --dataset-impl is configured properly.' | ||||
|             ) | ||||
|             version = f.read(8) | ||||
|             assert struct.unpack('<Q', version) == (1,) | ||||
|             code, self.element_size = struct.unpack('<QQ', f.read(16)) | ||||
|             self.dtype = dtypes[code] | ||||
|             self._len, self.s = struct.unpack('<QQ', f.read(16)) | ||||
|             self.doc_count = struct.unpack('<Q', f.read(8)) | ||||
|             self.dim_offsets = read_longs(f, self._len + 1) | ||||
|             self.data_offsets = read_longs(f, self._len + 1) | ||||
|             self.sizes = read_longs(f, self.s) | ||||
|             self.doc_idx = read_longs(f, self.doc_count) | ||||
|  | ||||
|     def read_data(self, path): | ||||
|         self.data_file = open(data_file_path(path), 'rb', buffering=0) | ||||
|  | ||||
|     def check_index(self, i): | ||||
|         if i < 0 or i >= self._len: | ||||
|             raise IndexError('index out of range') | ||||
|  | ||||
|     def __del__(self): | ||||
|         if self.data_file: | ||||
|             self.data_file.close() | ||||
|  | ||||
|     # @lru_cache(maxsize=8) | ||||
|     def __getitem__(self, idx): | ||||
|         if not self.data_file: | ||||
|             self.read_data(self.path) | ||||
|         if isinstance(idx, int): | ||||
|             i = idx | ||||
|             self.check_index(i) | ||||
|             tensor_size = self.sizes[self.dim_offsets[i]:self.dim_offsets[i + 1]] | ||||
|             a = np.empty(tensor_size, dtype=self.dtype) | ||||
|             self.data_file.seek(self.data_offsets[i] * self.element_size) | ||||
|             self.data_file.readinto(a) | ||||
|             return a | ||||
|         elif isinstance(idx, slice): | ||||
|             start, stop, step = idx.indices(len(self)) | ||||
|             if step != 1: | ||||
|                 raise ValueError("Slices into indexed_dataset must be contiguous") | ||||
|             sizes = self.sizes[self.dim_offsets[start]:self.dim_offsets[stop]] | ||||
|             size = sum(sizes) | ||||
|             a = np.empty(size, dtype=self.dtype) | ||||
|             self.data_file.seek(self.data_offsets[start] * self.element_size) | ||||
|             self.data_file.readinto(a) | ||||
|             offsets = list(accumulate(sizes)) | ||||
|             sents = np.split(a, offsets[:-1]) | ||||
|             return sents | ||||
|  | ||||
|     def __len__(self): | ||||
|         return self._len | ||||
|  | ||||
|     def num_tokens(self, index): | ||||
|         return self.sizes[index] | ||||
|  | ||||
|     def size(self, index): | ||||
|         return self.sizes[index] | ||||
|  | ||||
|     @staticmethod | ||||
|     def exists(path): | ||||
|         return ( | ||||
|             os.path.exists(index_file_path(path)) and os.path.exists(data_file_path(path)) | ||||
|         ) | ||||
|  | ||||
|     @property | ||||
|     def supports_prefetch(self): | ||||
|         return False  # avoid prefetching to save memory | ||||
|  | ||||
|  | ||||
| class IndexedCachedDataset(IndexedDataset): | ||||
|  | ||||
|     def __init__(self, path): | ||||
|         super().__init__(path) | ||||
|         self.cache = None | ||||
|         self.cache_index = {} | ||||
|  | ||||
|     @property | ||||
|     def supports_prefetch(self): | ||||
|         return True | ||||
|  | ||||
|     def prefetch(self, indices): | ||||
|         if all(i in self.cache_index for i in indices): | ||||
|             return | ||||
|         if not self.data_file: | ||||
|             self.read_data(self.path) | ||||
|         indices = sorted(set(indices)) | ||||
|         total_size = 0 | ||||
|         for i in indices: | ||||
|             total_size += self.data_offsets[i + 1] - self.data_offsets[i] | ||||
|         self.cache = np.empty(total_size, dtype=self.dtype) | ||||
|         ptx = 0 | ||||
|         self.cache_index.clear() | ||||
|         for i in indices: | ||||
|             self.cache_index[i] = ptx | ||||
|             size = self.data_offsets[i + 1] - self.data_offsets[i] | ||||
|             a = self.cache[ptx: ptx + size] | ||||
|             self.data_file.seek(self.data_offsets[i] * self.element_size) | ||||
|             self.data_file.readinto(a) | ||||
|             ptx += size | ||||
|         if self.data_file: | ||||
|             # close and delete data file after prefetch so we can pickle | ||||
|             self.data_file.close() | ||||
|             self.data_file = None | ||||
|  | ||||
|     # @lru_cache(maxsize=8) | ||||
|     def __getitem__(self, idx): | ||||
|         if isinstance(idx, int): | ||||
|             i = idx | ||||
|             self.check_index(i) | ||||
|             tensor_size = self.sizes[self.dim_offsets[i]:self.dim_offsets[i + 1]] | ||||
|             a = np.empty(tensor_size, dtype=self.dtype) | ||||
|             ptx = self.cache_index[i] | ||||
|             np.copyto(a, self.cache[ptx: ptx + a.size]) | ||||
|             return a | ||||
|         elif isinstance(idx, slice): | ||||
|             # Hack just to make this work, can optimizer later if necessary | ||||
|             sents = [] | ||||
|             for i in range(*idx.indices(len(self))): | ||||
|                 sents.append(self[i]) | ||||
|             return sents | ||||
|  | ||||
|  | ||||
| class IndexedDatasetBuilder(object): | ||||
|     element_sizes = { | ||||
|         np.uint8: 1, | ||||
|         np.int8: 1, | ||||
|         np.int16: 2, | ||||
|         np.int32: 4, | ||||
|         np.int64: 8, | ||||
|         np.float: 4, | ||||
|         np.double: 8 | ||||
|     } | ||||
|  | ||||
|     def __init__(self, out_file, dtype=np.int32): | ||||
|         self.out_file = open(out_file, 'wb') | ||||
|         self.dtype = dtype | ||||
|         self.data_offsets = [0] | ||||
|         self.dim_offsets = [0] | ||||
|         self.sizes = [] | ||||
|         self.element_size = self.element_sizes[self.dtype] | ||||
|         self.doc_idx = [0] | ||||
|  | ||||
|     def add_item(self, tensor): | ||||
|         bytes = self.out_file.write(np.array(tensor.numpy(), dtype=self.dtype)) | ||||
|         self.data_offsets.append(self.data_offsets[-1] + bytes / self.element_size) | ||||
|         for s in tensor.size(): | ||||
|             self.sizes.append(s) | ||||
|         self.dim_offsets.append(self.dim_offsets[-1] + len(tensor.size())) | ||||
|  | ||||
|     def end_document(self): | ||||
|         self.doc_idx.append(len(self.sizes)) | ||||
|  | ||||
|     def merge_file_(self, another_file): | ||||
|         index = IndexedDataset(another_file) | ||||
|         assert index.dtype == self.dtype | ||||
|  | ||||
|         begin = self.data_offsets[-1] | ||||
|         for offset in index.data_offsets[1:]: | ||||
|             self.data_offsets.append(begin + offset) | ||||
|         self.sizes.extend(index.sizes) | ||||
|         begin = self.dim_offsets[-1] | ||||
|         for dim_offset in index.dim_offsets[1:]: | ||||
|             self.dim_offsets.append(begin + dim_offset) | ||||
|  | ||||
|         with open(data_file_path(another_file), 'rb') as f: | ||||
|             while True: | ||||
|                 data = f.read(1024) | ||||
|                 if data: | ||||
|                     self.out_file.write(data) | ||||
|                 else: | ||||
|                     break | ||||
|  | ||||
|     def finalize(self, index_file): | ||||
|         self.out_file.close() | ||||
|         index = open(index_file, 'wb') | ||||
|         index.write(b'TNTIDX\x00\x00') | ||||
|         index.write(struct.pack('<Q', 1)) | ||||
|         index.write(struct.pack('<QQ', code(self.dtype), self.element_size)) | ||||
|         index.write(struct.pack('<QQ', len(self.data_offsets) - 1, len(self.sizes))) | ||||
|         index.write(struct.pack('<Q', len(self.doc_idx))) | ||||
|         write_longs(index, self.dim_offsets) | ||||
|         write_longs(index, self.data_offsets) | ||||
|         write_longs(index, self.sizes) | ||||
|         write_longs(index, self.doc_idx) | ||||
|         index.close() | ||||
|  | ||||
|  | ||||
| def _warmup_mmap_file(path): | ||||
|     with open(path, 'rb') as stream: | ||||
|         while stream.read(100 * 1024 * 1024): | ||||
|             pass | ||||
|  | ||||
|  | ||||
| class MMapIndexedDataset(torch.utils.data.Dataset): | ||||
|     class Index(object): | ||||
|         _HDR_MAGIC = b'MMIDIDX\x00\x00' | ||||
|  | ||||
|         @classmethod | ||||
|         def writer(cls, path, dtype): | ||||
|             class _Writer(object): | ||||
|                 def __enter__(self): | ||||
|                     self._file = open(path, 'wb') | ||||
|  | ||||
|                     self._file.write(cls._HDR_MAGIC) | ||||
|                     self._file.write(struct.pack('<Q', 1)) | ||||
|                     self._file.write(struct.pack('<B', code(dtype))) | ||||
|  | ||||
|                     return self | ||||
|  | ||||
|                 @staticmethod | ||||
|                 def _get_pointers(sizes): | ||||
|                     dtype_size = dtype().itemsize | ||||
|                     address = 0 | ||||
|                     pointers = [] | ||||
|  | ||||
|                     for size in sizes: | ||||
|                         pointers.append(address) | ||||
|                         address += size * dtype_size | ||||
|  | ||||
|                     return pointers | ||||
|  | ||||
|                 def write(self, sizes, doc_idx): | ||||
|                     pointers = self._get_pointers(sizes) | ||||
|  | ||||
|                     self._file.write(struct.pack('<Q', len(sizes))) | ||||
|                     self._file.write(struct.pack('<Q', len(doc_idx))) | ||||
|  | ||||
|                     sizes = np.array(sizes, dtype=np.int32) | ||||
|                     self._file.write(sizes.tobytes(order='C')) | ||||
|                     del sizes | ||||
|  | ||||
|                     pointers = np.array(pointers, dtype=np.int64) | ||||
|                     self._file.write(pointers.tobytes(order='C')) | ||||
|                     del pointers | ||||
|  | ||||
|                     doc_idx = np.array(doc_idx, dtype=np.int64) | ||||
|                     self._file.write(doc_idx.tobytes(order='C')) | ||||
|  | ||||
|                 def __exit__(self, exc_type, exc_val, exc_tb): | ||||
|                     self._file.close() | ||||
|  | ||||
|             return _Writer() | ||||
|  | ||||
|         def __init__(self, path, skip_warmup=False): | ||||
|             with open(path, 'rb') as stream: | ||||
|                 magic_test = stream.read(9) | ||||
|                 assert self._HDR_MAGIC == magic_test, ( | ||||
|                     'Index file doesn\'t match expected format. ' | ||||
|                     'Make sure that --dataset-impl is configured properly.' | ||||
|                 ) | ||||
|                 version = struct.unpack('<Q', stream.read(8)) | ||||
|                 assert (1,) == version | ||||
|  | ||||
|                 dtype_code, = struct.unpack('<B', stream.read(1)) | ||||
|                 self._dtype = dtypes[dtype_code] | ||||
|                 self._dtype_size = self._dtype().itemsize | ||||
|  | ||||
|                 self._len = struct.unpack('<Q', stream.read(8))[0] | ||||
|                 self._doc_count = struct.unpack('<Q', stream.read(8))[0] | ||||
|                 offset = stream.tell() | ||||
|  | ||||
|             if not skip_warmup: | ||||
|                 print("    warming up index mmap file...") | ||||
|                 _warmup_mmap_file(path) | ||||
|  | ||||
|             self._bin_buffer_mmap = np.memmap(path, mode='r', order='C') | ||||
|             self._bin_buffer = memoryview(self._bin_buffer_mmap) | ||||
|             print("    reading sizes...") | ||||
|             self._sizes = np.frombuffer( | ||||
|                 self._bin_buffer, | ||||
|                 dtype=np.int32, | ||||
|                 count=self._len, | ||||
|                 offset=offset) | ||||
|             print("    reading pointers...") | ||||
|             self._pointers = np.frombuffer(self._bin_buffer, dtype=np.int64, count=self._len, | ||||
|                                            offset=offset + self._sizes.nbytes) | ||||
|             print("    reading document index...") | ||||
|             self._doc_idx = np.frombuffer(self._bin_buffer, dtype=np.int64, count=self._doc_count, | ||||
|                                           offset=offset + self._sizes.nbytes + self._pointers.nbytes) | ||||
|  | ||||
|         def __del__(self): | ||||
|             self._bin_buffer_mmap._mmap.close() | ||||
|             del self._bin_buffer_mmap | ||||
|  | ||||
|         @property | ||||
|         def dtype(self): | ||||
|             return self._dtype | ||||
|  | ||||
|         @property | ||||
|         def sizes(self): | ||||
|             return self._sizes | ||||
|  | ||||
|         @property | ||||
|         def doc_idx(self): | ||||
|             return self._doc_idx | ||||
|  | ||||
|         @lru_cache(maxsize=8) | ||||
|         def __getitem__(self, i): | ||||
|             return self._pointers[i], self._sizes[i] | ||||
|  | ||||
|         def __len__(self): | ||||
|             return self._len | ||||
|  | ||||
|     def __init__(self, path, skip_warmup=False): | ||||
|         super().__init__() | ||||
|  | ||||
|         self._path = None | ||||
|         self._index = None | ||||
|         self._bin_buffer = None | ||||
|  | ||||
|         self._do_init(path, skip_warmup) | ||||
|  | ||||
|     def __getstate__(self): | ||||
|         return self._path | ||||
|  | ||||
|     def __setstate__(self, state): | ||||
|         self._do_init(state) | ||||
|  | ||||
|     def _do_init(self, path, skip_warmup): | ||||
|         self._path = path | ||||
|         self._index = self.Index(index_file_path(self._path), skip_warmup) | ||||
|  | ||||
|         if not skip_warmup: | ||||
|             print("    warming up data mmap file...") | ||||
|             _warmup_mmap_file(data_file_path(self._path)) | ||||
|         print("    creating numpy buffer of mmap...") | ||||
|         self._bin_buffer_mmap = np.memmap(data_file_path(self._path), mode='r', order='C') | ||||
|         print("    creating memory view of numpy buffer...") | ||||
|         self._bin_buffer = memoryview(self._bin_buffer_mmap) | ||||
|  | ||||
|     def __del__(self): | ||||
|         self._bin_buffer_mmap._mmap.close() | ||||
|         del self._bin_buffer_mmap | ||||
|         del self._index | ||||
|  | ||||
|     def __len__(self): | ||||
|         return len(self._index) | ||||
|  | ||||
|     # @lru_cache(maxsize=8) | ||||
|     def __getitem__(self, idx): | ||||
|         if isinstance(idx, int): | ||||
|             ptr, size = self._index[idx] | ||||
|             np_array = np.frombuffer(self._bin_buffer, dtype=self._index.dtype, | ||||
|                                      count=size, offset=ptr) | ||||
|             return np_array | ||||
|         elif isinstance(idx, slice): | ||||
|             start, stop, step = idx.indices(len(self)) | ||||
|             if step != 1: | ||||
|                 raise ValueError("Slices into indexed_dataset must be contiguous") | ||||
|             ptr = self._index._pointers[start] | ||||
|             sizes = self._index._sizes[idx] | ||||
|             offsets = list(accumulate(sizes)) | ||||
|             total_size = sum(sizes) | ||||
|             np_array = np.frombuffer(self._bin_buffer, dtype=self._index.dtype, | ||||
|                                      count=total_size, offset=ptr) | ||||
|             sents = np.split(np_array, offsets[:-1]) | ||||
|             return sents | ||||
|  | ||||
|     def get(self, idx, offset=0, length=None): | ||||
|         """ Retrieves a single item from the dataset with the option to only | ||||
|         return a portion of the item. | ||||
|  | ||||
|         get(idx) is the same as [idx] but get() does not support slicing. | ||||
|         """ | ||||
|         ptr, size = self._index[idx] | ||||
|         if length is None: | ||||
|             length = size - offset | ||||
|         ptr += offset * np.dtype(self._index.dtype).itemsize | ||||
|         np_array = np.frombuffer(self._bin_buffer, dtype=self._index.dtype, | ||||
|                                  count=length, offset=ptr) | ||||
|         return np_array | ||||
|  | ||||
|     @property | ||||
|     def sizes(self): | ||||
|         return self._index.sizes | ||||
|  | ||||
|     @property | ||||
|     def doc_idx(self): | ||||
|         return self._index.doc_idx | ||||
|  | ||||
|     def get_doc_idx(self): | ||||
|         return self._index._doc_idx | ||||
|  | ||||
|     def set_doc_idx(self, doc_idx_): | ||||
|         self._index._doc_idx = doc_idx_ | ||||
|  | ||||
|     @property | ||||
|     def supports_prefetch(self): | ||||
|         return False | ||||
|  | ||||
|     @staticmethod | ||||
|     def exists(path): | ||||
|         return ( | ||||
|             os.path.exists(index_file_path(path)) and os.path.exists(data_file_path(path)) | ||||
|         ) | ||||
|  | ||||
|  | ||||
| class MMapIndexedDatasetBuilder(object): | ||||
|     def __init__(self, out_file, dtype=np.int64): | ||||
|         self._data_file = open(out_file, 'wb') | ||||
|         self._dtype = dtype | ||||
|         self._sizes = [] | ||||
|         self._doc_idx = [0] | ||||
|  | ||||
|     def add_item(self, tensor): | ||||
|         np_array = np.array(tensor.numpy(), dtype=self._dtype) | ||||
|         self._data_file.write(np_array.tobytes(order='C')) | ||||
|         self._sizes.append(np_array.size) | ||||
|  | ||||
|     def end_document(self): | ||||
|         self._doc_idx.append(len(self._sizes)) | ||||
|  | ||||
|     def merge_file_(self, another_file): | ||||
|         # Concatenate index | ||||
|         index = MMapIndexedDataset.Index(index_file_path(another_file)) | ||||
|         assert index.dtype == self._dtype | ||||
|  | ||||
|         for size in index.sizes: | ||||
|             self._sizes.append(size) | ||||
|  | ||||
|         # Concatenate data | ||||
|         with open(data_file_path(another_file), 'rb') as f: | ||||
|             shutil.copyfileobj(f, self._data_file) | ||||
|  | ||||
|     def finalize(self, index_file): | ||||
|         self._data_file.close() | ||||
|  | ||||
|         with MMapIndexedDataset.Index.writer(index_file, self._dtype) as index: | ||||
|             index.write(self._sizes, self._doc_idx) | ||||
| @@ -0,0 +1,125 @@ | ||||
| # This file isn't really a formal automated test, it's just a place to | ||||
| # put some code used during development and manual testing of | ||||
| # indexed_dataset. | ||||
|  | ||||
| from megatron.data import indexed_dataset | ||||
| from megatron.tokenizer import build_tokenizer | ||||
| import argparse | ||||
| import os | ||||
| import sys | ||||
|  | ||||
| import torch | ||||
|  | ||||
| script_dir = os.path.dirname(os.path.realpath(__file__)) | ||||
| sys.path.append(os.path.join(script_dir, "../../../")) | ||||
|  | ||||
|  | ||||
| def test_indexed_dataset(args): | ||||
|     ds = indexed_dataset.make_dataset(args.data, args.dataset_impl) | ||||
|     tokenizer = build_tokenizer(args) | ||||
|     print(len(ds.doc_idx)) | ||||
|     print(len(ds)) | ||||
|     print(ds.doc_idx[-1]) | ||||
|     if ds.supports_prefetch: | ||||
|         # just prefetch the whole thing in test (so assume it is small) | ||||
|         ds.prefetch(range(len(ds))) | ||||
|     if args.count > len(ds.doc_idx) - 1: | ||||
|         args.count = len(ds.doc_idx) - 1 | ||||
|  | ||||
|     for i in range(args.count): | ||||
|         start = ds.doc_idx[i] | ||||
|         end = ds.doc_idx[i + 1] | ||||
|         ids = ds[start:end] | ||||
|         print(f"Document {i}:") | ||||
|         print("--------------") | ||||
|         for s in ids: | ||||
|             assert len(s) > 0 | ||||
|             l = s.data.tolist() | ||||
|             text = tokenizer.detokenize(l) | ||||
|             print(text) | ||||
|             print("---") | ||||
|  | ||||
|  | ||||
| def test_indexed_dataset_get(args): | ||||
|     ds = indexed_dataset.make_dataset(args.data, args.dataset_impl) | ||||
|     tokenizer = build_tokenizer(args) | ||||
|     size = ds.sizes[0] | ||||
|     print(f"size: {size}") | ||||
|     full = ds.get(0) | ||||
|     print(full) | ||||
|     # print(tokenizer.detokenize(full.data.tolist())) | ||||
|     print("---") | ||||
|     end = ds.get(0, offset=size - 10) | ||||
|     print(end) | ||||
|     # print(tokenizer.detokenize(end.data.tolist())) | ||||
|  | ||||
|     start = ds.get(0, length=10) | ||||
|     print(start) | ||||
|     # print(tokenizer.detokenize(start.data.tolist())) | ||||
|  | ||||
|     part = ds.get(0, offset=2, length=8) | ||||
|     print(part) | ||||
|     # print(tokenizer.detokenize(part.data.tolist())) | ||||
|  | ||||
| # def test_albert_dataset(args): | ||||
| #     # tokenizer = FullBertTokenizer(args.vocab, do_lower_case=True) | ||||
| #     # idataset = indexed_dataset.make_dataset(args.data, args.dataset_impl) | ||||
| #     # ds = AlbertDataset(idataset, tokenizer) | ||||
| #     ds = AlbertDataset.from_paths(args.vocab, args.data, args.dataset_impl, | ||||
| #                                   args.epochs, args.max_num_samples, | ||||
| #                                   args.masked_lm_prob, args.seq_length, | ||||
| #                                   args.short_seq_prob, args.seed) | ||||
| #     truncated = 0 | ||||
| #     total = 0 | ||||
| #     for i, s in enumerate(ds): | ||||
| #         ids = s['text'] | ||||
| #         tokens = ds.tokenizer.convert_ids_to_tokens(ids) | ||||
| #         print(tokens) | ||||
| #         if i >= args.count-1: | ||||
| #             exit() | ||||
|  | ||||
|  | ||||
| def main(): | ||||
|     parser = argparse.ArgumentParser() | ||||
|     parser.add_argument('--data', type=str, help='prefix to data files') | ||||
|     parser.add_argument('--dataset-impl', type=str, default='infer', | ||||
|                         choices=['lazy', 'cached', 'mmap', 'infer']) | ||||
|     parser.add_argument('--count', type=int, default=10, | ||||
|                         help='Number of samples/documents to print') | ||||
|  | ||||
|     group = parser.add_argument_group(title='tokenizer') | ||||
|     group.add_argument('--tokenizer-type', type=str, required=True, | ||||
|                        choices=['BertWordPieceLowerCase', | ||||
|                                 'GPT2BPETokenizer'], | ||||
|                        help='What type of tokenizer to use.') | ||||
|     group.add_argument('--vocab-file', type=str, default=None, | ||||
|                        help='Path to the vocab file') | ||||
|     group.add_argument('--merge-file', type=str, default=None, | ||||
|                        help='Path to the BPE merge file (if necessary).') | ||||
|  | ||||
|     parser.add_argument('--epochs', type=int, default=5, | ||||
|                         help='Number of epochs to plan for') | ||||
|     parser.add_argument('--max-num-samples', type=int, default=None, | ||||
|                         help='Maximum number of samples to plan for') | ||||
|     parser.add_argument('--masked-lm-prob', type=float, default=0.15, | ||||
|                         help='probability of masking tokens') | ||||
|     parser.add_argument('--seq-length', type=int, default=512, | ||||
|                         help='maximum sequence length') | ||||
|     parser.add_argument('--short-seq-prob', type=float, default=0.1, | ||||
|                         help='probability of creating a short sequence') | ||||
|     parser.add_argument('--seed', type=int, default=1234, | ||||
|                         help='random seed') | ||||
|     args = parser.parse_args() | ||||
|     args.rank = 0 | ||||
|     args.make_vocab_size_divisible_by = 128 | ||||
|     args.tensor_model_parallel_size = 1 | ||||
|  | ||||
|     if args.dataset_impl == "infer": | ||||
|         args.dataset_impl = indexed_dataset.infer_dataset_impl(args.data) | ||||
|  | ||||
| #    test_albert_dataset(args) | ||||
|     test_indexed_dataset_get(args) | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     main() | ||||
| @@ -0,0 +1,10 @@ | ||||
| #!/bin/bash | ||||
|  | ||||
| IMPL=cached | ||||
| python ../preprocess_data.py \ | ||||
|        --input test_samples.json \ | ||||
|        --vocab vocab.txt \ | ||||
|        --dataset-impl ${IMPL} \ | ||||
|        --output-prefix test_samples_${IMPL} \ | ||||
|        --workers 1 \ | ||||
|        --log-interval 2 | ||||
| @@ -0,0 +1,38 @@ | ||||
| # coding=utf-8 | ||||
| # Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved. | ||||
| # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | ||||
| # You may obtain a copy of the License at | ||||
| # | ||||
| #     http://www.apache.org/licenses/LICENSE-2.0 | ||||
| # | ||||
| # Unless required by applicable law or agreed to in writing, software | ||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
|  | ||||
|  | ||||
| from .tokenizer import build_tokenizer | ||||
|  | ||||
|  | ||||
| _TOKENIZER = None | ||||
| _PADDED_VOCAB_SIZE = -1 | ||||
|  | ||||
|  | ||||
| def initialize_tokenizer(vocab_file, tokenizer_type, vocab_extra_ids=0): | ||||
|     tokenizer, padded_vocab_size = build_tokenizer(vocab_file, tokenizer_type, vocab_extra_ids) | ||||
|     global _TOKENIZER, _PADDED_VOCAB_SIZE | ||||
|     _TOKENIZER = tokenizer | ||||
|     _PADDED_VOCAB_SIZE = padded_vocab_size | ||||
|  | ||||
|  | ||||
| def get_tokenizer(): | ||||
|     global _TOKENIZER | ||||
|     return _TOKENIZER | ||||
|  | ||||
|  | ||||
| def get_padded_vocab_size(): | ||||
|     global _PADDED_VOCAB_SIZE | ||||
|     return _PADDED_VOCAB_SIZE | ||||
| @@ -0,0 +1,431 @@ | ||||
| # coding=utf-8 | ||||
| # Copyright 2018 The Google AI Language Team Authors. | ||||
| # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | ||||
| # You may obtain a copy of the License at | ||||
| # | ||||
| #     http://www.apache.org/licenses/LICENSE-2.0 | ||||
| # | ||||
| # Unless required by applicable law or agreed to in writing, software | ||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
|  | ||||
| """Tokenization classes.""" | ||||
|  | ||||
| from __future__ import absolute_import | ||||
| from __future__ import division | ||||
| from __future__ import print_function | ||||
|  | ||||
| import collections | ||||
| import re | ||||
| import unicodedata | ||||
| import six | ||||
|  | ||||
|  | ||||
| def validate_case_matches_checkpoint(do_lower_case, init_checkpoint): | ||||
|     """Checks whether the casing config is consistent with the checkpoint name.""" | ||||
|  | ||||
|     # The casing has to be passed in by the user and there is no explicit check | ||||
|     # as to whether it matches the checkpoint. The casing information probably | ||||
|     # should have been stored in the bert_config.json file, but it's not, so | ||||
|     # we have to heuristically detect it to validate. | ||||
|  | ||||
|     if not init_checkpoint: | ||||
|         return | ||||
|  | ||||
|     m = re.match("^.*?([A-Za-z0-9_-]+)/bert_model.ckpt", init_checkpoint) | ||||
|     if m is None: | ||||
|         return | ||||
|  | ||||
|     model_name = m.group(1) | ||||
|  | ||||
|     lower_models = [ | ||||
|         "uncased_L-24_H-1024_A-16", "uncased_L-12_H-768_A-12", | ||||
|         "multilingual_L-12_H-768_A-12", "chinese_L-12_H-768_A-12" | ||||
|     ] | ||||
|  | ||||
|     cased_models = [ | ||||
|         "cased_L-12_H-768_A-12", "cased_L-24_H-1024_A-16", | ||||
|         "multi_cased_L-12_H-768_A-12" | ||||
|     ] | ||||
|  | ||||
|     is_bad_config = False | ||||
|     if model_name in lower_models and not do_lower_case: | ||||
|         is_bad_config = True | ||||
|         actual_flag = "False" | ||||
|         case_name = "lowercased" | ||||
|         opposite_flag = "True" | ||||
|  | ||||
|     if model_name in cased_models and do_lower_case: | ||||
|         is_bad_config = True | ||||
|         actual_flag = "True" | ||||
|         case_name = "cased" | ||||
|         opposite_flag = "False" | ||||
|  | ||||
|     if is_bad_config: | ||||
|         raise ValueError( | ||||
|             "You passed in `--do_lower_case=%s` with `--init_checkpoint=%s`. " | ||||
|             "However, `%s` seems to be a %s model, so you " | ||||
|             "should pass in `--do_lower_case=%s` so that the fine-tuning matches " | ||||
|             "how the model was pre-training. If this error is wrong, please " | ||||
|             "just comment out this check." % (actual_flag, init_checkpoint, | ||||
|                                               model_name, case_name, opposite_flag)) | ||||
|  | ||||
|  | ||||
| def convert_to_unicode(text): | ||||
|     """Converts `text` to Unicode (if it's not already), assuming utf-8 input.""" | ||||
|     if six.PY3: | ||||
|         if isinstance(text, str): | ||||
|             return text | ||||
|         elif isinstance(text, bytes): | ||||
|             return text.decode("utf-8", "ignore") | ||||
|         else: | ||||
|             raise ValueError("Unsupported string type: %s" % (type(text))) | ||||
|     elif six.PY2: | ||||
|         if isinstance(text, str): | ||||
|             return text.decode("utf-8", "ignore") | ||||
|         elif isinstance(text, unicode): | ||||
|             return text | ||||
|         else: | ||||
|             raise ValueError("Unsupported string type: %s" % (type(text))) | ||||
|     else: | ||||
|         raise ValueError("Not running on Python2 or Python 3?") | ||||
|  | ||||
|  | ||||
| def printable_text(text): | ||||
|     """Returns text encoded in a way suitable for print or `tf.logging`.""" | ||||
|  | ||||
|     # These functions want `str` for both Python2 and Python3, but in one case | ||||
|     # it's a Unicode string and in the other it's a byte string. | ||||
|     if six.PY3: | ||||
|         if isinstance(text, str): | ||||
|             return text | ||||
|         elif isinstance(text, bytes): | ||||
|             return text.decode("utf-8", "ignore") | ||||
|         else: | ||||
|             raise ValueError("Unsupported string type: %s" % (type(text))) | ||||
|     elif six.PY2: | ||||
|         if isinstance(text, str): | ||||
|             return text | ||||
|         elif isinstance(text, unicode): | ||||
|             return text.encode("utf-8") | ||||
|         else: | ||||
|             raise ValueError("Unsupported string type: %s" % (type(text))) | ||||
|     else: | ||||
|         raise ValueError("Not running on Python2 or Python 3?") | ||||
|  | ||||
|  | ||||
| def load_vocab(vocab_file): | ||||
|     """Loads a vocabulary file into a dictionary.""" | ||||
|     vocab = collections.OrderedDict() | ||||
|     index = 0 | ||||
|     with open(vocab_file, "r") as reader: | ||||
|         while True: | ||||
|             token = convert_to_unicode(reader.readline()) | ||||
|             if not token: | ||||
|                 break | ||||
|             token = token.strip() | ||||
|             vocab[token] = index | ||||
|             index += 1 | ||||
|     return vocab | ||||
|  | ||||
|  | ||||
| def convert_by_vocab(vocab, items): | ||||
|     """Converts a sequence of [tokens|ids] using the vocab.""" | ||||
|     output = [] | ||||
|     for item in items: | ||||
|         output.append(vocab[item]) | ||||
|     return output | ||||
|  | ||||
|  | ||||
| def convert_tokens_to_ids(vocab, tokens): | ||||
|     return convert_by_vocab(vocab, tokens) | ||||
|  | ||||
|  | ||||
| def convert_ids_to_tokens(inv_vocab, ids): | ||||
|     return convert_by_vocab(inv_vocab, ids) | ||||
|  | ||||
|  | ||||
| def whitespace_tokenize(text): | ||||
|     """Runs basic whitespace cleaning and splitting on a piece of text.""" | ||||
|     text = text.strip() | ||||
|     if not text: | ||||
|         return [] | ||||
|     tokens = text.split() | ||||
|     return tokens | ||||
|  | ||||
|  | ||||
| class FullTokenizer(object): | ||||
|     """Runs end-to-end tokenization.""" | ||||
|  | ||||
|     def __init__(self, vocab_file, do_lower_case=True): | ||||
|         self.vocab = load_vocab(vocab_file) | ||||
|         self.inv_vocab = {v: k for k, v in self.vocab.items()} | ||||
|         self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case) | ||||
|         self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab) | ||||
|  | ||||
|     def tokenize(self, text): | ||||
|         split_tokens = [] | ||||
|         for token in self.basic_tokenizer.tokenize(text): | ||||
|             for sub_token in self.wordpiece_tokenizer.tokenize(token): | ||||
|                 split_tokens.append(sub_token) | ||||
|  | ||||
|         return split_tokens | ||||
|  | ||||
|     def convert_tokens_to_ids(self, tokens): | ||||
|         return convert_by_vocab(self.vocab, tokens) | ||||
|  | ||||
|     def convert_ids_to_tokens(self, ids): | ||||
|         return convert_by_vocab(self.inv_vocab, ids) | ||||
|  | ||||
|     @staticmethod | ||||
|     def convert_tokens_to_string(tokens, clean_up_tokenization_spaces=True): | ||||
|         """ Converts a sequence of tokens (string) in a single string. """ | ||||
|  | ||||
|         def clean_up_tokenization(out_string): | ||||
|             """ Clean up a list of simple English tokenization artifacts | ||||
|             like spaces before punctuations and abbreviated forms. | ||||
|             """ | ||||
|             out_string = ( | ||||
|                 out_string.replace(" .", ".") | ||||
|                     .replace(" ?", "?") | ||||
|                     .replace(" !", "!") | ||||
|                     .replace(" ,", ",") | ||||
|                     .replace(" ' ", "'") | ||||
|                     .replace(" n't", "n't") | ||||
|                     .replace(" 'm", "'m") | ||||
|                     .replace(" 's", "'s") | ||||
|                     .replace(" 've", "'ve") | ||||
|                     .replace(" 're", "'re") | ||||
|             ) | ||||
|             return out_string | ||||
|  | ||||
|         text = ' '.join(tokens).replace(' ##', '').strip() | ||||
|         if clean_up_tokenization_spaces: | ||||
|             clean_text = clean_up_tokenization(text) | ||||
|             return clean_text | ||||
|         else: | ||||
|             return text | ||||
|  | ||||
|     def vocab_size(self): | ||||
|         return len(self.vocab) | ||||
|  | ||||
|  | ||||
| class BasicTokenizer(object): | ||||
|     """Runs basic tokenization (punctuation splitting, lower casing, etc.).""" | ||||
|  | ||||
|     def __init__(self, do_lower_case=True): | ||||
|         """Constructs a BasicTokenizer. | ||||
|  | ||||
|         Args: | ||||
|           do_lower_case: Whether to lower case the input. | ||||
|         """ | ||||
|         self.do_lower_case = do_lower_case | ||||
|  | ||||
|     def tokenize(self, text): | ||||
|         """Tokenizes a piece of text.""" | ||||
|         text = convert_to_unicode(text) | ||||
|         text = self._clean_text(text) | ||||
|  | ||||
|         # This was added on November 1st, 2018 for the multilingual and Chinese | ||||
|         # models. This is also applied to the English models now, but it doesn't | ||||
|         # matter since the English models were not trained on any Chinese data | ||||
|         # and generally don't have any Chinese data in them (there are Chinese | ||||
|         # characters in the vocabulary because Wikipedia does have some Chinese | ||||
|         # words in the English Wikipedia.). | ||||
|         text = self._tokenize_chinese_chars(text) | ||||
|  | ||||
|         orig_tokens = whitespace_tokenize(text) | ||||
|         split_tokens = [] | ||||
|         for token in orig_tokens: | ||||
|             if self.do_lower_case: | ||||
|                 token = token.lower() | ||||
|                 token = self._run_strip_accents(token) | ||||
|             split_tokens.extend(self._run_split_on_punc(token)) | ||||
|  | ||||
|         output_tokens = whitespace_tokenize(" ".join(split_tokens)) | ||||
|         return output_tokens | ||||
|  | ||||
|     def _run_strip_accents(self, text): | ||||
|         """Strips accents from a piece of text.""" | ||||
|         text = unicodedata.normalize("NFD", text) | ||||
|         output = [] | ||||
|         for char in text: | ||||
|             cat = unicodedata.category(char) | ||||
|             if cat == "Mn": | ||||
|                 continue | ||||
|             output.append(char) | ||||
|         return "".join(output) | ||||
|  | ||||
|     def _run_split_on_punc(self, text): | ||||
|         """Splits punctuation on a piece of text.""" | ||||
|         chars = list(text) | ||||
|         i = 0 | ||||
|         start_new_word = True | ||||
|         output = [] | ||||
|         while i < len(chars): | ||||
|             char = chars[i] | ||||
|             if _is_punctuation(char): | ||||
|                 output.append([char]) | ||||
|                 start_new_word = True | ||||
|             else: | ||||
|                 if start_new_word: | ||||
|                     output.append([]) | ||||
|                 start_new_word = False | ||||
|                 output[-1].append(char) | ||||
|             i += 1 | ||||
|  | ||||
|         return ["".join(x) for x in output] | ||||
|  | ||||
|     def _tokenize_chinese_chars(self, text): | ||||
|         """Adds whitespace around any CJK character.""" | ||||
|         output = [] | ||||
|         for char in text: | ||||
|             cp = ord(char) | ||||
|             if self._is_chinese_char(cp): | ||||
|                 output.append(" ") | ||||
|                 output.append(char) | ||||
|                 output.append(" ") | ||||
|             else: | ||||
|                 output.append(char) | ||||
|         return "".join(output) | ||||
|  | ||||
|     def _is_chinese_char(self, cp): | ||||
|         """Checks whether CP is the codepoint of a CJK character.""" | ||||
|         # This defines a "chinese character" as anything in the CJK Unicode block: | ||||
|         #   https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) | ||||
|         # | ||||
|         # Note that the CJK Unicode block is NOT all Japanese and Korean characters, | ||||
|         # despite its name. The modern Korean Hangul alphabet is a different block, | ||||
|         # as is Japanese Hiragana and Katakana. Those alphabets are used to write | ||||
|         # space-separated words, so they are not treated specially and handled | ||||
|         # like the all of the other languages. | ||||
|         if ((cp >= 0x4E00 and cp <= 0x9FFF) or  # | ||||
|             (cp >= 0x3400 and cp <= 0x4DBF) or  # | ||||
|             (cp >= 0x20000 and cp <= 0x2A6DF) or  # | ||||
|             (cp >= 0x2A700 and cp <= 0x2B73F) or  # | ||||
|             (cp >= 0x2B740 and cp <= 0x2B81F) or  # | ||||
|             (cp >= 0x2B820 and cp <= 0x2CEAF) or | ||||
|             (cp >= 0xF900 and cp <= 0xFAFF) or  # | ||||
|                 (cp >= 0x2F800 and cp <= 0x2FA1F)):  # | ||||
|             return True | ||||
|  | ||||
|         return False | ||||
|  | ||||
|     def _clean_text(self, text): | ||||
|         """Performs invalid character removal and whitespace cleanup on text.""" | ||||
|         output = [] | ||||
|         for char in text: | ||||
|             cp = ord(char) | ||||
|             if cp == 0 or cp == 0xfffd or _is_control(char): | ||||
|                 continue | ||||
|             if _is_whitespace(char): | ||||
|                 output.append(" ") | ||||
|             else: | ||||
|                 output.append(char) | ||||
|         return "".join(output) | ||||
|  | ||||
|  | ||||
| class WordpieceTokenizer(object): | ||||
|     """Runs WordPiece tokenization.""" | ||||
|  | ||||
|     def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=200): | ||||
|         self.vocab = vocab | ||||
|         self.unk_token = unk_token | ||||
|         self.max_input_chars_per_word = max_input_chars_per_word | ||||
|  | ||||
|     def tokenize(self, text): | ||||
|         """Tokenizes a piece of text into its word pieces. | ||||
|  | ||||
|         This uses a greedy longest-match-first algorithm to perform tokenization | ||||
|         using the given vocabulary. | ||||
|  | ||||
|         For example: | ||||
|           input = "unaffable" | ||||
|           output = ["un", "##aff", "##able"] | ||||
|  | ||||
|         Args: | ||||
|           text: A single token or whitespace separated tokens. This should have | ||||
|             already been passed through `BasicTokenizer. | ||||
|  | ||||
|         Returns: | ||||
|           A list of wordpiece tokens. | ||||
|         """ | ||||
|  | ||||
|         text = convert_to_unicode(text) | ||||
|  | ||||
|         output_tokens = [] | ||||
|         for token in whitespace_tokenize(text): | ||||
|             chars = list(token) | ||||
|             if len(chars) > self.max_input_chars_per_word: | ||||
|                 output_tokens.append(self.unk_token) | ||||
|                 continue | ||||
|  | ||||
|             is_bad = False | ||||
|             start = 0 | ||||
|             sub_tokens = [] | ||||
|             while start < len(chars): | ||||
|                 end = len(chars) | ||||
|                 cur_substr = None | ||||
|                 while start < end: | ||||
|                     substr = "".join(chars[start:end]) | ||||
|                     if start > 0: | ||||
|                         substr = "##" + substr | ||||
|                     if substr in self.vocab: | ||||
|                         cur_substr = substr | ||||
|                         break | ||||
|                     end -= 1 | ||||
|                 if cur_substr is None: | ||||
|                     is_bad = True | ||||
|                     break | ||||
|                 sub_tokens.append(cur_substr) | ||||
|                 start = end | ||||
|  | ||||
|             if is_bad: | ||||
|                 output_tokens.append(self.unk_token) | ||||
|             else: | ||||
|                 output_tokens.extend(sub_tokens) | ||||
|         return output_tokens | ||||
|  | ||||
|  | ||||
| def _is_whitespace(char): | ||||
|     """Checks whether `chars` is a whitespace character.""" | ||||
|     # \t, \n, and \r are technically control characters but we treat them | ||||
|     # as whitespace since they are generally considered as such. | ||||
|     if char == " " or char == "\t" or char == "\n" or char == "\r": | ||||
|         return True | ||||
|     cat = unicodedata.category(char) | ||||
|     if cat == "Zs": | ||||
|         return True | ||||
|     return False | ||||
|  | ||||
|  | ||||
| def _is_control(char): | ||||
|     """Checks whether `chars` is a control character.""" | ||||
|     # These are technically control characters but we count them as whitespace | ||||
|     # characters. | ||||
|     if char == "\t" or char == "\n" or char == "\r": | ||||
|         return False | ||||
|     cat = unicodedata.category(char) | ||||
|     if cat in ("Cc", "Cf"): | ||||
|         return True | ||||
|     return False | ||||
|  | ||||
|  | ||||
| def _is_punctuation(char): | ||||
|     """Checks whether `chars` is a punctuation character.""" | ||||
|     cp = ord(char) | ||||
|     # We treat all non-letter/number ASCII as punctuation. | ||||
|     # Characters such as "^", "$", and "`" are not in the Unicode | ||||
|     # Punctuation class but we treat them as punctuation anyways, for | ||||
|     # consistency. | ||||
|     if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or | ||||
|             (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)): | ||||
|         return True | ||||
|     cat = unicodedata.category(char) | ||||
|     if cat.startswith("P"): | ||||
|         return True | ||||
|     return False | ||||
							
								
								
									
										256
									
								
								examples/tutorial/sequence_parallel/data/tokenizer/tokenizer.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										256
									
								
								examples/tutorial/sequence_parallel/data/tokenizer/tokenizer.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,256 @@ | ||||
| # coding=utf-8 | ||||
| # Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved. | ||||
| # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | ||||
| # You may obtain a copy of the License at | ||||
| # | ||||
| #     http://www.apache.org/licenses/LICENSE-2.0 | ||||
| # | ||||
| # Unless required by applicable law or agreed to in writing, software | ||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
|  | ||||
| """Megatron tokenizers.""" | ||||
|  | ||||
| from abc import ABC | ||||
| from abc import abstractmethod | ||||
| from colossalai.core import global_context as gpc | ||||
| from colossalai.context import ParallelMode | ||||
|  | ||||
| from .bert_tokenization import FullTokenizer as FullBertTokenizer | ||||
|  | ||||
|  | ||||
| def build_tokenizer(vocab_file, tokenizer_type, vocab_extra_ids=0): | ||||
|     """Initialize tokenizer.""" | ||||
|     if not gpc.is_initialized(ParallelMode.GLOBAL) or gpc.get_global_rank() == 0: | ||||
|         print('> building {} tokenizer ...'.format(tokenizer_type), | ||||
|               flush=True) | ||||
|  | ||||
|     # Select and instantiate the tokenizer. | ||||
|     if tokenizer_type == 'BertWordPieceLowerCase': | ||||
|         tokenizer = _BertWordPieceTokenizer(vocab_file=vocab_file, | ||||
|                                             lower_case=True, | ||||
|                                             vocab_extra_ids=vocab_extra_ids) | ||||
|     elif tokenizer_type == 'BertWordPieceCase': | ||||
|         tokenizer = _BertWordPieceTokenizer(vocab_file=vocab_file, | ||||
|                                             lower_case=False, | ||||
|                                             vocab_extra_ids=vocab_extra_ids) | ||||
|     else: | ||||
|         raise NotImplementedError('{} tokenizer is not ' | ||||
|                                   'implemented.'.format(tokenizer_type)) | ||||
|  | ||||
|     # Add vocab size. | ||||
|     padded_vocab_size = _vocab_size_with_padding(tokenizer.vocab_size) | ||||
|  | ||||
|     return tokenizer, padded_vocab_size | ||||
|  | ||||
|  | ||||
| def _vocab_size_with_padding(orig_vocab_size, make_vocab_size_divisible_by=128): | ||||
|     """Pad vocab size so it is divisible by model parallel size and | ||||
|     still having GPU friendly size.""" | ||||
|  | ||||
|     after = orig_vocab_size | ||||
|  | ||||
|     if gpc.is_initialized(ParallelMode.TENSOR): | ||||
|         multiple = make_vocab_size_divisible_by * gpc.get_world_size(ParallelMode.TENSOR) | ||||
|     else: | ||||
|         multiple = make_vocab_size_divisible_by | ||||
|     while (after % multiple) != 0: | ||||
|         after += 1 | ||||
|     if not gpc.is_initialized(ParallelMode.GLOBAL) or gpc.get_global_rank() == 0: | ||||
|         print(' > padded vocab (size: {}) with {} dummy tokens ' | ||||
|               '(new size: {})'.format( | ||||
|                   orig_vocab_size, after - orig_vocab_size, after), flush=True) | ||||
|     return after | ||||
|  | ||||
|  | ||||
| class AbstractTokenizer(ABC): | ||||
|     """Abstract class for tokenizer.""" | ||||
|  | ||||
|     def __init__(self, name): | ||||
|         self.name = name | ||||
|         super().__init__() | ||||
|  | ||||
|     @property | ||||
|     @abstractmethod | ||||
|     def vocab_size(self): | ||||
|         pass | ||||
|  | ||||
|     @property | ||||
|     @abstractmethod | ||||
|     def vocab(self): | ||||
|         """Dictionary from vocab text token to id token.""" | ||||
|         pass | ||||
|  | ||||
|     @property | ||||
|     @abstractmethod | ||||
|     def inv_vocab(self): | ||||
|         """Dictionary from vocab id token to text token.""" | ||||
|         pass | ||||
|  | ||||
|     @abstractmethod | ||||
|     def tokenize(self, text): | ||||
|         pass | ||||
|  | ||||
|     def detokenize(self, token_ids): | ||||
|         raise NotImplementedError('detokenizer is not implemented for {} ' | ||||
|                                   'tokenizer'.format(self.name)) | ||||
|  | ||||
|     @property | ||||
|     def cls(self): | ||||
|         raise NotImplementedError('CLS is not provided for {} ' | ||||
|                                   'tokenizer'.format(self.name)) | ||||
|  | ||||
|     @property | ||||
|     def sep(self): | ||||
|         raise NotImplementedError('SEP is not provided for {} ' | ||||
|                                   'tokenizer'.format(self.name)) | ||||
|  | ||||
|     @property | ||||
|     def pad(self): | ||||
|         raise NotImplementedError('PAD is not provided for {} ' | ||||
|                                   'tokenizer'.format(self.name)) | ||||
|  | ||||
|     @property | ||||
|     def eod(self): | ||||
|         raise NotImplementedError('EOD is not provided for {} ' | ||||
|                                   'tokenizer'.format(self.name)) | ||||
|  | ||||
|     @property | ||||
|     def mask(self): | ||||
|         raise NotImplementedError('MASK is not provided for {} ' | ||||
|                                   'tokenizer'.format(self.name)) | ||||
|  | ||||
|  | ||||
| class _BertWordPieceTokenizer(AbstractTokenizer): | ||||
|     """Original BERT wordpiece tokenizer.""" | ||||
|  | ||||
|     def __init__(self, vocab_file, lower_case=True, vocab_extra_ids=0): | ||||
|         if lower_case: | ||||
|             name = 'BERT Lower Case' | ||||
|         else: | ||||
|             name = 'BERT Upper Case' | ||||
|         super().__init__(name) | ||||
|         self.tokenizer = FullBertTokenizer(vocab_file, do_lower_case=lower_case) | ||||
|         self.cls_id = self.tokenizer.vocab['[CLS]'] | ||||
|         self.sep_id = self.tokenizer.vocab['[SEP]'] | ||||
|         self.pad_id = self.tokenizer.vocab['[PAD]'] | ||||
|         self.mask_id = self.tokenizer.vocab['[MASK]'] | ||||
|         self._additional_special_tokens = [] | ||||
|  | ||||
|         # (dsachan) Add BOS and EOS tokens | ||||
|         SPECIAL_TOKENS = {'eos_token': '[EOS]', | ||||
|                           'bos_token': '[BOS]'} | ||||
|         self._bos_token = '[BOS]' | ||||
|         self.add_token(self._bos_token) | ||||
|         self._bos_token_id = self.vocab.get(self._bos_token) | ||||
|  | ||||
|         self._eos_token = '[EOS]' | ||||
|         self.add_token(self._eos_token) | ||||
|         self._eos_token_id = self.vocab.get(self._eos_token) | ||||
|  | ||||
|         # (dsachan) Add additional special tokens | ||||
|         # These can be used as sentinel tokens in T5 model inputs | ||||
|         additional_special_tokens = [] | ||||
|         additional_special_tokens.extend( | ||||
|             ["<extra_id_{}>".format(i) for i in range(vocab_extra_ids)]) | ||||
|         self.add_additional_special_tokens(additional_special_tokens) | ||||
|  | ||||
|     def add_token(self, token): | ||||
|         if token not in self.vocab: | ||||
|             self.inv_vocab[self.vocab_size] = token | ||||
|             # self.vocab_size comes from len(vocab) | ||||
|             # and it will increase as we add elements | ||||
|             self.vocab[token] = self.vocab_size | ||||
|  | ||||
|     def add_additional_special_tokens(self, tokens_list): | ||||
|         setattr(self, "additional_special_tokens", tokens_list) | ||||
|         for value in tokens_list: | ||||
|             self.add_token(value) | ||||
|  | ||||
|     @property | ||||
|     def vocab_size(self): | ||||
|         return self.tokenizer.vocab_size() | ||||
|  | ||||
|     @property | ||||
|     def vocab(self): | ||||
|         return self.tokenizer.vocab | ||||
|  | ||||
|     @property | ||||
|     def inv_vocab(self): | ||||
|         return self.tokenizer.inv_vocab | ||||
|  | ||||
|     def tokenize(self, text): | ||||
|         text_tokens = self.tokenizer.tokenize(text) | ||||
|         return self.tokenizer.convert_tokens_to_ids(text_tokens) | ||||
|  | ||||
|     def decode(self, ids): | ||||
|         tokens = self.tokenizer.convert_ids_to_tokens(ids) | ||||
|         return self.tokenizer.convert_tokens_to_string(tokens) | ||||
|  | ||||
|     def decode_token_ids(self, token_ids): | ||||
|         tokens = self.tokenizer.convert_ids_to_tokens(token_ids) | ||||
|         exclude_list = ['[PAD]', '[CLS]'] | ||||
|         non_pads = [t for t in tokens if t not in exclude_list] | ||||
|  | ||||
|         result = "" | ||||
|         for s in non_pads: | ||||
|             if s.startswith("##"): | ||||
|                 result += s[2:] | ||||
|             else: | ||||
|                 result += " " + s | ||||
|  | ||||
|         return result | ||||
|  | ||||
|     @property | ||||
|     def cls(self): | ||||
|         return self.cls_id | ||||
|  | ||||
|     @property | ||||
|     def sep(self): | ||||
|         return self.sep_id | ||||
|  | ||||
|     @property | ||||
|     def pad(self): | ||||
|         return self.pad_id | ||||
|  | ||||
|     @property | ||||
|     def mask(self): | ||||
|         return self.mask_id | ||||
|  | ||||
|     @property | ||||
|     def bos_token(self): | ||||
|         """ Beginning of sentence token id """ | ||||
|         return self._bos_token | ||||
|  | ||||
|     @property | ||||
|     def eos_token(self): | ||||
|         """ End of sentence token id """ | ||||
|         return self._eos_token | ||||
|  | ||||
|     @property | ||||
|     def additional_special_tokens(self): | ||||
|         """ All the additional special tokens you may want to use (list of strings).""" | ||||
|         return self._additional_special_tokens | ||||
|  | ||||
|     @property | ||||
|     def bos_token_id(self): | ||||
|         """ Id of the beginning of sentence token in the vocabulary.""" | ||||
|         return self._bos_token_id | ||||
|  | ||||
|     @property | ||||
|     def eos_token_id(self): | ||||
|         """ Id of the end of sentence token in the vocabulary.""" | ||||
|         return self._eos_token_id | ||||
|  | ||||
|     @property | ||||
|     def additional_special_tokens_ids(self): | ||||
|         """ Ids of all the additional special tokens in the vocabulary (list of integers).""" | ||||
|         return [self.vocab.get(token) for token in self._additional_special_tokens] | ||||
|  | ||||
|     @additional_special_tokens.setter | ||||
|     def additional_special_tokens(self, value): | ||||
|         self._additional_special_tokens = value | ||||
		Reference in New Issue
	
	Block a user