mirror of
				https://github.com/hpcaitech/ColossalAI.git
				synced 2025-10-30 21:39:05 +00:00 
			
		
		
		
	[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
		| @@ -15,16 +15,13 @@ def cyclic_iter(iter): | ||||
|             yield x | ||||
|  | ||||
|  | ||||
| def build_train_valid_test_data_iterators(train_iters, | ||||
|                                           global_batch_size, | ||||
|                                           eval_interval, | ||||
|                                           eval_iters, | ||||
|                                           dataloader_type='single', | ||||
|                                           **kwargs): | ||||
| def build_train_valid_test_data_iterators( | ||||
|     train_iters, global_batch_size, eval_interval, eval_iters, dataloader_type="single", **kwargs | ||||
| ): | ||||
|     (train_dataloader, valid_dataloader, test_dataloader) = (None, None, None) | ||||
|  | ||||
|     logger = get_dist_logger() | ||||
|     logger.info('> building train, validation, and test datasets ...', ranks=[0]) | ||||
|     logger.info("> building train, validation, and test datasets ...", ranks=[0]) | ||||
|  | ||||
|     # Backward compatibility, assume fixed batch size. | ||||
|     # if iteration > 0 and consumed_train_samples == 0: | ||||
| @@ -38,29 +35,29 @@ def build_train_valid_test_data_iterators(train_iters, | ||||
|  | ||||
|     # Data loader only on rank 0 of each model parallel group. | ||||
|     if not gpc.is_initialized(ParallelMode.TENSOR) or gpc.get_local_rank(ParallelMode.TENSOR) == 0: | ||||
|  | ||||
|         # Number of train/valid/test samples. | ||||
|         train_samples = train_iters * global_batch_size | ||||
|         eval_iters_ = (train_iters // eval_interval + 1) * eval_iters | ||||
|         test_iters = eval_iters | ||||
|         train_val_test_num_samples = [train_samples, eval_iters_ * global_batch_size, test_iters * global_batch_size] | ||||
|         logger.info(' > datasets target sizes (minimum size):') | ||||
|         logger.info('    train:      {}'.format(train_val_test_num_samples[0]), ranks=[0]) | ||||
|         logger.info('    validation: {}'.format(train_val_test_num_samples[1]), ranks=[0]) | ||||
|         logger.info('    test:       {}'.format(train_val_test_num_samples[2]), ranks=[0]) | ||||
|         logger.info(" > datasets target sizes (minimum size):") | ||||
|         logger.info("    train:      {}".format(train_val_test_num_samples[0]), ranks=[0]) | ||||
|         logger.info("    validation: {}".format(train_val_test_num_samples[1]), ranks=[0]) | ||||
|         logger.info("    test:       {}".format(train_val_test_num_samples[2]), ranks=[0]) | ||||
|  | ||||
|         # Build the datasets. | ||||
|         train_ds, valid_ds, test_ds = build_train_valid_test_datasets( | ||||
|             train_valid_test_num_samples=train_val_test_num_samples, **kwargs) | ||||
|             train_valid_test_num_samples=train_val_test_num_samples, **kwargs | ||||
|         ) | ||||
|  | ||||
|         # Build dataloaders. | ||||
|         dp_size = gpc.get_world_size(ParallelMode.DATA) | ||||
|         train_dataloader = build_pretraining_data_loader(train_ds, | ||||
|                                                          consumed_samples=0, | ||||
|                                                          micro_batch_size=global_batch_size // dp_size) | ||||
|         valid_dataloader = build_pretraining_data_loader(valid_ds, | ||||
|                                                          consumed_samples=0, | ||||
|                                                          micro_batch_size=global_batch_size // dp_size) | ||||
|         train_dataloader = build_pretraining_data_loader( | ||||
|             train_ds, consumed_samples=0, micro_batch_size=global_batch_size // dp_size | ||||
|         ) | ||||
|         valid_dataloader = build_pretraining_data_loader( | ||||
|             valid_ds, consumed_samples=0, micro_batch_size=global_batch_size // dp_size | ||||
|         ) | ||||
|         test_dataloader = build_pretraining_data_loader(test_ds, 0, micro_batch_size=global_batch_size // dp_size) | ||||
|  | ||||
|         # Flags to know if we need to do training/validation/testing. | ||||
| @@ -73,29 +70,26 @@ def build_train_valid_test_data_iterators(train_iters, | ||||
|         flags = torch.cuda.LongTensor([0, 0, 0]) | ||||
|  | ||||
|     # Broadcast num tokens. | ||||
|     torch.distributed.broadcast(flags, | ||||
|                                 gpc.get_ranks_in_group(ParallelMode.TENSOR)[0], | ||||
|                                 group=gpc.get_group(ParallelMode.TENSOR)) | ||||
|     torch.distributed.broadcast( | ||||
|         flags, gpc.get_ranks_in_group(ParallelMode.TENSOR)[0], group=gpc.get_group(ParallelMode.TENSOR) | ||||
|     ) | ||||
|  | ||||
|     # Build iterators. | ||||
|     dl_type = dataloader_type | ||||
|     assert dl_type in ['single', 'cyclic'] | ||||
|     assert dl_type in ["single", "cyclic"] | ||||
|  | ||||
|     if train_dataloader is not None: | ||||
|         train_data_iterator = iter(train_dataloader) if dl_type == 'single' \ | ||||
|             else iter(cyclic_iter(train_dataloader)) | ||||
|         train_data_iterator = iter(train_dataloader) if dl_type == "single" else iter(cyclic_iter(train_dataloader)) | ||||
|     else: | ||||
|         train_data_iterator = None | ||||
|  | ||||
|     if valid_dataloader is not None: | ||||
|         valid_data_iterator = iter(valid_dataloader) if dl_type == 'single' \ | ||||
|             else iter(cyclic_iter(valid_dataloader)) | ||||
|         valid_data_iterator = iter(valid_dataloader) if dl_type == "single" else iter(cyclic_iter(valid_dataloader)) | ||||
|     else: | ||||
|         valid_data_iterator = None | ||||
|  | ||||
|     if test_dataloader is not None: | ||||
|         test_data_iterator = iter(test_dataloader) if dl_type == 'single' \ | ||||
|             else iter(cyclic_iter(test_dataloader)) | ||||
|         test_data_iterator = iter(test_dataloader) if dl_type == "single" else iter(cyclic_iter(test_dataloader)) | ||||
|     else: | ||||
|         test_data_iterator = None | ||||
|  | ||||
|   | ||||
| @@ -15,7 +15,7 @@ def _build_key_size_numel_dictionaries(keys, data): | ||||
|     if not gpc.is_initialized(ParallelMode.TENSOR) or gpc.get_local_rank(ParallelMode.TENSOR) == 0: | ||||
|         offset = 0 | ||||
|         for key in keys: | ||||
|             assert data[key].dim() < max_dim, 'you should increase MAX_DATA_DIM' | ||||
|             assert data[key].dim() < max_dim, "you should increase MAX_DATA_DIM" | ||||
|             size = data[key].size() | ||||
|             for i, s in enumerate(size): | ||||
|                 sizes[i + offset] = s | ||||
| @@ -23,9 +23,9 @@ def _build_key_size_numel_dictionaries(keys, data): | ||||
|  | ||||
|     # Move to GPU and broadcast. | ||||
|     sizes_cuda = torch.cuda.LongTensor(sizes) | ||||
|     torch.distributed.broadcast(sizes_cuda, | ||||
|                                 gpc.get_ranks_in_group(ParallelMode.TENSOR)[0], | ||||
|                                 group=gpc.get_group(ParallelMode.TENSOR)) | ||||
|     torch.distributed.broadcast( | ||||
|         sizes_cuda, gpc.get_ranks_in_group(ParallelMode.TENSOR)[0], group=gpc.get_group(ParallelMode.TENSOR) | ||||
|     ) | ||||
|  | ||||
|     # Move back to cpu and unpack. | ||||
|     sizes_cpu = sizes_cuda.cpu() | ||||
| @@ -73,9 +73,9 @@ def broadcast_data(keys, data, datatype): | ||||
|         flatten_data = torch.empty(total_numel, device=torch.cuda.current_device(), dtype=datatype) | ||||
|  | ||||
|     # Broadcast | ||||
|     torch.distributed.broadcast(flatten_data, | ||||
|                                 gpc.get_ranks_in_group(ParallelMode.TENSOR)[0], | ||||
|                                 group=gpc.get_group(ParallelMode.TENSOR)) | ||||
|     torch.distributed.broadcast( | ||||
|         flatten_data, gpc.get_ranks_in_group(ParallelMode.TENSOR)[0], group=gpc.get_group(ParallelMode.TENSOR) | ||||
|     ) | ||||
|  | ||||
|     # Unpack | ||||
|     output = {} | ||||
| @@ -93,7 +93,7 @@ def get_batch(data_iterator): | ||||
|     """Build the batch.""" | ||||
|  | ||||
|     # Items and their type. | ||||
|     keys = ['text', 'types', 'labels', 'is_random', 'loss_mask', 'padding_mask'] | ||||
|     keys = ["text", "types", "labels", "is_random", "loss_mask", "padding_mask"] | ||||
|     datatype = torch.int64 | ||||
|  | ||||
|     # Broadcast data. | ||||
| @@ -104,12 +104,12 @@ def get_batch(data_iterator): | ||||
|     data_b = broadcast_data(keys, data, datatype) | ||||
|  | ||||
|     # Unpack. | ||||
|     tokens = data_b['text'].long() | ||||
|     types = data_b['types'].long() | ||||
|     sentence_order = data_b['is_random'].long() | ||||
|     loss_mask = data_b['loss_mask'].float() | ||||
|     lm_labels = data_b['labels'].long() | ||||
|     padding_mask = data_b['padding_mask'].long() | ||||
|     tokens = data_b["text"].long() | ||||
|     types = data_b["types"].long() | ||||
|     sentence_order = data_b["is_random"].long() | ||||
|     loss_mask = data_b["loss_mask"].float() | ||||
|     lm_labels = data_b["labels"].long() | ||||
|     padding_mask = data_b["padding_mask"].long() | ||||
|  | ||||
|     return tokens, types, sentence_order, loss_mask, lm_labels, padding_mask | ||||
|  | ||||
| @@ -118,7 +118,7 @@ def get_batch_for_sequence_parallel(data_iterator): | ||||
|     """Build the batch.""" | ||||
|  | ||||
|     # Items and their type. | ||||
|     keys = ['text', 'types', 'labels', 'is_random', 'loss_mask', 'padding_mask'] | ||||
|     keys = ["text", "types", "labels", "is_random", "loss_mask", "padding_mask"] | ||||
|     datatype = torch.int64 | ||||
|  | ||||
|     # Broadcast data. | ||||
| @@ -134,24 +134,23 @@ def get_batch_for_sequence_parallel(data_iterator): | ||||
|     global_rank = torch.distributed.get_rank() | ||||
|     local_world_size = 1 if not gpc.is_initialized(ParallelMode.TENSOR) else gpc.get_world_size(ParallelMode.TENSOR) | ||||
|     local_rank = global_rank % local_world_size | ||||
|     seq_length = data_b['text'].size(1) | ||||
|     seq_length = data_b["text"].size(1) | ||||
|     sub_seq_length = seq_length // local_world_size | ||||
|     sub_seq_start = local_rank * sub_seq_length | ||||
|     sub_seq_end = (local_rank + 1) * sub_seq_length | ||||
|     # | ||||
|     # # Unpack. | ||||
|     tokens = data_b['text'][:, sub_seq_start:sub_seq_end].long() | ||||
|     types = data_b['types'][:, sub_seq_start:sub_seq_end].long() | ||||
|     sentence_order = data_b['is_random'].long() | ||||
|     loss_mask = data_b['loss_mask'][:, sub_seq_start:sub_seq_end].float() | ||||
|     lm_labels = data_b['labels'][:, sub_seq_start:sub_seq_end].long() | ||||
|     padding_mask = data_b['padding_mask'].long() | ||||
|     tokens = data_b["text"][:, sub_seq_start:sub_seq_end].long() | ||||
|     types = data_b["types"][:, sub_seq_start:sub_seq_end].long() | ||||
|     sentence_order = data_b["is_random"].long() | ||||
|     loss_mask = data_b["loss_mask"][:, sub_seq_start:sub_seq_end].float() | ||||
|     lm_labels = data_b["labels"][:, sub_seq_start:sub_seq_end].long() | ||||
|     padding_mask = data_b["padding_mask"].long() | ||||
|  | ||||
|     return tokens, types, sentence_order, loss_mask, lm_labels, padding_mask | ||||
|  | ||||
|  | ||||
| class SequenceParallelDataIterator: | ||||
|  | ||||
|     def __init__(self, data_iter): | ||||
|         self.data_iter = data_iter | ||||
|  | ||||
|   | ||||
| @@ -41,10 +41,19 @@ except: | ||||
|  | ||||
|  | ||||
| class BertDataset(Dataset): | ||||
|  | ||||
|     def __init__(self, name, indexed_dataset, data_prefix, num_epochs, max_num_samples, masked_lm_prob, max_seq_length, | ||||
|                  short_seq_prob, seed, binary_head): | ||||
|  | ||||
|     def __init__( | ||||
|         self, | ||||
|         name, | ||||
|         indexed_dataset, | ||||
|         data_prefix, | ||||
|         num_epochs, | ||||
|         max_num_samples, | ||||
|         masked_lm_prob, | ||||
|         max_seq_length, | ||||
|         short_seq_prob, | ||||
|         seed, | ||||
|         binary_head, | ||||
|     ): | ||||
|         # Params to store. | ||||
|         self.name = name | ||||
|         self.seed = seed | ||||
| @@ -61,11 +70,12 @@ class BertDataset(Dataset): | ||||
|             data_prefix, | ||||
|             num_epochs, | ||||
|             max_num_samples, | ||||
|             self.max_seq_length - 3,    # account for added tokens, | ||||
|             self.max_seq_length - 3,  # account for added tokens, | ||||
|             short_seq_prob, | ||||
|             self.seed, | ||||
|             self.name, | ||||
|             self.binary_head) | ||||
|             self.binary_head, | ||||
|         ) | ||||
|  | ||||
|         # Vocab stuff. | ||||
|         tokenizer = get_tokenizer() | ||||
| @@ -89,7 +99,7 @@ class BertDataset(Dataset): | ||||
|         return build_training_sample( | ||||
|             sample, | ||||
|             seq_length, | ||||
|             self.max_seq_length,    # needed for padding | ||||
|             self.max_seq_length,  # needed for padding | ||||
|             self.vocab_id_list, | ||||
|             self.vocab_id_to_token_dict, | ||||
|             self.cls_id, | ||||
| @@ -98,37 +108,39 @@ class BertDataset(Dataset): | ||||
|             self.pad_id, | ||||
|             self.masked_lm_prob, | ||||
|             np_rng, | ||||
|             self.binary_head) | ||||
|             self.binary_head, | ||||
|         ) | ||||
|  | ||||
|  | ||||
| def get_samples_mapping_(indexed_dataset, data_prefix, num_epochs, max_num_samples, max_seq_length, short_seq_prob, | ||||
|                          seed, name, binary_head): | ||||
| def get_samples_mapping_( | ||||
|     indexed_dataset, data_prefix, num_epochs, max_num_samples, max_seq_length, short_seq_prob, seed, name, binary_head | ||||
| ): | ||||
|     logger = get_dist_logger() | ||||
|     if not num_epochs: | ||||
|         if not max_num_samples: | ||||
|             raise ValueError("Need to specify either max_num_samples " | ||||
|                              "or num_epochs") | ||||
|             raise ValueError("Need to specify either max_num_samples " "or num_epochs") | ||||
|         num_epochs = np.iinfo(np.int32).max - 1 | ||||
|     if not max_num_samples: | ||||
|         max_num_samples = np.iinfo(np.int64).max - 1 | ||||
|  | ||||
|     # Filename of the index mapping | ||||
|     indexmap_filename = data_prefix | ||||
|     indexmap_filename += '_{}_indexmap'.format(name) | ||||
|     indexmap_filename += "_{}_indexmap".format(name) | ||||
|     if num_epochs != (np.iinfo(np.int32).max - 1): | ||||
|         indexmap_filename += '_{}ep'.format(num_epochs) | ||||
|         indexmap_filename += "_{}ep".format(num_epochs) | ||||
|     if max_num_samples != (np.iinfo(np.int64).max - 1): | ||||
|         indexmap_filename += '_{}mns'.format(max_num_samples) | ||||
|     indexmap_filename += '_{}msl'.format(max_seq_length) | ||||
|     indexmap_filename += '_{:0.2f}ssp'.format(short_seq_prob) | ||||
|     indexmap_filename += '_{}s'.format(seed) | ||||
|     indexmap_filename += '.npy' | ||||
|         indexmap_filename += "_{}mns".format(max_num_samples) | ||||
|     indexmap_filename += "_{}msl".format(max_seq_length) | ||||
|     indexmap_filename += "_{:0.2f}ssp".format(short_seq_prob) | ||||
|     indexmap_filename += "_{}s".format(seed) | ||||
|     indexmap_filename += ".npy" | ||||
|  | ||||
|     # Build the indexed mapping if not exist. | ||||
|     if torch.distributed.get_rank() == 0 and \ | ||||
|        not os.path.isfile(indexmap_filename): | ||||
|         print(' > WARNING: could not find index map file {}, building ' | ||||
|               'the indices on rank 0 ...'.format(indexmap_filename)) | ||||
|     if torch.distributed.get_rank() == 0 and not os.path.isfile(indexmap_filename): | ||||
|         print( | ||||
|             " > WARNING: could not find index map file {}, building " | ||||
|             "the indices on rank 0 ...".format(indexmap_filename) | ||||
|         ) | ||||
|  | ||||
|         # Make sure the types match the helpers input types. | ||||
|         assert indexed_dataset.doc_idx.dtype == np.int64 | ||||
| @@ -137,18 +149,27 @@ def get_samples_mapping_(indexed_dataset, data_prefix, num_epochs, max_num_sampl | ||||
|         # Build samples mapping | ||||
|         verbose = torch.distributed.get_rank() == 0 | ||||
|         start_time = time.time() | ||||
|         logger.info('\n > building samples index mapping for {} ...'.format(name), ranks=[0]) | ||||
|         logger.info("\n > building samples index mapping for {} ...".format(name), ranks=[0]) | ||||
|         # First compile and then import. | ||||
|         samples_mapping = helpers.build_mapping(indexed_dataset.doc_idx, indexed_dataset.sizes, num_epochs, | ||||
|                                                 max_num_samples, max_seq_length, short_seq_prob, seed, verbose, | ||||
|                                                 2 if binary_head else 1) | ||||
|         logger.info('\n > done building samples index maping', ranks=[0]) | ||||
|         samples_mapping = helpers.build_mapping( | ||||
|             indexed_dataset.doc_idx, | ||||
|             indexed_dataset.sizes, | ||||
|             num_epochs, | ||||
|             max_num_samples, | ||||
|             max_seq_length, | ||||
|             short_seq_prob, | ||||
|             seed, | ||||
|             verbose, | ||||
|             2 if binary_head else 1, | ||||
|         ) | ||||
|         logger.info("\n > done building samples index maping", ranks=[0]) | ||||
|         np.save(indexmap_filename, samples_mapping, allow_pickle=True) | ||||
|         logger.info('\n > saved the index mapping in {}'.format(indexmap_filename), ranks=[0]) | ||||
|         logger.info("\n > saved the index mapping in {}".format(indexmap_filename), ranks=[0]) | ||||
|         # Make sure all the ranks have built the mapping | ||||
|         logger.info('\n > elapsed time to build and save samples mapping ' | ||||
|                     '(seconds): {:4f}'.format(time.time() - start_time), | ||||
|                     ranks=[0]) | ||||
|         logger.info( | ||||
|             "\n > elapsed time to build and save samples mapping " "(seconds): {:4f}".format(time.time() - start_time), | ||||
|             ranks=[0], | ||||
|         ) | ||||
|     # This should be a barrier but nccl barrier assumes | ||||
|     # device_index=rank which is not the case for model | ||||
|     # parallel case | ||||
| @@ -156,22 +177,38 @@ def get_samples_mapping_(indexed_dataset, data_prefix, num_epochs, max_num_sampl | ||||
|     torch.distributed.all_reduce(counts, group=gpc.get_group(ParallelMode.DATA)) | ||||
|     if gpc.is_initialized(ParallelMode.PIPELINE): | ||||
|         torch.distributed.all_reduce(counts, group=gpc.get_group(ParallelMode.PIPELINE)) | ||||
|     assert counts[0].item() == (torch.distributed.get_world_size() // | ||||
|                                 torch.distributed.get_world_size(group=gpc.get_group(ParallelMode.SEQUENCE))) | ||||
|     assert counts[0].item() == ( | ||||
|         torch.distributed.get_world_size() | ||||
|         // torch.distributed.get_world_size(group=gpc.get_group(ParallelMode.SEQUENCE)) | ||||
|     ) | ||||
|  | ||||
|     # Load indexed dataset. | ||||
|     start_time = time.time() | ||||
|     samples_mapping = np.load(indexmap_filename, allow_pickle=True, mmap_mode='r') | ||||
|     logger.info('\n > loading indexed mapping from {}'.format(indexmap_filename) + | ||||
|                 '\n    loaded indexed file in {:3.3f} seconds'.format(time.time() - start_time) + | ||||
|                 '\n    total number of samples: {}'.format(samples_mapping.shape[0]), | ||||
|                 ranks=[0]) | ||||
|     samples_mapping = np.load(indexmap_filename, allow_pickle=True, mmap_mode="r") | ||||
|     logger.info( | ||||
|         "\n > loading indexed mapping from {}".format(indexmap_filename) | ||||
|         + "\n    loaded indexed file in {:3.3f} seconds".format(time.time() - start_time) | ||||
|         + "\n    total number of samples: {}".format(samples_mapping.shape[0]), | ||||
|         ranks=[0], | ||||
|     ) | ||||
|  | ||||
|     return samples_mapping | ||||
|  | ||||
|  | ||||
| def build_training_sample(sample, target_seq_length, max_seq_length, vocab_id_list, vocab_id_to_token_dict, cls_id, | ||||
|                           sep_id, mask_id, pad_id, masked_lm_prob, np_rng, binary_head): | ||||
| def build_training_sample( | ||||
|     sample, | ||||
|     target_seq_length, | ||||
|     max_seq_length, | ||||
|     vocab_id_list, | ||||
|     vocab_id_to_token_dict, | ||||
|     cls_id, | ||||
|     sep_id, | ||||
|     mask_id, | ||||
|     pad_id, | ||||
|     masked_lm_prob, | ||||
|     np_rng, | ||||
|     binary_head, | ||||
| ): | ||||
|     """Build training sample. | ||||
|  | ||||
|     Arguments: | ||||
| @@ -215,22 +252,30 @@ def build_training_sample(sample, target_seq_length, max_seq_length, vocab_id_li | ||||
|  | ||||
|     # Masking. | ||||
|     max_predictions_per_seq = masked_lm_prob * max_num_tokens | ||||
|     (tokens, masked_positions, masked_labels, | ||||
|      _) = create_masked_lm_predictions(tokens, vocab_id_list, vocab_id_to_token_dict, masked_lm_prob, cls_id, sep_id, | ||||
|                                        mask_id, max_predictions_per_seq, np_rng) | ||||
|     (tokens, masked_positions, masked_labels, _) = create_masked_lm_predictions( | ||||
|         tokens, | ||||
|         vocab_id_list, | ||||
|         vocab_id_to_token_dict, | ||||
|         masked_lm_prob, | ||||
|         cls_id, | ||||
|         sep_id, | ||||
|         mask_id, | ||||
|         max_predictions_per_seq, | ||||
|         np_rng, | ||||
|     ) | ||||
|  | ||||
|     # Padding. | ||||
|     tokens_np, tokentypes_np, labels_np, padding_mask_np, loss_mask_np \ | ||||
|         = pad_and_convert_to_numpy(tokens, tokentypes, masked_positions, | ||||
|                                    masked_labels, pad_id, max_seq_length) | ||||
|     tokens_np, tokentypes_np, labels_np, padding_mask_np, loss_mask_np = pad_and_convert_to_numpy( | ||||
|         tokens, tokentypes, masked_positions, masked_labels, pad_id, max_seq_length | ||||
|     ) | ||||
|  | ||||
|     train_sample = { | ||||
|         'text': tokens_np, | ||||
|         'types': tokentypes_np, | ||||
|         'labels': labels_np, | ||||
|         'is_random': int(is_next_random), | ||||
|         'loss_mask': loss_mask_np, | ||||
|         'padding_mask': padding_mask_np, | ||||
|         'truncated': int(truncated) | ||||
|         "text": tokens_np, | ||||
|         "types": tokentypes_np, | ||||
|         "labels": labels_np, | ||||
|         "is_random": int(is_next_random), | ||||
|         "loss_mask": loss_mask_np, | ||||
|         "padding_mask": padding_mask_np, | ||||
|         "truncated": int(truncated), | ||||
|     } | ||||
|     return train_sample | ||||
|   | ||||
| @@ -22,9 +22,7 @@ import torch | ||||
|  | ||||
|  | ||||
| class BlendableDataset(torch.utils.data.Dataset): | ||||
|  | ||||
|     def __init__(self, datasets, weights): | ||||
|  | ||||
|         self.datasets = datasets | ||||
|         num_datasets = len(datasets) | ||||
|         assert num_datasets == len(weights) | ||||
| @@ -46,12 +44,16 @@ class BlendableDataset(torch.utils.data.Dataset): | ||||
|         self.dataset_sample_index = np.zeros(self.size, dtype=np.int64) | ||||
|  | ||||
|         from . import helpers | ||||
|         helpers.build_blending_indices(self.dataset_index, | ||||
|                                        self.dataset_sample_index, | ||||
|                                        weights, num_datasets, self.size, | ||||
|                                        torch.distributed.get_rank() == 0) | ||||
|         print('> elapsed time for building blendable dataset indices: ' | ||||
|               '{:.2f} (sec)'.format(time.time() - start_time)) | ||||
|  | ||||
|         helpers.build_blending_indices( | ||||
|             self.dataset_index, | ||||
|             self.dataset_sample_index, | ||||
|             weights, | ||||
|             num_datasets, | ||||
|             self.size, | ||||
|             torch.distributed.get_rank() == 0, | ||||
|         ) | ||||
|         print("> elapsed time for building blendable dataset indices: " "{:.2f} (sec)".format(time.time() - start_time)) | ||||
|  | ||||
|     def __len__(self): | ||||
|         return self.size | ||||
|   | ||||
| @@ -1,29 +1,34 @@ | ||||
| from .blendable_dataset import BlendableDataset | ||||
| from .dataset_utils import get_datasets_weights_and_num_samples, get_indexed_dataset_, get_train_valid_test_split_ | ||||
| from .bert_dataset import BertDataset | ||||
| from colossalai.logging import get_dist_logger | ||||
|  | ||||
| DSET_TYPE_BERT = 'standard_bert' | ||||
| DSET_TYPE_ICT = 'ict' | ||||
| DSET_TYPE_T5 = 't5' | ||||
| from .bert_dataset import BertDataset | ||||
| from .blendable_dataset import BlendableDataset | ||||
| from .dataset_utils import get_datasets_weights_and_num_samples, get_indexed_dataset_, get_train_valid_test_split_ | ||||
|  | ||||
| DSET_TYPE_BERT = "standard_bert" | ||||
| DSET_TYPE_ICT = "ict" | ||||
| DSET_TYPE_T5 = "t5" | ||||
|  | ||||
| DSET_TYPES = [DSET_TYPE_BERT, DSET_TYPE_ICT, DSET_TYPE_T5] | ||||
|  | ||||
|  | ||||
| def _build_train_valid_test_datasets(data_prefix, data_impl, splits_string, | ||||
|                                      train_valid_test_num_samples, | ||||
|                                      max_seq_length, masked_lm_prob, | ||||
|                                      short_seq_prob, seed, skip_warmup, | ||||
|                                      binary_head, | ||||
|                                      dataset_type='standard_bert'): | ||||
|  | ||||
| def _build_train_valid_test_datasets( | ||||
|     data_prefix, | ||||
|     data_impl, | ||||
|     splits_string, | ||||
|     train_valid_test_num_samples, | ||||
|     max_seq_length, | ||||
|     masked_lm_prob, | ||||
|     short_seq_prob, | ||||
|     seed, | ||||
|     skip_warmup, | ||||
|     binary_head, | ||||
|     dataset_type="standard_bert", | ||||
| ): | ||||
|     if dataset_type not in DSET_TYPES: | ||||
|         raise ValueError("Invalid dataset_type: ", dataset_type) | ||||
|  | ||||
|     # Indexed dataset. | ||||
|     indexed_dataset = get_indexed_dataset_(data_prefix, | ||||
|                                            data_impl, | ||||
|                                            skip_warmup) | ||||
|     indexed_dataset = get_indexed_dataset_(data_prefix, data_impl, skip_warmup) | ||||
|  | ||||
|     # Get start and end indices of train/valid/train into doc-idx | ||||
|     # Note that doc-idx is designed to be num-docs + 1 so we can | ||||
| @@ -34,22 +39,25 @@ def _build_train_valid_test_datasets(data_prefix, data_impl, splits_string, | ||||
|     logger = get_dist_logger() | ||||
|  | ||||
|     # Print stats about the splits. | ||||
|     logger.info('\n > dataset split:', ranks=[0]) | ||||
|     logger.info("\n > dataset split:", ranks=[0]) | ||||
|  | ||||
|     def print_split_stats(name, index): | ||||
|         start_index = indexed_dataset.doc_idx[splits[index]] | ||||
|         end_index = indexed_dataset.doc_idx[splits[index + 1]] | ||||
|         logger.info('\n    {}:'.format(name) + | ||||
|                     '\n     document indices in [{}, {}) total of {} documents'.format( | ||||
|                         splits[index], splits[index + 1], | ||||
|                         splits[index + 1] - splits[index]) + | ||||
|                     '\n     sentence indices in [{}, {}) total of {} sentences'.format( | ||||
|                         start_index, end_index, | ||||
|                         end_index - start_index), | ||||
|                     ranks=[0]) | ||||
|     print_split_stats('train', 0) | ||||
|     print_split_stats('validation', 1) | ||||
|     print_split_stats('test', 2) | ||||
|         logger.info( | ||||
|             "\n    {}:".format(name) | ||||
|             + "\n     document indices in [{}, {}) total of {} documents".format( | ||||
|                 splits[index], splits[index + 1], splits[index + 1] - splits[index] | ||||
|             ) | ||||
|             + "\n     sentence indices in [{}, {}) total of {} sentences".format( | ||||
|                 start_index, end_index, end_index - start_index | ||||
|             ), | ||||
|             ranks=[0], | ||||
|         ) | ||||
|  | ||||
|     print_split_stats("train", 0) | ||||
|     print_split_stats("validation", 1) | ||||
|     print_split_stats("test", 2) | ||||
|  | ||||
|     def build_dataset(index, name): | ||||
|         dataset = None | ||||
| @@ -80,44 +88,53 @@ def _build_train_valid_test_datasets(data_prefix, data_impl, splits_string, | ||||
|                     masked_lm_prob=masked_lm_prob, | ||||
|                     short_seq_prob=short_seq_prob, | ||||
|                     binary_head=binary_head, | ||||
|                     **kwargs | ||||
|                     **kwargs, | ||||
|                 ) | ||||
|  | ||||
|             # Set the original pointer so dataset remains the main dataset. | ||||
|             indexed_dataset.set_doc_idx(doc_idx_ptr) | ||||
|             # Checks. | ||||
|             assert indexed_dataset.doc_idx[0] == 0 | ||||
|             assert indexed_dataset.doc_idx.shape[0] == \ | ||||
|                 (total_num_of_documents + 1) | ||||
|             assert indexed_dataset.doc_idx.shape[0] == (total_num_of_documents + 1) | ||||
|         return dataset | ||||
|  | ||||
|     train_dataset = build_dataset(0, 'train') | ||||
|     valid_dataset = build_dataset(1, 'valid') | ||||
|     test_dataset = build_dataset(2, 'test') | ||||
|     train_dataset = build_dataset(0, "train") | ||||
|     valid_dataset = build_dataset(1, "valid") | ||||
|     test_dataset = build_dataset(2, "test") | ||||
|  | ||||
|     return (train_dataset, valid_dataset, test_dataset) | ||||
|  | ||||
|  | ||||
| def build_train_valid_test_datasets(data_prefix, data_impl, splits_string, | ||||
|                                     train_valid_test_num_samples, | ||||
|                                     max_seq_length, masked_lm_prob, | ||||
|                                     short_seq_prob, seed, skip_warmup, | ||||
|                                     binary_head, | ||||
|                                     dataset_type='standard_bert'): | ||||
|  | ||||
| def build_train_valid_test_datasets( | ||||
|     data_prefix, | ||||
|     data_impl, | ||||
|     splits_string, | ||||
|     train_valid_test_num_samples, | ||||
|     max_seq_length, | ||||
|     masked_lm_prob, | ||||
|     short_seq_prob, | ||||
|     seed, | ||||
|     skip_warmup, | ||||
|     binary_head, | ||||
|     dataset_type="standard_bert", | ||||
| ): | ||||
|     if len(data_prefix) == 1: | ||||
|         return _build_train_valid_test_datasets(data_prefix[0], | ||||
|                                                 data_impl, splits_string, | ||||
|                                                 train_valid_test_num_samples, | ||||
|                                                 max_seq_length, masked_lm_prob, | ||||
|                                                 short_seq_prob, seed, | ||||
|                                                 skip_warmup, | ||||
|                                                 binary_head, | ||||
|                                                 dataset_type=dataset_type) | ||||
|         return _build_train_valid_test_datasets( | ||||
|             data_prefix[0], | ||||
|             data_impl, | ||||
|             splits_string, | ||||
|             train_valid_test_num_samples, | ||||
|             max_seq_length, | ||||
|             masked_lm_prob, | ||||
|             short_seq_prob, | ||||
|             seed, | ||||
|             skip_warmup, | ||||
|             binary_head, | ||||
|             dataset_type=dataset_type, | ||||
|         ) | ||||
|     # Blending dataset. | ||||
|     # Parse the values. | ||||
|     output = get_datasets_weights_and_num_samples(data_prefix, | ||||
|                                                   train_valid_test_num_samples) | ||||
|     output = get_datasets_weights_and_num_samples(data_prefix, train_valid_test_num_samples) | ||||
|     prefixes, weights, datasets_train_valid_test_num_samples = output | ||||
|  | ||||
|     # Build individual datasets. | ||||
| @@ -126,10 +143,18 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string, | ||||
|     test_datasets = [] | ||||
|     for i in range(len(prefixes)): | ||||
|         train_ds, valid_ds, test_ds = _build_train_valid_test_datasets( | ||||
|             prefixes[i], data_impl, splits_string, | ||||
|             prefixes[i], | ||||
|             data_impl, | ||||
|             splits_string, | ||||
|             datasets_train_valid_test_num_samples[i], | ||||
|             max_seq_length, masked_lm_prob, short_seq_prob, | ||||
|             seed, skip_warmup, binary_head, dataset_type=dataset_type) | ||||
|             max_seq_length, | ||||
|             masked_lm_prob, | ||||
|             short_seq_prob, | ||||
|             seed, | ||||
|             skip_warmup, | ||||
|             binary_head, | ||||
|             dataset_type=dataset_type, | ||||
|         ) | ||||
|         if train_ds: | ||||
|             train_datasets.append(train_ds) | ||||
|         if valid_ds: | ||||
| @@ -148,5 +173,4 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string, | ||||
|     if test_datasets: | ||||
|         blending_test_dataset = BlendableDataset(test_datasets, weights) | ||||
|  | ||||
|     return (blending_train_dataset, blending_valid_dataset, | ||||
|             blending_test_dataset) | ||||
|     return (blending_train_dataset, blending_valid_dataset, blending_test_dataset) | ||||
|   | ||||
| @@ -14,7 +14,6 @@ | ||||
| # limitations under the License. | ||||
| """Dataloaders.""" | ||||
|  | ||||
| import random | ||||
|  | ||||
| import torch | ||||
|  | ||||
| @@ -22,61 +21,60 @@ from colossalai.legacy.context import ParallelMode | ||||
| from colossalai.legacy.core import global_context as gpc | ||||
|  | ||||
|  | ||||
| def build_pretraining_data_loader(dataset, consumed_samples, micro_batch_size, dataloader_type='single', num_workers=0): | ||||
| def build_pretraining_data_loader(dataset, consumed_samples, micro_batch_size, dataloader_type="single", num_workers=0): | ||||
|     """Build dataloader given an input dataset.""" | ||||
|  | ||||
|     if dataset is None: | ||||
|         return None | ||||
|  | ||||
|     # Megatron sampler | ||||
|     if dataloader_type == 'single': | ||||
|         batch_sampler = MegatronPretrainingSampler(total_samples=len(dataset), | ||||
|                                                    consumed_samples=consumed_samples, | ||||
|                                                    micro_batch_size=micro_batch_size, | ||||
|                                                    data_parallel_rank=gpc.get_local_rank(ParallelMode.DATA), | ||||
|                                                    data_parallel_size=gpc.get_world_size(ParallelMode.DATA)) | ||||
|     elif dataloader_type == 'cyclic': | ||||
|         batch_sampler = MegatronPretrainingRandomSampler(total_samples=len(dataset), | ||||
|                                                          consumed_samples=consumed_samples, | ||||
|                                                          micro_batch_size=micro_batch_size, | ||||
|                                                          data_parallel_rank=gpc.get_local_rank(ParallelMode.DATA), | ||||
|                                                          data_parallel_size=gpc.get_world_size(ParallelMode.DATA)) | ||||
|     if dataloader_type == "single": | ||||
|         batch_sampler = MegatronPretrainingSampler( | ||||
|             total_samples=len(dataset), | ||||
|             consumed_samples=consumed_samples, | ||||
|             micro_batch_size=micro_batch_size, | ||||
|             data_parallel_rank=gpc.get_local_rank(ParallelMode.DATA), | ||||
|             data_parallel_size=gpc.get_world_size(ParallelMode.DATA), | ||||
|         ) | ||||
|     elif dataloader_type == "cyclic": | ||||
|         batch_sampler = MegatronPretrainingRandomSampler( | ||||
|             total_samples=len(dataset), | ||||
|             consumed_samples=consumed_samples, | ||||
|             micro_batch_size=micro_batch_size, | ||||
|             data_parallel_rank=gpc.get_local_rank(ParallelMode.DATA), | ||||
|             data_parallel_size=gpc.get_world_size(ParallelMode.DATA), | ||||
|         ) | ||||
|     else: | ||||
|         raise Exception('{} dataloader type is not supported.'.format(dataloader_type)) | ||||
|         raise Exception("{} dataloader type is not supported.".format(dataloader_type)) | ||||
|  | ||||
|     # Torch dataloader. | ||||
|     return torch.utils.data.DataLoader(dataset, batch_sampler=batch_sampler, num_workers=num_workers, pin_memory=True) | ||||
|  | ||||
|  | ||||
| class MegatronPretrainingSampler: | ||||
|  | ||||
|     def __init__(self, | ||||
|                  total_samples, | ||||
|                  consumed_samples, | ||||
|                  micro_batch_size, | ||||
|                  data_parallel_rank, | ||||
|                  data_parallel_size, | ||||
|                  drop_last=True): | ||||
|     def __init__( | ||||
|         self, total_samples, consumed_samples, micro_batch_size, data_parallel_rank, data_parallel_size, drop_last=True | ||||
|     ): | ||||
|         # Keep a copy of input params for later use. | ||||
|         self.total_samples = total_samples | ||||
|         self.consumed_samples = consumed_samples | ||||
|         self.micro_batch_size = micro_batch_size | ||||
|         self.data_parallel_rank = data_parallel_rank | ||||
|         self.micro_batch_times_data_parallel_size = \ | ||||
|             self.micro_batch_size * data_parallel_size | ||||
|         self.micro_batch_times_data_parallel_size = self.micro_batch_size * data_parallel_size | ||||
|         self.drop_last = drop_last | ||||
|  | ||||
|         # Sanity checks. | ||||
|         assert self.total_samples > 0, \ | ||||
|             'no sample to consume: {}'.format(self.total_samples) | ||||
|         assert self.consumed_samples < self.total_samples, \ | ||||
|             'no samples left to consume: {}, {}'.format(self.consumed_samples, | ||||
|                                                         self.total_samples) | ||||
|         assert self.total_samples > 0, "no sample to consume: {}".format(self.total_samples) | ||||
|         assert self.consumed_samples < self.total_samples, "no samples left to consume: {}, {}".format( | ||||
|             self.consumed_samples, self.total_samples | ||||
|         ) | ||||
|         assert self.micro_batch_size > 0 | ||||
|         assert data_parallel_size > 0 | ||||
|         assert self.data_parallel_rank < data_parallel_size, \ | ||||
|             'data_parallel_rank should be smaller than data size: {}, ' \ | ||||
|             '{}'.format(self.data_parallel_rank, data_parallel_size) | ||||
|         assert ( | ||||
|             self.data_parallel_rank < data_parallel_size | ||||
|         ), "data_parallel_rank should be smaller than data size: {}, " "{}".format( | ||||
|             self.data_parallel_rank, data_parallel_size | ||||
|         ) | ||||
|  | ||||
|     def __len__(self): | ||||
|         return self.total_samples | ||||
| @@ -103,7 +101,6 @@ class MegatronPretrainingSampler: | ||||
|  | ||||
|  | ||||
| class MegatronPretrainingRandomSampler: | ||||
|  | ||||
|     def __init__(self, total_samples, consumed_samples, micro_batch_size, data_parallel_rank, data_parallel_size): | ||||
|         # Keep a copy of input params for later use. | ||||
|         self.total_samples = total_samples | ||||
| @@ -111,19 +108,18 @@ class MegatronPretrainingRandomSampler: | ||||
|         self.micro_batch_size = micro_batch_size | ||||
|         self.data_parallel_rank = data_parallel_rank | ||||
|         self.data_parallel_size = data_parallel_size | ||||
|         self.micro_batch_times_data_parallel_size = \ | ||||
|             self.micro_batch_size * data_parallel_size | ||||
|         self.last_batch_size = \ | ||||
|             self.total_samples % self.micro_batch_times_data_parallel_size | ||||
|         self.micro_batch_times_data_parallel_size = self.micro_batch_size * data_parallel_size | ||||
|         self.last_batch_size = self.total_samples % self.micro_batch_times_data_parallel_size | ||||
|  | ||||
|         # Sanity checks. | ||||
|         assert self.total_samples > 0, \ | ||||
|             'no sample to consume: {}'.format(self.total_samples) | ||||
|         assert self.total_samples > 0, "no sample to consume: {}".format(self.total_samples) | ||||
|         assert self.micro_batch_size > 0 | ||||
|         assert data_parallel_size > 0 | ||||
|         assert self.data_parallel_rank < data_parallel_size, \ | ||||
|             'data_parallel_rank should be smaller than data size: {}, ' \ | ||||
|             '{}'.format(self.data_parallel_rank, data_parallel_size) | ||||
|         assert ( | ||||
|             self.data_parallel_rank < data_parallel_size | ||||
|         ), "data_parallel_rank should be smaller than data size: {}, " "{}".format( | ||||
|             self.data_parallel_rank, data_parallel_size | ||||
|         ) | ||||
|  | ||||
|     def __len__(self): | ||||
|         return self.total_samples | ||||
| @@ -135,8 +131,7 @@ class MegatronPretrainingRandomSampler: | ||||
|         assert current_epoch_samples % self.micro_batch_times_data_parallel_size == 0 | ||||
|  | ||||
|         # data sharding and random sampling | ||||
|         bucket_size = (self.total_samples // self.micro_batch_times_data_parallel_size) \ | ||||
|             * self.micro_batch_size | ||||
|         bucket_size = (self.total_samples // self.micro_batch_times_data_parallel_size) * self.micro_batch_size | ||||
|         bucket_offset = current_epoch_samples // self.data_parallel_size | ||||
|         start_idx = self.data_parallel_rank * bucket_size | ||||
|  | ||||
|   | ||||
| @@ -18,32 +18,33 @@ | ||||
| #   https://github.com/google-research/albert/blob/master/create_pretraining_data.py | ||||
| # with some modifications. | ||||
|  | ||||
| import collections | ||||
| import math | ||||
| import time | ||||
| import collections | ||||
| from colossalai.logging import get_dist_logger | ||||
|  | ||||
| import numpy as np | ||||
|  | ||||
| from colossalai.logging import get_dist_logger | ||||
|  | ||||
| from .blendable_dataset import BlendableDataset | ||||
| from .indexed_dataset import make_dataset as make_indexed_dataset | ||||
|  | ||||
| DSET_TYPE_STD = 'standard_bert' | ||||
| DSET_TYPE_ICT = 'ict' | ||||
| DSET_TYPE_STD = "standard_bert" | ||||
| DSET_TYPE_ICT = "ict" | ||||
|  | ||||
| DSET_TYPES = [DSET_TYPE_ICT, DSET_TYPE_STD] | ||||
|  | ||||
|  | ||||
| def get_datasets_weights_and_num_samples(data_prefix, | ||||
|                                          train_valid_test_num_samples): | ||||
|  | ||||
| def get_datasets_weights_and_num_samples(data_prefix, train_valid_test_num_samples): | ||||
|     # The data prefix should be in the format of: | ||||
|     #   weight-1, data-prefix-1, weight-2, data-prefix-2, .. | ||||
|     assert len(data_prefix) % 2 == 0 | ||||
|     num_datasets = len(data_prefix) // 2 | ||||
|     weights = [0]*num_datasets | ||||
|     prefixes = [0]*num_datasets | ||||
|     weights = [0] * num_datasets | ||||
|     prefixes = [0] * num_datasets | ||||
|     for i in range(num_datasets): | ||||
|         weights[i] = float(data_prefix[2*i]) | ||||
|         prefixes[i] = (data_prefix[2*i+1]).strip() | ||||
|         weights[i] = float(data_prefix[2 * i]) | ||||
|         prefixes[i] = (data_prefix[2 * i + 1]).strip() | ||||
|     # Normalize weights | ||||
|     weight_sum = 0.0 | ||||
|     for weight in weights: | ||||
| @@ -57,8 +58,8 @@ def get_datasets_weights_and_num_samples(data_prefix, | ||||
|     datasets_train_valid_test_num_samples = [] | ||||
|     for weight in weights: | ||||
|         datasets_train_valid_test_num_samples.append( | ||||
|             [int(math.ceil(val * weight * 1.005)) | ||||
|              for val in train_valid_test_num_samples]) | ||||
|             [int(math.ceil(val * weight * 1.005)) for val in train_valid_test_num_samples] | ||||
|         ) | ||||
|  | ||||
|     return prefixes, weights, datasets_train_valid_test_num_samples | ||||
|  | ||||
| @@ -68,11 +69,13 @@ def compile_helper(): | ||||
|     is invoked on a single process.""" | ||||
|     import os | ||||
|     import subprocess | ||||
|  | ||||
|     path = os.path.abspath(os.path.dirname(__file__)) | ||||
|     ret = subprocess.run(['make', '-C', path]) | ||||
|     ret = subprocess.run(["make", "-C", path]) | ||||
|     if ret.returncode != 0: | ||||
|         print("Making C++ dataset helpers module failed, exiting.") | ||||
|         import sys | ||||
|  | ||||
|         sys.exit(1) | ||||
|  | ||||
|  | ||||
| @@ -82,7 +85,7 @@ def get_a_and_b_segments(sample, np_rng): | ||||
|     # Number of sentences in the sample. | ||||
|     n_sentences = len(sample) | ||||
|     # Make sure we always have two sentences. | ||||
|     assert n_sentences > 1, 'make sure each sample has at least two sentences.' | ||||
|     assert n_sentences > 1, "make sure each sample has at least two sentences." | ||||
|  | ||||
|     # First part: | ||||
|     # `a_end` is how many sentences go into the `A`. | ||||
| @@ -110,7 +113,7 @@ def get_a_and_b_segments(sample, np_rng): | ||||
|  | ||||
| def truncate_segments(tokens_a, tokens_b, len_a, len_b, max_num_tokens, np_rng): | ||||
|     """Truncates a pair of sequences to a maximum sequence length.""" | ||||
|     #print(len_a, len_b, max_num_tokens) | ||||
|     # print(len_a, len_b, max_num_tokens) | ||||
|     assert len_a > 0 | ||||
|     if len_a + len_b <= max_num_tokens: | ||||
|         return False | ||||
| @@ -155,8 +158,7 @@ def create_tokens_and_tokentypes(tokens_a, tokens_b, cls_id, sep_id): | ||||
|     return tokens, tokentypes | ||||
|  | ||||
|  | ||||
| MaskedLmInstance = collections.namedtuple("MaskedLmInstance", | ||||
|                                           ["index", "label"]) | ||||
| MaskedLmInstance = collections.namedtuple("MaskedLmInstance", ["index", "label"]) | ||||
|  | ||||
|  | ||||
| def is_start_piece(piece): | ||||
| @@ -168,16 +170,21 @@ def is_start_piece(piece): | ||||
|     return not piece.startswith("##") | ||||
|  | ||||
|  | ||||
| def create_masked_lm_predictions(tokens, | ||||
|                                  vocab_id_list, vocab_id_to_token_dict, | ||||
|                                  masked_lm_prob, | ||||
|                                  cls_id, sep_id, mask_id, | ||||
|                                  max_predictions_per_seq, | ||||
|                                  np_rng, | ||||
|                                  max_ngrams=3, | ||||
|                                  do_whole_word_mask=True, | ||||
|                                  favor_longer_ngram=False, | ||||
|                                  do_permutation=False): | ||||
| def create_masked_lm_predictions( | ||||
|     tokens, | ||||
|     vocab_id_list, | ||||
|     vocab_id_to_token_dict, | ||||
|     masked_lm_prob, | ||||
|     cls_id, | ||||
|     sep_id, | ||||
|     mask_id, | ||||
|     max_predictions_per_seq, | ||||
|     np_rng, | ||||
|     max_ngrams=3, | ||||
|     do_whole_word_mask=True, | ||||
|     favor_longer_ngram=False, | ||||
|     do_permutation=False, | ||||
| ): | ||||
|     """Creates the predictions for the masked LM objective. | ||||
|     Note: Tokens here are vocab ids and not text tokens.""" | ||||
|  | ||||
| @@ -187,7 +194,7 @@ def create_masked_lm_predictions(tokens, | ||||
|     # on-the-fly whole word masking is possible. | ||||
|     token_boundary = [0] * len(tokens) | ||||
|  | ||||
|     for (i, token) in enumerate(tokens): | ||||
|     for i, token in enumerate(tokens): | ||||
|         if token == cls_id or token == sep_id: | ||||
|             token_boundary[i] = 1 | ||||
|             continue | ||||
| @@ -197,8 +204,7 @@ def create_masked_lm_predictions(tokens, | ||||
|         # Note that Whole Word Masking does *not* change the training code | ||||
|         # at all -- we still predict each WordPiece independently, softmaxed | ||||
|         # over the entire vocabulary. | ||||
|         if (do_whole_word_mask and len(cand_indexes) >= 1 and | ||||
|                 not is_start_piece(vocab_id_to_token_dict[token])): | ||||
|         if do_whole_word_mask and len(cand_indexes) >= 1 and not is_start_piece(vocab_id_to_token_dict[token]): | ||||
|             cand_indexes[-1].append(i) | ||||
|         else: | ||||
|             cand_indexes.append([i]) | ||||
| @@ -211,16 +217,14 @@ def create_masked_lm_predictions(tokens, | ||||
|     masked_lm_labels = [] | ||||
|  | ||||
|     if masked_lm_prob == 0: | ||||
|         return (output_tokens, masked_lm_positions, | ||||
|                 masked_lm_labels, token_boundary) | ||||
|         return (output_tokens, masked_lm_positions, masked_lm_labels, token_boundary) | ||||
|  | ||||
|     num_to_predict = min(max_predictions_per_seq, | ||||
|                          max(1, int(round(len(tokens) * masked_lm_prob)))) | ||||
|     num_to_predict = min(max_predictions_per_seq, max(1, int(round(len(tokens) * masked_lm_prob)))) | ||||
|  | ||||
|     # Note(mingdachen): | ||||
|     # By default, we set the probabilities to favor shorter ngram sequences. | ||||
|     ngrams = np.arange(1, max_ngrams + 1, dtype=np.int64) | ||||
|     pvals = 1. / np.arange(1, max_ngrams + 1) | ||||
|     pvals = 1.0 / np.arange(1, max_ngrams + 1) | ||||
|     pvals /= pvals.sum(keepdims=True) | ||||
|  | ||||
|     if favor_longer_ngram: | ||||
| @@ -230,7 +234,7 @@ def create_masked_lm_predictions(tokens, | ||||
|     for idx in range(len(cand_indexes)): | ||||
|         ngram_index = [] | ||||
|         for n in ngrams: | ||||
|             ngram_index.append(cand_indexes[idx:idx + n]) | ||||
|             ngram_index.append(cand_indexes[idx : idx + n]) | ||||
|         ngram_indexes.append(ngram_index) | ||||
|  | ||||
|     np_rng.shuffle(ngram_indexes) | ||||
| @@ -249,9 +253,10 @@ def create_masked_lm_predictions(tokens, | ||||
|                 if index in covered_indexes: | ||||
|                     continue | ||||
|  | ||||
|         n = np_rng.choice(ngrams[:len(cand_index_set)], | ||||
|                           p=pvals[:len(cand_index_set)] / | ||||
|                           pvals[:len(cand_index_set)].sum(keepdims=True)) | ||||
|         n = np_rng.choice( | ||||
|             ngrams[: len(cand_index_set)], | ||||
|             p=pvals[: len(cand_index_set)] / pvals[: len(cand_index_set)].sum(keepdims=True), | ||||
|         ) | ||||
|         index_set = sum(cand_index_set[n - 1], []) | ||||
|         n -= 1 | ||||
|         # Note(mingdachen): | ||||
| @@ -309,9 +314,10 @@ def create_masked_lm_predictions(tokens, | ||||
|                     if index in covered_indexes or index in select_indexes: | ||||
|                         continue | ||||
|  | ||||
|             n = np.random.choice(ngrams[:len(cand_index_set)], | ||||
|                                  p=pvals[:len(cand_index_set)] / | ||||
|                                  pvals[:len(cand_index_set)].sum(keepdims=True)) | ||||
|             n = np.random.choice( | ||||
|                 ngrams[: len(cand_index_set)], | ||||
|                 p=pvals[: len(cand_index_set)] / pvals[: len(cand_index_set)].sum(keepdims=True), | ||||
|             ) | ||||
|             index_set = sum(cand_index_set[n - 1], []) | ||||
|             n -= 1 | ||||
|  | ||||
| @@ -353,8 +359,7 @@ def create_masked_lm_predictions(tokens, | ||||
|     return (output_tokens, masked_lm_positions, masked_lm_labels, token_boundary) | ||||
|  | ||||
|  | ||||
| def pad_and_convert_to_numpy(tokens, tokentypes, masked_positions, | ||||
|                              masked_labels, pad_id, max_seq_length): | ||||
| def pad_and_convert_to_numpy(tokens, tokentypes, masked_positions, masked_labels, pad_id, max_seq_length): | ||||
|     """Pad sequences and convert them to numpy.""" | ||||
|  | ||||
|     # Some checks. | ||||
| @@ -370,8 +375,7 @@ def pad_and_convert_to_numpy(tokens, tokentypes, masked_positions, | ||||
|     tokentypes_np = np.array(tokentypes + filler, dtype=np.int64) | ||||
|  | ||||
|     # Padding mask. | ||||
|     padding_mask_np = np.array([1] * num_tokens + [0] * padding_length, | ||||
|                                dtype=np.int64) | ||||
|     padding_mask_np = np.array([1] * num_tokens + [0] * padding_length, dtype=np.int64) | ||||
|  | ||||
|     # Lables and loss mask. | ||||
|     labels = [-1] * max_seq_length | ||||
| @@ -386,26 +390,36 @@ def pad_and_convert_to_numpy(tokens, tokentypes, masked_positions, | ||||
|     return tokens_np, tokentypes_np, labels_np, padding_mask_np, loss_mask_np | ||||
|  | ||||
|  | ||||
| def build_train_valid_test_datasets(data_prefix, data_impl, splits_string, | ||||
|                                     train_valid_test_num_samples, | ||||
|                                     max_seq_length, masked_lm_prob, | ||||
|                                     short_seq_prob, seed, skip_warmup, | ||||
|                                     binary_head, | ||||
|                                     dataset_type='standard_bert'): | ||||
|  | ||||
| def build_train_valid_test_datasets( | ||||
|     data_prefix, | ||||
|     data_impl, | ||||
|     splits_string, | ||||
|     train_valid_test_num_samples, | ||||
|     max_seq_length, | ||||
|     masked_lm_prob, | ||||
|     short_seq_prob, | ||||
|     seed, | ||||
|     skip_warmup, | ||||
|     binary_head, | ||||
|     dataset_type="standard_bert", | ||||
| ): | ||||
|     if len(data_prefix) == 1: | ||||
|         return _build_train_valid_test_datasets(data_prefix[0], | ||||
|                                                 data_impl, splits_string, | ||||
|                                                 train_valid_test_num_samples, | ||||
|                                                 max_seq_length, masked_lm_prob, | ||||
|                                                 short_seq_prob, seed, | ||||
|                                                 skip_warmup, | ||||
|                                                 binary_head, | ||||
|                                                 dataset_type=dataset_type) | ||||
|         return _build_train_valid_test_datasets( | ||||
|             data_prefix[0], | ||||
|             data_impl, | ||||
|             splits_string, | ||||
|             train_valid_test_num_samples, | ||||
|             max_seq_length, | ||||
|             masked_lm_prob, | ||||
|             short_seq_prob, | ||||
|             seed, | ||||
|             skip_warmup, | ||||
|             binary_head, | ||||
|             dataset_type=dataset_type, | ||||
|         ) | ||||
|     # Blending dataset. | ||||
|     # Parse the values. | ||||
|     output = get_datasets_weights_and_num_samples(data_prefix, | ||||
|                                                   train_valid_test_num_samples) | ||||
|     output = get_datasets_weights_and_num_samples(data_prefix, train_valid_test_num_samples) | ||||
|     prefixes, weights, datasets_train_valid_test_num_samples = output | ||||
|  | ||||
|     # Build individual datasets. | ||||
| @@ -414,10 +428,18 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string, | ||||
|     test_datasets = [] | ||||
|     for i in range(len(prefixes)): | ||||
|         train_ds, valid_ds, test_ds = _build_train_valid_test_datasets( | ||||
|             prefixes[i], data_impl, splits_string, | ||||
|             prefixes[i], | ||||
|             data_impl, | ||||
|             splits_string, | ||||
|             datasets_train_valid_test_num_samples[i], | ||||
|             max_seq_length, masked_lm_prob, short_seq_prob, | ||||
|             seed, skip_warmup, binary_head, dataset_type=dataset_type) | ||||
|             max_seq_length, | ||||
|             masked_lm_prob, | ||||
|             short_seq_prob, | ||||
|             seed, | ||||
|             skip_warmup, | ||||
|             binary_head, | ||||
|             dataset_type=dataset_type, | ||||
|         ) | ||||
|         if train_ds: | ||||
|             train_datasets.append(train_ds) | ||||
|         if valid_ds: | ||||
| @@ -436,31 +458,33 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string, | ||||
|     if test_datasets: | ||||
|         blending_test_dataset = BlendableDataset(test_datasets, weights) | ||||
|  | ||||
|     return (blending_train_dataset, blending_valid_dataset, | ||||
|             blending_test_dataset) | ||||
|     return (blending_train_dataset, blending_valid_dataset, blending_test_dataset) | ||||
|  | ||||
|  | ||||
| def _build_train_valid_test_datasets(data_prefix, data_impl, splits_string, | ||||
|                                      train_valid_test_num_samples, | ||||
|                                      max_seq_length, masked_lm_prob, | ||||
|                                      short_seq_prob, seed, skip_warmup, | ||||
|                                      binary_head, | ||||
|                                      dataset_type='standard_bert'): | ||||
| def _build_train_valid_test_datasets( | ||||
|     data_prefix, | ||||
|     data_impl, | ||||
|     splits_string, | ||||
|     train_valid_test_num_samples, | ||||
|     max_seq_length, | ||||
|     masked_lm_prob, | ||||
|     short_seq_prob, | ||||
|     seed, | ||||
|     skip_warmup, | ||||
|     binary_head, | ||||
|     dataset_type="standard_bert", | ||||
| ): | ||||
|     logger = get_dist_logger() | ||||
|  | ||||
|     if dataset_type not in DSET_TYPES: | ||||
|         raise ValueError("Invalid dataset_type: ", dataset_type) | ||||
|  | ||||
|     # Indexed dataset. | ||||
|     indexed_dataset = get_indexed_dataset_(data_prefix, | ||||
|                                            data_impl, | ||||
|                                            skip_warmup) | ||||
|     indexed_dataset = get_indexed_dataset_(data_prefix, data_impl, skip_warmup) | ||||
|  | ||||
|     if dataset_type == DSET_TYPE_ICT: | ||||
|         args = get_args() | ||||
|         title_dataset = get_indexed_dataset_(args.titles_data_path, | ||||
|                                              data_impl, | ||||
|                                              skip_warmup) | ||||
|         title_dataset = get_indexed_dataset_(args.titles_data_path, data_impl, skip_warmup) | ||||
|  | ||||
|     # Get start and end indices of train/valid/train into doc-idx | ||||
|     # Note that doc-idx is designed to be num-docs + 1 so we can | ||||
| @@ -469,27 +493,29 @@ def _build_train_valid_test_datasets(data_prefix, data_impl, splits_string, | ||||
|     splits = get_train_valid_test_split_(splits_string, total_num_of_documents) | ||||
|  | ||||
|     # Print stats about the splits. | ||||
|     logger.info('\n > dataset split:') | ||||
|     logger.info("\n > dataset split:") | ||||
|  | ||||
|     def print_split_stats(name, index): | ||||
|         start_index = indexed_dataset.doc_idx[splits[index]] | ||||
|         end_index = indexed_dataset.doc_idx[splits[index + 1]] | ||||
|         logger.info('\n    {}:'.format(name) + | ||||
|                     '\n     document indices in [{}, {}) total of {} documents'.format( | ||||
|                         splits[index], | ||||
|                         splits[index + 1], | ||||
|                         splits[index + 1] - splits[index]) + | ||||
|                     '\n     sentence indices in [{}, {}) total of {} sentences'.format( | ||||
|                         start_index, | ||||
|                         end_index, | ||||
|                         end_index - start_index), | ||||
|                     ranks=[0]) | ||||
|     print_split_stats('train', 0) | ||||
|     print_split_stats('validation', 1) | ||||
|     print_split_stats('test', 2) | ||||
|         logger.info( | ||||
|             "\n    {}:".format(name) | ||||
|             + "\n     document indices in [{}, {}) total of {} documents".format( | ||||
|                 splits[index], splits[index + 1], splits[index + 1] - splits[index] | ||||
|             ) | ||||
|             + "\n     sentence indices in [{}, {}) total of {} sentences".format( | ||||
|                 start_index, end_index, end_index - start_index | ||||
|             ), | ||||
|             ranks=[0], | ||||
|         ) | ||||
|  | ||||
|     print_split_stats("train", 0) | ||||
|     print_split_stats("validation", 1) | ||||
|     print_split_stats("test", 2) | ||||
|  | ||||
|     def build_dataset(index, name): | ||||
|         from .bert_dataset import BertDataset | ||||
|  | ||||
|         dataset = None | ||||
|         if splits[index + 1] > splits[index]: | ||||
|             # Get the pointer to the original doc-idx so we can set it later. | ||||
| @@ -508,7 +534,7 @@ def _build_train_valid_test_datasets(data_prefix, data_impl, splits_string, | ||||
|                 max_num_samples=train_valid_test_num_samples[index], | ||||
|                 max_seq_length=max_seq_length, | ||||
|                 seed=seed, | ||||
|                 binary_head=binary_head | ||||
|                 binary_head=binary_head, | ||||
|             ) | ||||
|  | ||||
|             if dataset_type == DSET_TYPE_ICT: | ||||
| @@ -518,27 +544,26 @@ def _build_train_valid_test_datasets(data_prefix, data_impl, splits_string, | ||||
|                     title_dataset=title_dataset, | ||||
|                     query_in_block_prob=args.query_in_block_prob, | ||||
|                     use_one_sent_docs=args.use_one_sent_docs, | ||||
|                     **kwargs | ||||
|                     **kwargs, | ||||
|                 ) | ||||
|             else: | ||||
|                 dataset = BertDataset( | ||||
|                     indexed_dataset=indexed_dataset, | ||||
|                     masked_lm_prob=masked_lm_prob, | ||||
|                     short_seq_prob=short_seq_prob, | ||||
|                     **kwargs | ||||
|                     **kwargs, | ||||
|                 ) | ||||
|  | ||||
|             # Set the original pointer so dataset remains the main dataset. | ||||
|             indexed_dataset.set_doc_idx(doc_idx_ptr) | ||||
|             # Checks. | ||||
|             assert indexed_dataset.doc_idx[0] == 0 | ||||
|             assert indexed_dataset.doc_idx.shape[0] == \ | ||||
|                 (total_num_of_documents + 1) | ||||
|             assert indexed_dataset.doc_idx.shape[0] == (total_num_of_documents + 1) | ||||
|         return dataset | ||||
|  | ||||
|     train_dataset = build_dataset(0, 'train') | ||||
|     valid_dataset = build_dataset(1, 'valid') | ||||
|     test_dataset = build_dataset(2, 'test') | ||||
|     train_dataset = build_dataset(0, "train") | ||||
|     valid_dataset = build_dataset(1, "valid") | ||||
|     test_dataset = build_dataset(2, "test") | ||||
|  | ||||
|     return (train_dataset, valid_dataset, test_dataset) | ||||
|  | ||||
| @@ -546,44 +571,41 @@ def _build_train_valid_test_datasets(data_prefix, data_impl, splits_string, | ||||
| def get_indexed_dataset_(data_prefix, data_impl, skip_warmup): | ||||
|     logger = get_dist_logger() | ||||
|     start_time = time.time() | ||||
|     indexed_dataset = make_indexed_dataset(data_prefix, | ||||
|                                            data_impl, | ||||
|                                            skip_warmup) | ||||
|     indexed_dataset = make_indexed_dataset(data_prefix, data_impl, skip_warmup) | ||||
|     assert indexed_dataset.sizes.shape[0] == indexed_dataset.doc_idx[-1] | ||||
|     logger.info('\n > building dataset index ...', ranks=[0]) | ||||
|     logger.info('\n > finished creating indexed dataset in {:4f} ' | ||||
|                 'seconds'.format(time.time() - start_time), ranks=[0]) | ||||
|     logger.info('\n > indexed dataset stats:' + | ||||
|                 '\n    number of documents: {}'.format( | ||||
|                     indexed_dataset.doc_idx.shape[0] - 1) + | ||||
|                 '\n    number of sentences: {}'.format( | ||||
|                     indexed_dataset.sizes.shape[0]), | ||||
|                 ranks=[0] | ||||
|                 ) | ||||
|     logger.info("\n > building dataset index ...", ranks=[0]) | ||||
|     logger.info( | ||||
|         "\n > finished creating indexed dataset in {:4f} " "seconds".format(time.time() - start_time), ranks=[0] | ||||
|     ) | ||||
|     logger.info( | ||||
|         "\n > indexed dataset stats:" | ||||
|         + "\n    number of documents: {}".format(indexed_dataset.doc_idx.shape[0] - 1) | ||||
|         + "\n    number of sentences: {}".format(indexed_dataset.sizes.shape[0]), | ||||
|         ranks=[0], | ||||
|     ) | ||||
|  | ||||
|     return indexed_dataset | ||||
|  | ||||
|  | ||||
| def get_train_valid_test_split_(splits_string, size): | ||||
|     """ Get dataset splits from comma or '/' separated string list.""" | ||||
|     """Get dataset splits from comma or '/' separated string list.""" | ||||
|  | ||||
|     splits = [] | ||||
|     if splits_string.find(',') != -1: | ||||
|         splits = [float(s) for s in splits_string.split(',')] | ||||
|     elif splits_string.find('/') != -1: | ||||
|         splits = [float(s) for s in splits_string.split('/')] | ||||
|     if splits_string.find(",") != -1: | ||||
|         splits = [float(s) for s in splits_string.split(",")] | ||||
|     elif splits_string.find("/") != -1: | ||||
|         splits = [float(s) for s in splits_string.split("/")] | ||||
|     else: | ||||
|         splits = [float(splits_string)] | ||||
|     while len(splits) < 3: | ||||
|         splits.append(0.) | ||||
|         splits.append(0.0) | ||||
|     splits = splits[:3] | ||||
|     splits_sum = sum(splits) | ||||
|     assert splits_sum > 0.0 | ||||
|     splits = [split / splits_sum for split in splits] | ||||
|     splits_index = [0] | ||||
|     for index, split in enumerate(splits): | ||||
|         splits_index.append(splits_index[index] + | ||||
|                             int(round(split * float(size)))) | ||||
|         splits_index.append(splits_index[index] + int(round(split * float(size)))) | ||||
|     diff = splits_index[-1] - size | ||||
|     for index in range(1, len(splits_index)): | ||||
|         splits_index[index] -= diff | ||||
|   | ||||
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							| @@ -2,12 +2,11 @@ import itertools | ||||
| import random | ||||
|  | ||||
| import numpy as np | ||||
| from torch.utils.data import Dataset | ||||
|  | ||||
| from megatron import get_tokenizer | ||||
| from megatron import get_args | ||||
| from megatron import get_args, get_tokenizer | ||||
| from megatron.data.dataset_utils import get_indexed_dataset_ | ||||
| from megatron.data.realm_dataset_utils import get_block_samples_mapping | ||||
| from torch.utils.data import Dataset | ||||
|  | ||||
|  | ||||
| def make_attention_mask(source_block, target_block): | ||||
|     """ | ||||
| @@ -20,16 +19,17 @@ def make_attention_mask(source_block, target_block): | ||||
|     # (source_length, target_length) | ||||
|     return mask | ||||
|  | ||||
|  | ||||
| def get_ict_dataset(use_titles=True, query_in_block_prob=1): | ||||
|     """Get a dataset which uses block samples mappings to get ICT/block indexing data (via get_block()) | ||||
|     rather than for training, since it is only built with a single epoch sample mapping. | ||||
|     """ | ||||
|     args = get_args() | ||||
|     block_dataset = get_indexed_dataset_(args.data_path, 'mmap', True) | ||||
|     titles_dataset = get_indexed_dataset_(args.titles_data_path, 'mmap', True) | ||||
|     block_dataset = get_indexed_dataset_(args.data_path, "mmap", True) | ||||
|     titles_dataset = get_indexed_dataset_(args.titles_data_path, "mmap", True) | ||||
|  | ||||
|     kwargs = dict( | ||||
|         name='full', | ||||
|         name="full", | ||||
|         block_dataset=block_dataset, | ||||
|         title_dataset=titles_dataset, | ||||
|         data_prefix=args.data_path, | ||||
| @@ -39,7 +39,7 @@ def get_ict_dataset(use_titles=True, query_in_block_prob=1): | ||||
|         seed=1, | ||||
|         query_in_block_prob=query_in_block_prob, | ||||
|         use_titles=use_titles, | ||||
|         use_one_sent_docs=args.use_one_sent_docs | ||||
|         use_one_sent_docs=args.use_one_sent_docs, | ||||
|     ) | ||||
|     dataset = ICTDataset(**kwargs) | ||||
|     return dataset | ||||
| @@ -47,9 +47,22 @@ def get_ict_dataset(use_titles=True, query_in_block_prob=1): | ||||
|  | ||||
| class ICTDataset(Dataset): | ||||
|     """Dataset containing sentences and their blocks for an inverse cloze task.""" | ||||
|     def __init__(self, name, block_dataset, title_dataset, data_prefix, | ||||
|                  num_epochs, max_num_samples, max_seq_length, query_in_block_prob, | ||||
|                  seed, use_titles=True, use_one_sent_docs=False, binary_head=False): | ||||
|  | ||||
|     def __init__( | ||||
|         self, | ||||
|         name, | ||||
|         block_dataset, | ||||
|         title_dataset, | ||||
|         data_prefix, | ||||
|         num_epochs, | ||||
|         max_num_samples, | ||||
|         max_seq_length, | ||||
|         query_in_block_prob, | ||||
|         seed, | ||||
|         use_titles=True, | ||||
|         use_one_sent_docs=False, | ||||
|         binary_head=False, | ||||
|     ): | ||||
|         self.name = name | ||||
|         self.seed = seed | ||||
|         self.max_seq_length = max_seq_length | ||||
| @@ -61,8 +74,16 @@ class ICTDataset(Dataset): | ||||
|         self.use_one_sent_docs = use_one_sent_docs | ||||
|  | ||||
|         self.samples_mapping = get_block_samples_mapping( | ||||
|             block_dataset, title_dataset, data_prefix, num_epochs, | ||||
|             max_num_samples, max_seq_length, seed, name, use_one_sent_docs) | ||||
|             block_dataset, | ||||
|             title_dataset, | ||||
|             data_prefix, | ||||
|             num_epochs, | ||||
|             max_num_samples, | ||||
|             max_seq_length, | ||||
|             seed, | ||||
|             name, | ||||
|             use_one_sent_docs, | ||||
|         ) | ||||
|         self.tokenizer = get_tokenizer() | ||||
|         self.vocab_id_list = list(self.tokenizer.inv_vocab.keys()) | ||||
|         self.vocab_id_to_token_list = self.tokenizer.inv_vocab | ||||
| @@ -99,8 +120,8 @@ class ICTDataset(Dataset): | ||||
|  | ||||
|         # still need to truncate because blocks are concluded when | ||||
|         # the sentence lengths have exceeded max_seq_length. | ||||
|         query = query[:self.max_seq_length - 2] | ||||
|         block = list(itertools.chain(*block))[:self.max_seq_length - title_pad_offset] | ||||
|         query = query[: self.max_seq_length - 2] | ||||
|         block = list(itertools.chain(*block))[: self.max_seq_length - title_pad_offset] | ||||
|  | ||||
|         query_tokens, query_pad_mask = self.concat_and_pad_tokens(query) | ||||
|         context_tokens, context_pad_mask = self.concat_and_pad_tokens(block, title) | ||||
| @@ -111,13 +132,13 @@ class ICTDataset(Dataset): | ||||
|         block_data = sample_data.as_array() | ||||
|  | ||||
|         sample = { | ||||
|             'query_tokens': query_tokens, | ||||
|             'query_mask': query_mask, | ||||
|             'query_pad_mask': query_pad_mask, | ||||
|             'context_tokens': context_tokens, | ||||
|             'context_mask': context_mask, | ||||
|             'context_pad_mask': context_pad_mask, | ||||
|             'block_data': block_data, | ||||
|             "query_tokens": query_tokens, | ||||
|             "query_mask": query_mask, | ||||
|             "query_pad_mask": query_pad_mask, | ||||
|             "context_tokens": context_tokens, | ||||
|             "context_mask": context_mask, | ||||
|             "context_pad_mask": context_pad_mask, | ||||
|             "block_data": block_data, | ||||
|         } | ||||
|  | ||||
|         return sample | ||||
| @@ -127,7 +148,7 @@ class ICTDataset(Dataset): | ||||
|         block = [self.block_dataset[i] for i in range(start_idx, end_idx)] | ||||
|         title = self.title_dataset[int(doc_idx)] | ||||
|  | ||||
|         block = list(itertools.chain(*block))[:self.max_seq_length - (3 + len(title))] | ||||
|         block = list(itertools.chain(*block))[: self.max_seq_length - (3 + len(title))] | ||||
|         block_tokens, block_pad_mask = self.concat_and_pad_tokens(block, title) | ||||
|  | ||||
|         return block_tokens, block_pad_mask | ||||
|   | ||||
| @@ -27,17 +27,17 @@ def __best_fitting_dtype(vocab_size=None): | ||||
|  | ||||
|  | ||||
| def get_available_dataset_impl(): | ||||
|     return ['lazy', 'cached', 'mmap'] | ||||
|     return ["lazy", "cached", "mmap"] | ||||
|  | ||||
|  | ||||
| def infer_dataset_impl(path): | ||||
|     if IndexedDataset.exists(path): | ||||
|         with open(index_file_path(path), 'rb') as f: | ||||
|         with open(index_file_path(path), "rb") as f: | ||||
|             magic = f.read(8) | ||||
|             if magic == IndexedDataset._HDR_MAGIC: | ||||
|                 return 'cached' | ||||
|                 return "cached" | ||||
|             elif magic == MMapIndexedDataset.Index._HDR_MAGIC[:8]: | ||||
|                 return 'mmap' | ||||
|                 return "mmap" | ||||
|             else: | ||||
|                 return None | ||||
|     else: | ||||
| @@ -47,7 +47,7 @@ def infer_dataset_impl(path): | ||||
|  | ||||
|  | ||||
| def make_builder(out_file, impl, vocab_size=None): | ||||
|     if impl == 'mmap': | ||||
|     if impl == "mmap": | ||||
|         return MMapIndexedDatasetBuilder(out_file, dtype=__best_fitting_dtype(vocab_size)) | ||||
|     else: | ||||
|         return IndexedDatasetBuilder(out_file) | ||||
| @@ -58,20 +58,20 @@ def make_dataset(path, impl, skip_warmup=False): | ||||
|         print(f"Dataset does not exist: {path}") | ||||
|         print("Path should be a basename that both .idx and .bin can be appended to get full filenames.") | ||||
|         return None | ||||
|     if impl == 'infer': | ||||
|     if impl == "infer": | ||||
|         impl = infer_dataset_impl(path) | ||||
|     if impl == 'lazy' and IndexedDataset.exists(path): | ||||
|     if impl == "lazy" and IndexedDataset.exists(path): | ||||
|         return IndexedDataset(path) | ||||
|     elif impl == 'cached' and IndexedDataset.exists(path): | ||||
|     elif impl == "cached" and IndexedDataset.exists(path): | ||||
|         return IndexedCachedDataset(path) | ||||
|     elif impl == 'mmap' and MMapIndexedDataset.exists(path): | ||||
|     elif impl == "mmap" and MMapIndexedDataset.exists(path): | ||||
|         return MMapIndexedDataset(path, skip_warmup) | ||||
|     print(f"Unknown dataset implementation: {impl}") | ||||
|     return None | ||||
|  | ||||
|  | ||||
| def dataset_exists(path, impl): | ||||
|     if impl == 'mmap': | ||||
|     if impl == "mmap": | ||||
|         return MMapIndexedDataset.exists(path) | ||||
|     else: | ||||
|         return IndexedDataset.exists(path) | ||||
| @@ -98,11 +98,11 @@ def code(dtype): | ||||
|  | ||||
|  | ||||
| def index_file_path(prefix_path): | ||||
|     return prefix_path + '.idx' | ||||
|     return prefix_path + ".idx" | ||||
|  | ||||
|  | ||||
| def data_file_path(prefix_path): | ||||
|     return prefix_path + '.bin' | ||||
|     return prefix_path + ".bin" | ||||
|  | ||||
|  | ||||
| def create_doc_idx(sizes): | ||||
| @@ -115,7 +115,8 @@ def create_doc_idx(sizes): | ||||
|  | ||||
| class IndexedDataset(torch.utils.data.Dataset): | ||||
|     """Loader for IndexedDataset""" | ||||
|     _HDR_MAGIC = b'TNTIDX\x00\x00' | ||||
|  | ||||
|     _HDR_MAGIC = b"TNTIDX\x00\x00" | ||||
|  | ||||
|     def __init__(self, path): | ||||
|         super().__init__() | ||||
| @@ -124,27 +125,28 @@ class IndexedDataset(torch.utils.data.Dataset): | ||||
|         self.read_index(path) | ||||
|  | ||||
|     def read_index(self, path): | ||||
|         with open(index_file_path(path), 'rb') as f: | ||||
|         with open(index_file_path(path), "rb") as f: | ||||
|             magic = f.read(8) | ||||
|             assert magic == self._HDR_MAGIC, ('Index file doesn\'t match expected format. ' | ||||
|                                               'Make sure that --dataset-impl is configured properly.') | ||||
|             assert magic == self._HDR_MAGIC, ( | ||||
|                 "Index file doesn't match expected format. " "Make sure that --dataset-impl is configured properly." | ||||
|             ) | ||||
|             version = f.read(8) | ||||
|             assert struct.unpack('<Q', version) == (1,) | ||||
|             code, self.element_size = struct.unpack('<QQ', f.read(16)) | ||||
|             assert struct.unpack("<Q", version) == (1,) | ||||
|             code, self.element_size = struct.unpack("<QQ", f.read(16)) | ||||
|             self.dtype = dtypes[code] | ||||
|             self._len, self.s = struct.unpack('<QQ', f.read(16)) | ||||
|             self.doc_count = struct.unpack('<Q', f.read(8)) | ||||
|             self._len, self.s = struct.unpack("<QQ", f.read(16)) | ||||
|             self.doc_count = struct.unpack("<Q", f.read(8)) | ||||
|             self.dim_offsets = read_longs(f, self._len + 1) | ||||
|             self.data_offsets = read_longs(f, self._len + 1) | ||||
|             self.sizes = read_longs(f, self.s) | ||||
|             self.doc_idx = read_longs(f, self.doc_count) | ||||
|  | ||||
|     def read_data(self, path): | ||||
|         self.data_file = open(data_file_path(path), 'rb', buffering=0) | ||||
|         self.data_file = open(data_file_path(path), "rb", buffering=0) | ||||
|  | ||||
|     def check_index(self, i): | ||||
|         if i < 0 or i >= self._len: | ||||
|             raise IndexError('index out of range') | ||||
|             raise IndexError("index out of range") | ||||
|  | ||||
|     def __del__(self): | ||||
|         if self.data_file: | ||||
| @@ -157,7 +159,7 @@ class IndexedDataset(torch.utils.data.Dataset): | ||||
|         if isinstance(idx, int): | ||||
|             i = idx | ||||
|             self.check_index(i) | ||||
|             tensor_size = self.sizes[self.dim_offsets[i]:self.dim_offsets[i + 1]] | ||||
|             tensor_size = self.sizes[self.dim_offsets[i] : self.dim_offsets[i + 1]] | ||||
|             a = np.empty(tensor_size, dtype=self.dtype) | ||||
|             self.data_file.seek(self.data_offsets[i] * self.element_size) | ||||
|             self.data_file.readinto(a) | ||||
| @@ -166,7 +168,7 @@ class IndexedDataset(torch.utils.data.Dataset): | ||||
|             start, stop, step = idx.indices(len(self)) | ||||
|             if step != 1: | ||||
|                 raise ValueError("Slices into indexed_dataset must be contiguous") | ||||
|             sizes = self.sizes[self.dim_offsets[start]:self.dim_offsets[stop]] | ||||
|             sizes = self.sizes[self.dim_offsets[start] : self.dim_offsets[stop]] | ||||
|             size = sum(sizes) | ||||
|             a = np.empty(size, dtype=self.dtype) | ||||
|             self.data_file.seek(self.data_offsets[start] * self.element_size) | ||||
| @@ -186,15 +188,14 @@ class IndexedDataset(torch.utils.data.Dataset): | ||||
|  | ||||
|     @staticmethod | ||||
|     def exists(path): | ||||
|         return (os.path.exists(index_file_path(path)) and os.path.exists(data_file_path(path))) | ||||
|         return os.path.exists(index_file_path(path)) and os.path.exists(data_file_path(path)) | ||||
|  | ||||
|     @property | ||||
|     def supports_prefetch(self): | ||||
|         return False    # avoid prefetching to save memory | ||||
|         return False  # avoid prefetching to save memory | ||||
|  | ||||
|  | ||||
| class IndexedCachedDataset(IndexedDataset): | ||||
|  | ||||
|     def __init__(self, path): | ||||
|         super().__init__(path) | ||||
|         self.cache = None | ||||
| @@ -219,7 +220,7 @@ class IndexedCachedDataset(IndexedDataset): | ||||
|         for i in indices: | ||||
|             self.cache_index[i] = ptx | ||||
|             size = self.data_offsets[i + 1] - self.data_offsets[i] | ||||
|             a = self.cache[ptx:ptx + size] | ||||
|             a = self.cache[ptx : ptx + size] | ||||
|             self.data_file.seek(self.data_offsets[i] * self.element_size) | ||||
|             self.data_file.readinto(a) | ||||
|             ptx += size | ||||
| @@ -233,10 +234,10 @@ class IndexedCachedDataset(IndexedDataset): | ||||
|         if isinstance(idx, int): | ||||
|             i = idx | ||||
|             self.check_index(i) | ||||
|             tensor_size = self.sizes[self.dim_offsets[i]:self.dim_offsets[i + 1]] | ||||
|             tensor_size = self.sizes[self.dim_offsets[i] : self.dim_offsets[i + 1]] | ||||
|             a = np.empty(tensor_size, dtype=self.dtype) | ||||
|             ptx = self.cache_index[i] | ||||
|             np.copyto(a, self.cache[ptx:ptx + a.size]) | ||||
|             np.copyto(a, self.cache[ptx : ptx + a.size]) | ||||
|             return a | ||||
|         elif isinstance(idx, slice): | ||||
|             # Hack just to make this work, can optimizer later if necessary | ||||
| @@ -250,7 +251,7 @@ class IndexedDatasetBuilder(object): | ||||
|     element_sizes = {np.uint8: 1, np.int8: 1, np.int16: 2, np.int32: 4, np.int64: 8, float: 4, np.double: 8} | ||||
|  | ||||
|     def __init__(self, out_file, dtype=np.int32): | ||||
|         self.out_file = open(out_file, 'wb') | ||||
|         self.out_file = open(out_file, "wb") | ||||
|         self.dtype = dtype | ||||
|         self.data_offsets = [0] | ||||
|         self.dim_offsets = [0] | ||||
| @@ -280,7 +281,7 @@ class IndexedDatasetBuilder(object): | ||||
|         for dim_offset in index.dim_offsets[1:]: | ||||
|             self.dim_offsets.append(begin + dim_offset) | ||||
|  | ||||
|         with open(data_file_path(another_file), 'rb') as f: | ||||
|         with open(data_file_path(another_file), "rb") as f: | ||||
|             while True: | ||||
|                 data = f.read(1024) | ||||
|                 if data: | ||||
| @@ -290,12 +291,12 @@ class IndexedDatasetBuilder(object): | ||||
|  | ||||
|     def finalize(self, index_file): | ||||
|         self.out_file.close() | ||||
|         index = open(index_file, 'wb') | ||||
|         index.write(b'TNTIDX\x00\x00') | ||||
|         index.write(struct.pack('<Q', 1)) | ||||
|         index.write(struct.pack('<QQ', code(self.dtype), self.element_size)) | ||||
|         index.write(struct.pack('<QQ', len(self.data_offsets) - 1, len(self.sizes))) | ||||
|         index.write(struct.pack('<Q', len(self.doc_idx))) | ||||
|         index = open(index_file, "wb") | ||||
|         index.write(b"TNTIDX\x00\x00") | ||||
|         index.write(struct.pack("<Q", 1)) | ||||
|         index.write(struct.pack("<QQ", code(self.dtype), self.element_size)) | ||||
|         index.write(struct.pack("<QQ", len(self.data_offsets) - 1, len(self.sizes))) | ||||
|         index.write(struct.pack("<Q", len(self.doc_idx))) | ||||
|         write_longs(index, self.dim_offsets) | ||||
|         write_longs(index, self.data_offsets) | ||||
|         write_longs(index, self.sizes) | ||||
| @@ -304,27 +305,24 @@ class IndexedDatasetBuilder(object): | ||||
|  | ||||
|  | ||||
| def _warmup_mmap_file(path): | ||||
|     with open(path, 'rb') as stream: | ||||
|     with open(path, "rb") as stream: | ||||
|         while stream.read(100 * 1024 * 1024): | ||||
|             pass | ||||
|  | ||||
|  | ||||
| class MMapIndexedDataset(torch.utils.data.Dataset): | ||||
|  | ||||
|     class Index(object): | ||||
|         _HDR_MAGIC = b'MMIDIDX\x00\x00' | ||||
|         _HDR_MAGIC = b"MMIDIDX\x00\x00" | ||||
|  | ||||
|         @classmethod | ||||
|         def writer(cls, path, dtype): | ||||
|  | ||||
|             class _Writer(object): | ||||
|  | ||||
|                 def __enter__(self): | ||||
|                     self._file = open(path, 'wb') | ||||
|                     self._file = open(path, "wb") | ||||
|  | ||||
|                     self._file.write(cls._HDR_MAGIC) | ||||
|                     self._file.write(struct.pack('<Q', 1)) | ||||
|                     self._file.write(struct.pack('<B', code(dtype))) | ||||
|                     self._file.write(struct.pack("<Q", 1)) | ||||
|                     self._file.write(struct.pack("<B", code(dtype))) | ||||
|  | ||||
|                     return self | ||||
|  | ||||
| @@ -343,19 +341,19 @@ class MMapIndexedDataset(torch.utils.data.Dataset): | ||||
|                 def write(self, sizes, doc_idx): | ||||
|                     pointers = self._get_pointers(sizes) | ||||
|  | ||||
|                     self._file.write(struct.pack('<Q', len(sizes))) | ||||
|                     self._file.write(struct.pack('<Q', len(doc_idx))) | ||||
|                     self._file.write(struct.pack("<Q", len(sizes))) | ||||
|                     self._file.write(struct.pack("<Q", len(doc_idx))) | ||||
|  | ||||
|                     sizes = np.array(sizes, dtype=np.int32) | ||||
|                     self._file.write(sizes.tobytes(order='C')) | ||||
|                     self._file.write(sizes.tobytes(order="C")) | ||||
|                     del sizes | ||||
|  | ||||
|                     pointers = np.array(pointers, dtype=np.int64) | ||||
|                     self._file.write(pointers.tobytes(order='C')) | ||||
|                     self._file.write(pointers.tobytes(order="C")) | ||||
|                     del pointers | ||||
|  | ||||
|                     doc_idx = np.array(doc_idx, dtype=np.int64) | ||||
|                     self._file.write(doc_idx.tobytes(order='C')) | ||||
|                     self._file.write(doc_idx.tobytes(order="C")) | ||||
|  | ||||
|                 def __exit__(self, exc_type, exc_val, exc_tb): | ||||
|                     self._file.close() | ||||
| @@ -363,39 +361,41 @@ class MMapIndexedDataset(torch.utils.data.Dataset): | ||||
|             return _Writer() | ||||
|  | ||||
|         def __init__(self, path, skip_warmup=False): | ||||
|             with open(path, 'rb') as stream: | ||||
|             with open(path, "rb") as stream: | ||||
|                 magic_test = stream.read(9) | ||||
|                 assert self._HDR_MAGIC == magic_test, ('Index file doesn\'t match expected format. ' | ||||
|                                                        'Make sure that --dataset-impl is configured properly.') | ||||
|                 version = struct.unpack('<Q', stream.read(8)) | ||||
|                 assert self._HDR_MAGIC == magic_test, ( | ||||
|                     "Index file doesn't match expected format. " "Make sure that --dataset-impl is configured properly." | ||||
|                 ) | ||||
|                 version = struct.unpack("<Q", stream.read(8)) | ||||
|                 assert (1,) == version | ||||
|  | ||||
|                 dtype_code, = struct.unpack('<B', stream.read(1)) | ||||
|                 (dtype_code,) = struct.unpack("<B", stream.read(1)) | ||||
|                 self._dtype = dtypes[dtype_code] | ||||
|                 self._dtype_size = self._dtype().itemsize | ||||
|  | ||||
|                 self._len = struct.unpack('<Q', stream.read(8))[0] | ||||
|                 self._doc_count = struct.unpack('<Q', stream.read(8))[0] | ||||
|                 self._len = struct.unpack("<Q", stream.read(8))[0] | ||||
|                 self._doc_count = struct.unpack("<Q", stream.read(8))[0] | ||||
|                 offset = stream.tell() | ||||
|  | ||||
|             if not skip_warmup: | ||||
|                 print("    warming up index mmap file...") | ||||
|                 _warmup_mmap_file(path) | ||||
|  | ||||
|             self._bin_buffer_mmap = np.memmap(path, mode='r', order='C') | ||||
|             self._bin_buffer_mmap = np.memmap(path, mode="r", order="C") | ||||
|             self._bin_buffer = memoryview(self._bin_buffer_mmap) | ||||
|             print("    reading sizes...") | ||||
|             self._sizes = np.frombuffer(self._bin_buffer, dtype=np.int32, count=self._len, offset=offset) | ||||
|             print("    reading pointers...") | ||||
|             self._pointers = np.frombuffer(self._bin_buffer, | ||||
|                                            dtype=np.int64, | ||||
|                                            count=self._len, | ||||
|                                            offset=offset + self._sizes.nbytes) | ||||
|             self._pointers = np.frombuffer( | ||||
|                 self._bin_buffer, dtype=np.int64, count=self._len, offset=offset + self._sizes.nbytes | ||||
|             ) | ||||
|             print("    reading document index...") | ||||
|             self._doc_idx = np.frombuffer(self._bin_buffer, | ||||
|                                           dtype=np.int64, | ||||
|                                           count=self._doc_count, | ||||
|                                           offset=offset + self._sizes.nbytes + self._pointers.nbytes) | ||||
|             self._doc_idx = np.frombuffer( | ||||
|                 self._bin_buffer, | ||||
|                 dtype=np.int64, | ||||
|                 count=self._doc_count, | ||||
|                 offset=offset + self._sizes.nbytes + self._pointers.nbytes, | ||||
|             ) | ||||
|  | ||||
|         def __del__(self): | ||||
|             self._bin_buffer_mmap._mmap.close() | ||||
| @@ -443,7 +443,7 @@ class MMapIndexedDataset(torch.utils.data.Dataset): | ||||
|             print("    warming up data mmap file...") | ||||
|             _warmup_mmap_file(data_file_path(self._path)) | ||||
|         print("    creating numpy buffer of mmap...") | ||||
|         self._bin_buffer_mmap = np.memmap(data_file_path(self._path), mode='r', order='C') | ||||
|         self._bin_buffer_mmap = np.memmap(data_file_path(self._path), mode="r", order="C") | ||||
|         print("    creating memory view of numpy buffer...") | ||||
|         self._bin_buffer = memoryview(self._bin_buffer_mmap) | ||||
|  | ||||
| @@ -474,7 +474,7 @@ class MMapIndexedDataset(torch.utils.data.Dataset): | ||||
|             return sents | ||||
|  | ||||
|     def get(self, idx, offset=0, length=None): | ||||
|         """ Retrieves a single item from the dataset with the option to only | ||||
|         """Retrieves a single item from the dataset with the option to only | ||||
|         return a portion of the item. | ||||
|  | ||||
|         get(idx) is the same as [idx] but get() does not support slicing. | ||||
| @@ -506,20 +506,19 @@ class MMapIndexedDataset(torch.utils.data.Dataset): | ||||
|  | ||||
|     @staticmethod | ||||
|     def exists(path): | ||||
|         return (os.path.exists(index_file_path(path)) and os.path.exists(data_file_path(path))) | ||||
|         return os.path.exists(index_file_path(path)) and os.path.exists(data_file_path(path)) | ||||
|  | ||||
|  | ||||
| class MMapIndexedDatasetBuilder(object): | ||||
|  | ||||
|     def __init__(self, out_file, dtype=np.int64): | ||||
|         self._data_file = open(out_file, 'wb') | ||||
|         self._data_file = open(out_file, "wb") | ||||
|         self._dtype = dtype | ||||
|         self._sizes = [] | ||||
|         self._doc_idx = [0] | ||||
|  | ||||
|     def add_item(self, tensor): | ||||
|         np_array = np.array(tensor.numpy(), dtype=self._dtype) | ||||
|         self._data_file.write(np_array.tobytes(order='C')) | ||||
|         self._data_file.write(np_array.tobytes(order="C")) | ||||
|         self._sizes.append(np_array.size) | ||||
|  | ||||
|     def end_document(self): | ||||
| @@ -534,7 +533,7 @@ class MMapIndexedDatasetBuilder(object): | ||||
|             self._sizes.append(size) | ||||
|  | ||||
|         # Concatenate data | ||||
|         with open(data_file_path(another_file), 'rb') as f: | ||||
|         with open(data_file_path(another_file), "rb") as f: | ||||
|             shutil.copyfileobj(f, self._data_file) | ||||
|  | ||||
|     def finalize(self, index_file): | ||||
|   | ||||
| @@ -2,13 +2,12 @@ | ||||
| # put some code used during development and manual testing of | ||||
| # indexed_dataset. | ||||
|  | ||||
| from megatron.data import indexed_dataset | ||||
| from megatron.tokenizer import build_tokenizer | ||||
| import argparse | ||||
| import os | ||||
| import sys | ||||
|  | ||||
| import torch | ||||
| from megatron.data import indexed_dataset | ||||
| from megatron.tokenizer import build_tokenizer | ||||
|  | ||||
| script_dir = os.path.dirname(os.path.realpath(__file__)) | ||||
| sys.path.append(os.path.join(script_dir, "../../../")) | ||||
| @@ -42,7 +41,7 @@ def test_indexed_dataset(args): | ||||
|  | ||||
| def test_indexed_dataset_get(args): | ||||
|     ds = indexed_dataset.make_dataset(args.data, args.dataset_impl) | ||||
|     tokenizer = build_tokenizer(args) | ||||
|     build_tokenizer(args) | ||||
|     size = ds.sizes[0] | ||||
|     print(f"size: {size}") | ||||
|     full = ds.get(0) | ||||
| @@ -61,6 +60,7 @@ def test_indexed_dataset_get(args): | ||||
|     print(part) | ||||
|     # print(tokenizer.detokenize(part.data.tolist())) | ||||
|  | ||||
|  | ||||
| # def test_albert_dataset(args): | ||||
| #     # tokenizer = FullBertTokenizer(args.vocab, do_lower_case=True) | ||||
| #     # idataset = indexed_dataset.make_dataset(args.data, args.dataset_impl) | ||||
| @@ -81,34 +81,27 @@ def test_indexed_dataset_get(args): | ||||
|  | ||||
| def main(): | ||||
|     parser = argparse.ArgumentParser() | ||||
|     parser.add_argument('--data', type=str, help='prefix to data files') | ||||
|     parser.add_argument('--dataset-impl', type=str, default='infer', | ||||
|                         choices=['lazy', 'cached', 'mmap', 'infer']) | ||||
|     parser.add_argument('--count', type=int, default=10, | ||||
|                         help='Number of samples/documents to print') | ||||
|     parser.add_argument("--data", type=str, help="prefix to data files") | ||||
|     parser.add_argument("--dataset-impl", type=str, default="infer", choices=["lazy", "cached", "mmap", "infer"]) | ||||
|     parser.add_argument("--count", type=int, default=10, help="Number of samples/documents to print") | ||||
|  | ||||
|     group = parser.add_argument_group(title='tokenizer') | ||||
|     group.add_argument('--tokenizer-type', type=str, required=True, | ||||
|                        choices=['BertWordPieceLowerCase', | ||||
|                                 'GPT2BPETokenizer'], | ||||
|                        help='What type of tokenizer to use.') | ||||
|     group.add_argument('--vocab-file', type=str, default=None, | ||||
|                        help='Path to the vocab file') | ||||
|     group.add_argument('--merge-file', type=str, default=None, | ||||
|                        help='Path to the BPE merge file (if necessary).') | ||||
|     group = parser.add_argument_group(title="tokenizer") | ||||
|     group.add_argument( | ||||
|         "--tokenizer-type", | ||||
|         type=str, | ||||
|         required=True, | ||||
|         choices=["BertWordPieceLowerCase", "GPT2BPETokenizer"], | ||||
|         help="What type of tokenizer to use.", | ||||
|     ) | ||||
|     group.add_argument("--vocab-file", type=str, default=None, help="Path to the vocab file") | ||||
|     group.add_argument("--merge-file", type=str, default=None, help="Path to the BPE merge file (if necessary).") | ||||
|  | ||||
|     parser.add_argument('--epochs', type=int, default=5, | ||||
|                         help='Number of epochs to plan for') | ||||
|     parser.add_argument('--max-num-samples', type=int, default=None, | ||||
|                         help='Maximum number of samples to plan for') | ||||
|     parser.add_argument('--masked-lm-prob', type=float, default=0.15, | ||||
|                         help='probability of masking tokens') | ||||
|     parser.add_argument('--seq-length', type=int, default=512, | ||||
|                         help='maximum sequence length') | ||||
|     parser.add_argument('--short-seq-prob', type=float, default=0.1, | ||||
|                         help='probability of creating a short sequence') | ||||
|     parser.add_argument('--seed', type=int, default=1234, | ||||
|                         help='random seed') | ||||
|     parser.add_argument("--epochs", type=int, default=5, help="Number of epochs to plan for") | ||||
|     parser.add_argument("--max-num-samples", type=int, default=None, help="Maximum number of samples to plan for") | ||||
|     parser.add_argument("--masked-lm-prob", type=float, default=0.15, help="probability of masking tokens") | ||||
|     parser.add_argument("--seq-length", type=int, default=512, help="maximum sequence length") | ||||
|     parser.add_argument("--short-seq-prob", type=float, default=0.1, help="probability of creating a short sequence") | ||||
|     parser.add_argument("--seed", type=int, default=1234, help="random seed") | ||||
|     args = parser.parse_args() | ||||
|     args.rank = 0 | ||||
|     args.make_vocab_size_divisible_by = 128 | ||||
| @@ -117,7 +110,7 @@ def main(): | ||||
|     if args.dataset_impl == "infer": | ||||
|         args.dataset_impl = indexed_dataset.infer_dataset_impl(args.data) | ||||
|  | ||||
| #    test_albert_dataset(args) | ||||
|     #    test_albert_dataset(args) | ||||
|     test_indexed_dataset_get(args) | ||||
|  | ||||
|  | ||||
|   | ||||
| @@ -1,8 +1,7 @@ | ||||
| import torch | ||||
|  | ||||
|  | ||||
| class DummyDataloader(): | ||||
|  | ||||
| class DummyDataloader: | ||||
|     def __init__(self, batch_size, vocab_size, seq_length): | ||||
|         self.batch_size = batch_size | ||||
|         self.vocab_size = vocab_size | ||||
| @@ -10,30 +9,44 @@ class DummyDataloader(): | ||||
|         self.step = 0 | ||||
|  | ||||
|     def generate(self): | ||||
|         tokens = torch.randint(low=0, high=self.vocab_size, size=( | ||||
|             self.batch_size, | ||||
|             self.seq_length, | ||||
|         )) | ||||
|         types = torch.randint(low=0, high=3, size=( | ||||
|             self.batch_size, | ||||
|             self.seq_length, | ||||
|         )) | ||||
|         tokens = torch.randint( | ||||
|             low=0, | ||||
|             high=self.vocab_size, | ||||
|             size=( | ||||
|                 self.batch_size, | ||||
|                 self.seq_length, | ||||
|             ), | ||||
|         ) | ||||
|         types = torch.randint( | ||||
|             low=0, | ||||
|             high=3, | ||||
|             size=( | ||||
|                 self.batch_size, | ||||
|                 self.seq_length, | ||||
|             ), | ||||
|         ) | ||||
|         sentence_order = torch.randint(low=0, high=2, size=(self.batch_size,)) | ||||
|         loss_mask = torch.randint(low=0, high=2, size=( | ||||
|             self.batch_size, | ||||
|             self.seq_length, | ||||
|         )) | ||||
|         loss_mask = torch.randint( | ||||
|             low=0, | ||||
|             high=2, | ||||
|             size=( | ||||
|                 self.batch_size, | ||||
|                 self.seq_length, | ||||
|             ), | ||||
|         ) | ||||
|         lm_labels = torch.randint(low=0, high=self.vocab_size, size=(self.batch_size, self.seq_length)) | ||||
|         padding_mask = torch.randint(low=0, high=2, size=(self.batch_size, self.seq_length)) | ||||
|         return dict(text=tokens, | ||||
|                     types=types, | ||||
|                     is_random=sentence_order, | ||||
|                     loss_mask=loss_mask, | ||||
|                     labels=lm_labels, | ||||
|                     padding_mask=padding_mask) | ||||
|         return dict( | ||||
|             text=tokens, | ||||
|             types=types, | ||||
|             is_random=sentence_order, | ||||
|             loss_mask=loss_mask, | ||||
|             labels=lm_labels, | ||||
|             padding_mask=padding_mask, | ||||
|         ) | ||||
|  | ||||
|     def __iter__(self): | ||||
|         return self | ||||
|  | ||||
|     def __next__(self): | ||||
|         return self.generate() | ||||
|         return self.generate() | ||||
|   | ||||
| @@ -16,7 +16,6 @@ | ||||
|  | ||||
| from .tokenizer import build_tokenizer | ||||
|  | ||||
|  | ||||
| _TOKENIZER = None | ||||
| _PADDED_VOCAB_SIZE = -1 | ||||
|  | ||||
|   | ||||
| @@ -15,13 +15,12 @@ | ||||
|  | ||||
| """Tokenization classes.""" | ||||
|  | ||||
| from __future__ import absolute_import | ||||
| from __future__ import division | ||||
| from __future__ import print_function | ||||
| from __future__ import absolute_import, division, print_function | ||||
|  | ||||
| import collections | ||||
| import re | ||||
| import unicodedata | ||||
|  | ||||
| import six | ||||
|  | ||||
|  | ||||
| @@ -43,14 +42,13 @@ def validate_case_matches_checkpoint(do_lower_case, init_checkpoint): | ||||
|     model_name = m.group(1) | ||||
|  | ||||
|     lower_models = [ | ||||
|         "uncased_L-24_H-1024_A-16", "uncased_L-12_H-768_A-12", | ||||
|         "multilingual_L-12_H-768_A-12", "chinese_L-12_H-768_A-12" | ||||
|         "uncased_L-24_H-1024_A-16", | ||||
|         "uncased_L-12_H-768_A-12", | ||||
|         "multilingual_L-12_H-768_A-12", | ||||
|         "chinese_L-12_H-768_A-12", | ||||
|     ] | ||||
|  | ||||
|     cased_models = [ | ||||
|         "cased_L-12_H-768_A-12", "cased_L-24_H-1024_A-16", | ||||
|         "multi_cased_L-12_H-768_A-12" | ||||
|     ] | ||||
|     cased_models = ["cased_L-12_H-768_A-12", "cased_L-24_H-1024_A-16", "multi_cased_L-12_H-768_A-12"] | ||||
|  | ||||
|     is_bad_config = False | ||||
|     if model_name in lower_models and not do_lower_case: | ||||
| @@ -71,8 +69,8 @@ def validate_case_matches_checkpoint(do_lower_case, init_checkpoint): | ||||
|             "However, `%s` seems to be a %s model, so you " | ||||
|             "should pass in `--do_lower_case=%s` so that the fine-tuning matches " | ||||
|             "how the model was pre-training. If this error is wrong, please " | ||||
|             "just comment out this check." % (actual_flag, init_checkpoint, | ||||
|                                               model_name, case_name, opposite_flag)) | ||||
|             "just comment out this check." % (actual_flag, init_checkpoint, model_name, case_name, opposite_flag) | ||||
|         ) | ||||
|  | ||||
|  | ||||
| def convert_to_unicode(text): | ||||
| @@ -183,27 +181,27 @@ class FullTokenizer(object): | ||||
|  | ||||
|     @staticmethod | ||||
|     def convert_tokens_to_string(tokens, clean_up_tokenization_spaces=True): | ||||
|         """ Converts a sequence of tokens (string) in a single string. """ | ||||
|         """Converts a sequence of tokens (string) in a single string.""" | ||||
|  | ||||
|         def clean_up_tokenization(out_string): | ||||
|             """ Clean up a list of simple English tokenization artifacts | ||||
|             """Clean up a list of simple English tokenization artifacts | ||||
|             like spaces before punctuations and abbreviated forms. | ||||
|             """ | ||||
|             out_string = ( | ||||
|                 out_string.replace(" .", ".") | ||||
|                     .replace(" ?", "?") | ||||
|                     .replace(" !", "!") | ||||
|                     .replace(" ,", ",") | ||||
|                     .replace(" ' ", "'") | ||||
|                     .replace(" n't", "n't") | ||||
|                     .replace(" 'm", "'m") | ||||
|                     .replace(" 's", "'s") | ||||
|                     .replace(" 've", "'ve") | ||||
|                     .replace(" 're", "'re") | ||||
|                 .replace(" ?", "?") | ||||
|                 .replace(" !", "!") | ||||
|                 .replace(" ,", ",") | ||||
|                 .replace(" ' ", "'") | ||||
|                 .replace(" n't", "n't") | ||||
|                 .replace(" 'm", "'m") | ||||
|                 .replace(" 's", "'s") | ||||
|                 .replace(" 've", "'ve") | ||||
|                 .replace(" 're", "'re") | ||||
|             ) | ||||
|             return out_string | ||||
|  | ||||
|         text = ' '.join(tokens).replace(' ##', '').strip() | ||||
|         text = " ".join(tokens).replace(" ##", "").strip() | ||||
|         if clean_up_tokenization_spaces: | ||||
|             clean_text = clean_up_tokenization(text) | ||||
|             return clean_text | ||||
| @@ -303,14 +301,16 @@ class BasicTokenizer(object): | ||||
|         # as is Japanese Hiragana and Katakana. Those alphabets are used to write | ||||
|         # space-separated words, so they are not treated specially and handled | ||||
|         # like the all of the other languages. | ||||
|         if ((cp >= 0x4E00 and cp <= 0x9FFF) or  # | ||||
|             (cp >= 0x3400 and cp <= 0x4DBF) or  # | ||||
|             (cp >= 0x20000 and cp <= 0x2A6DF) or  # | ||||
|             (cp >= 0x2A700 and cp <= 0x2B73F) or  # | ||||
|             (cp >= 0x2B740 and cp <= 0x2B81F) or  # | ||||
|             (cp >= 0x2B820 and cp <= 0x2CEAF) or | ||||
|             (cp >= 0xF900 and cp <= 0xFAFF) or  # | ||||
|                 (cp >= 0x2F800 and cp <= 0x2FA1F)):  # | ||||
|         if ( | ||||
|             (cp >= 0x4E00 and cp <= 0x9FFF) | ||||
|             or (cp >= 0x3400 and cp <= 0x4DBF)  # | ||||
|             or (cp >= 0x20000 and cp <= 0x2A6DF)  # | ||||
|             or (cp >= 0x2A700 and cp <= 0x2B73F)  # | ||||
|             or (cp >= 0x2B740 and cp <= 0x2B81F)  # | ||||
|             or (cp >= 0x2B820 and cp <= 0x2CEAF)  # | ||||
|             or (cp >= 0xF900 and cp <= 0xFAFF) | ||||
|             or (cp >= 0x2F800 and cp <= 0x2FA1F)  # | ||||
|         ):  # | ||||
|             return True | ||||
|  | ||||
|         return False | ||||
| @@ -320,7 +320,7 @@ class BasicTokenizer(object): | ||||
|         output = [] | ||||
|         for char in text: | ||||
|             cp = ord(char) | ||||
|             if cp == 0 or cp == 0xfffd or _is_control(char): | ||||
|             if cp == 0 or cp == 0xFFFD or _is_control(char): | ||||
|                 continue | ||||
|             if _is_whitespace(char): | ||||
|                 output.append(" ") | ||||
| @@ -422,8 +422,7 @@ def _is_punctuation(char): | ||||
|     # Characters such as "^", "$", and "`" are not in the Unicode | ||||
|     # Punctuation class but we treat them as punctuation anyways, for | ||||
|     # consistency. | ||||
|     if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or | ||||
|             (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)): | ||||
|     if (cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126): | ||||
|         return True | ||||
|     cat = unicodedata.category(char) | ||||
|     if cat.startswith("P"): | ||||
|   | ||||
| @@ -25,16 +25,15 @@ from .bert_tokenization import FullTokenizer as FullBertTokenizer | ||||
| def build_tokenizer(vocab_file, tokenizer_type, vocab_extra_ids=0): | ||||
|     """Initialize tokenizer.""" | ||||
|     if not gpc.is_initialized(ParallelMode.GLOBAL) or gpc.get_global_rank() == 0: | ||||
|         print('> building {} tokenizer ...'.format(tokenizer_type), flush=True) | ||||
|         print("> building {} tokenizer ...".format(tokenizer_type), flush=True) | ||||
|  | ||||
|     # Select and instantiate the tokenizer. | ||||
|     if tokenizer_type == 'BertWordPieceLowerCase': | ||||
|     if tokenizer_type == "BertWordPieceLowerCase": | ||||
|         tokenizer = _BertWordPieceTokenizer(vocab_file=vocab_file, lower_case=True, vocab_extra_ids=vocab_extra_ids) | ||||
|     elif tokenizer_type == 'BertWordPieceCase': | ||||
|     elif tokenizer_type == "BertWordPieceCase": | ||||
|         tokenizer = _BertWordPieceTokenizer(vocab_file=vocab_file, lower_case=False, vocab_extra_ids=vocab_extra_ids) | ||||
|     else: | ||||
|         raise NotImplementedError('{} tokenizer is not ' | ||||
|                                   'implemented.'.format(tokenizer_type)) | ||||
|         raise NotImplementedError("{} tokenizer is not " "implemented.".format(tokenizer_type)) | ||||
|  | ||||
|     # Add vocab size. | ||||
|     padded_vocab_size = _vocab_size_with_padding(tokenizer.vocab_size) | ||||
| @@ -55,9 +54,11 @@ def _vocab_size_with_padding(orig_vocab_size, make_vocab_size_divisible_by=128): | ||||
|     while (after % multiple) != 0: | ||||
|         after += 1 | ||||
|     if not gpc.is_initialized(ParallelMode.GLOBAL) or gpc.get_global_rank() == 0: | ||||
|         print(' > padded vocab (size: {}) with {} dummy tokens ' | ||||
|               '(new size: {})'.format(orig_vocab_size, after - orig_vocab_size, after), | ||||
|               flush=True) | ||||
|         print( | ||||
|             " > padded vocab (size: {}) with {} dummy tokens " | ||||
|             "(new size: {})".format(orig_vocab_size, after - orig_vocab_size, after), | ||||
|             flush=True, | ||||
|         ) | ||||
|     return after | ||||
|  | ||||
|  | ||||
| @@ -77,46 +78,38 @@ class AbstractTokenizer(ABC): | ||||
|     @abstractmethod | ||||
|     def vocab(self): | ||||
|         """Dictionary from vocab text token to id token.""" | ||||
|         pass | ||||
|  | ||||
|     @property | ||||
|     @abstractmethod | ||||
|     def inv_vocab(self): | ||||
|         """Dictionary from vocab id token to text token.""" | ||||
|         pass | ||||
|  | ||||
|     @abstractmethod | ||||
|     def tokenize(self, text): | ||||
|         pass | ||||
|  | ||||
|     def detokenize(self, token_ids): | ||||
|         raise NotImplementedError('detokenizer is not implemented for {} ' | ||||
|                                   'tokenizer'.format(self.name)) | ||||
|         raise NotImplementedError("detokenizer is not implemented for {} " "tokenizer".format(self.name)) | ||||
|  | ||||
|     @property | ||||
|     def cls(self): | ||||
|         raise NotImplementedError('CLS is not provided for {} ' | ||||
|                                   'tokenizer'.format(self.name)) | ||||
|         raise NotImplementedError("CLS is not provided for {} " "tokenizer".format(self.name)) | ||||
|  | ||||
|     @property | ||||
|     def sep(self): | ||||
|         raise NotImplementedError('SEP is not provided for {} ' | ||||
|                                   'tokenizer'.format(self.name)) | ||||
|         raise NotImplementedError("SEP is not provided for {} " "tokenizer".format(self.name)) | ||||
|  | ||||
|     @property | ||||
|     def pad(self): | ||||
|         raise NotImplementedError('PAD is not provided for {} ' | ||||
|                                   'tokenizer'.format(self.name)) | ||||
|         raise NotImplementedError("PAD is not provided for {} " "tokenizer".format(self.name)) | ||||
|  | ||||
|     @property | ||||
|     def eod(self): | ||||
|         raise NotImplementedError('EOD is not provided for {} ' | ||||
|                                   'tokenizer'.format(self.name)) | ||||
|         raise NotImplementedError("EOD is not provided for {} " "tokenizer".format(self.name)) | ||||
|  | ||||
|     @property | ||||
|     def mask(self): | ||||
|         raise NotImplementedError('MASK is not provided for {} ' | ||||
|                                   'tokenizer'.format(self.name)) | ||||
|         raise NotImplementedError("MASK is not provided for {} " "tokenizer".format(self.name)) | ||||
|  | ||||
|  | ||||
| class _BertWordPieceTokenizer(AbstractTokenizer): | ||||
| @@ -124,24 +117,24 @@ class _BertWordPieceTokenizer(AbstractTokenizer): | ||||
|  | ||||
|     def __init__(self, vocab_file, lower_case=True, vocab_extra_ids=0): | ||||
|         if lower_case: | ||||
|             name = 'BERT Lower Case' | ||||
|             name = "BERT Lower Case" | ||||
|         else: | ||||
|             name = 'BERT Upper Case' | ||||
|             name = "BERT Upper Case" | ||||
|         super().__init__(name) | ||||
|         self.tokenizer = FullBertTokenizer(vocab_file, do_lower_case=lower_case) | ||||
|         self.cls_id = self.tokenizer.vocab['[CLS]'] | ||||
|         self.sep_id = self.tokenizer.vocab['[SEP]'] | ||||
|         self.pad_id = self.tokenizer.vocab['[PAD]'] | ||||
|         self.mask_id = self.tokenizer.vocab['[MASK]'] | ||||
|         self.cls_id = self.tokenizer.vocab["[CLS]"] | ||||
|         self.sep_id = self.tokenizer.vocab["[SEP]"] | ||||
|         self.pad_id = self.tokenizer.vocab["[PAD]"] | ||||
|         self.mask_id = self.tokenizer.vocab["[MASK]"] | ||||
|         self._additional_special_tokens = [] | ||||
|  | ||||
|         # (dsachan) Add BOS and EOS tokens | ||||
|         SPECIAL_TOKENS = {'eos_token': '[EOS]', 'bos_token': '[BOS]'} | ||||
|         self._bos_token = '[BOS]' | ||||
|         SPECIAL_TOKENS = {"eos_token": "[EOS]", "bos_token": "[BOS]"} | ||||
|         self._bos_token = "[BOS]" | ||||
|         self.add_token(self._bos_token) | ||||
|         self._bos_token_id = self.vocab.get(self._bos_token) | ||||
|  | ||||
|         self._eos_token = '[EOS]' | ||||
|         self._eos_token = "[EOS]" | ||||
|         self.add_token(self._eos_token) | ||||
|         self._eos_token_id = self.vocab.get(self._eos_token) | ||||
|  | ||||
| @@ -185,7 +178,7 @@ class _BertWordPieceTokenizer(AbstractTokenizer): | ||||
|  | ||||
|     def decode_token_ids(self, token_ids): | ||||
|         tokens = self.tokenizer.convert_ids_to_tokens(token_ids) | ||||
|         exclude_list = ['[PAD]', '[CLS]'] | ||||
|         exclude_list = ["[PAD]", "[CLS]"] | ||||
|         non_pads = [t for t in tokens if t not in exclude_list] | ||||
|  | ||||
|         result = "" | ||||
| @@ -215,32 +208,32 @@ class _BertWordPieceTokenizer(AbstractTokenizer): | ||||
|  | ||||
|     @property | ||||
|     def bos_token(self): | ||||
|         """ Beginning of sentence token id """ | ||||
|         """Beginning of sentence token id""" | ||||
|         return self._bos_token | ||||
|  | ||||
|     @property | ||||
|     def eos_token(self): | ||||
|         """ End of sentence token id """ | ||||
|         """End of sentence token id""" | ||||
|         return self._eos_token | ||||
|  | ||||
|     @property | ||||
|     def additional_special_tokens(self): | ||||
|         """ All the additional special tokens you may want to use (list of strings).""" | ||||
|         """All the additional special tokens you may want to use (list of strings).""" | ||||
|         return self._additional_special_tokens | ||||
|  | ||||
|     @property | ||||
|     def bos_token_id(self): | ||||
|         """ Id of the beginning of sentence token in the vocabulary.""" | ||||
|         """Id of the beginning of sentence token in the vocabulary.""" | ||||
|         return self._bos_token_id | ||||
|  | ||||
|     @property | ||||
|     def eos_token_id(self): | ||||
|         """ Id of the end of sentence token in the vocabulary.""" | ||||
|         """Id of the end of sentence token in the vocabulary.""" | ||||
|         return self._eos_token_id | ||||
|  | ||||
|     @property | ||||
|     def additional_special_tokens_ids(self): | ||||
|         """ Ids of all the additional special tokens in the vocabulary (list of integers).""" | ||||
|         """Ids of all the additional special tokens in the vocabulary (list of integers).""" | ||||
|         return [self.vocab.get(token) for token in self._additional_special_tokens] | ||||
|  | ||||
|     @additional_special_tokens.setter | ||||
|   | ||||
		Reference in New Issue
	
	Block a user