mirror of
				https://github.com/hpcaitech/ColossalAI.git
				synced 2025-10-31 22:11:06 +00:00 
			
		
		
		
	* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
		
			
				
	
	
		
			177 lines
		
	
	
		
			5.8 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			177 lines
		
	
	
		
			5.8 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| from colossalai.logging import get_dist_logger
 | |
| 
 | |
| from .bert_dataset import BertDataset
 | |
| from .blendable_dataset import BlendableDataset
 | |
| from .dataset_utils import get_datasets_weights_and_num_samples, get_indexed_dataset_, get_train_valid_test_split_
 | |
| 
 | |
| DSET_TYPE_BERT = "standard_bert"
 | |
| DSET_TYPE_ICT = "ict"
 | |
| DSET_TYPE_T5 = "t5"
 | |
| 
 | |
| DSET_TYPES = [DSET_TYPE_BERT, DSET_TYPE_ICT, DSET_TYPE_T5]
 | |
| 
 | |
| 
 | |
| def _build_train_valid_test_datasets(
 | |
|     data_prefix,
 | |
|     data_impl,
 | |
|     splits_string,
 | |
|     train_valid_test_num_samples,
 | |
|     max_seq_length,
 | |
|     masked_lm_prob,
 | |
|     short_seq_prob,
 | |
|     seed,
 | |
|     skip_warmup,
 | |
|     binary_head,
 | |
|     dataset_type="standard_bert",
 | |
| ):
 | |
|     if dataset_type not in DSET_TYPES:
 | |
|         raise ValueError("Invalid dataset_type: ", dataset_type)
 | |
| 
 | |
|     # Indexed dataset.
 | |
|     indexed_dataset = get_indexed_dataset_(data_prefix, data_impl, skip_warmup)
 | |
| 
 | |
|     # Get start and end indices of train/valid/train into doc-idx
 | |
|     # Note that doc-idx is designed to be num-docs + 1 so we can
 | |
|     # easily iterate over it.
 | |
|     total_num_of_documents = indexed_dataset.doc_idx.shape[0] - 1
 | |
|     splits = get_train_valid_test_split_(splits_string, total_num_of_documents)
 | |
| 
 | |
|     logger = get_dist_logger()
 | |
| 
 | |
|     # Print stats about the splits.
 | |
|     logger.info("\n > dataset split:", ranks=[0])
 | |
| 
 | |
|     def print_split_stats(name, index):
 | |
|         start_index = indexed_dataset.doc_idx[splits[index]]
 | |
|         end_index = indexed_dataset.doc_idx[splits[index + 1]]
 | |
|         logger.info(
 | |
|             "\n    {}:".format(name)
 | |
|             + "\n     document indices in [{}, {}) total of {} documents".format(
 | |
|                 splits[index], splits[index + 1], splits[index + 1] - splits[index]
 | |
|             )
 | |
|             + "\n     sentence indices in [{}, {}) total of {} sentences".format(
 | |
|                 start_index, end_index, end_index - start_index
 | |
|             ),
 | |
|             ranks=[0],
 | |
|         )
 | |
| 
 | |
|     print_split_stats("train", 0)
 | |
|     print_split_stats("validation", 1)
 | |
|     print_split_stats("test", 2)
 | |
| 
 | |
|     def build_dataset(index, name):
 | |
|         dataset = None
 | |
|         if splits[index + 1] > splits[index]:
 | |
|             # Get the pointer to the original doc-idx so we can set it later.
 | |
|             doc_idx_ptr = indexed_dataset.get_doc_idx()
 | |
|             # Slice the doc-idx
 | |
|             start_index = splits[index]
 | |
|             # Add +1 so we can index into the dataset to get the upper bound.
 | |
|             end_index = splits[index + 1] + 1
 | |
|             # New doc_idx view.
 | |
|             indexed_dataset.set_doc_idx(doc_idx_ptr[start_index:end_index])
 | |
|             # Build the dataset accordingly.
 | |
|             kwargs = dict(
 | |
|                 name=name,
 | |
|                 data_prefix=data_prefix,
 | |
|                 num_epochs=None,
 | |
|                 max_num_samples=train_valid_test_num_samples[index],
 | |
|                 max_seq_length=max_seq_length,
 | |
|                 seed=seed,
 | |
|             )
 | |
| 
 | |
|             if dataset_type != DSET_TYPE_BERT:
 | |
|                 raise NotImplementedError("Only BERT dataset is supported")
 | |
|             else:
 | |
|                 dataset = BertDataset(
 | |
|                     indexed_dataset=indexed_dataset,
 | |
|                     masked_lm_prob=masked_lm_prob,
 | |
|                     short_seq_prob=short_seq_prob,
 | |
|                     binary_head=binary_head,
 | |
|                     **kwargs,
 | |
|                 )
 | |
| 
 | |
|             # Set the original pointer so dataset remains the main dataset.
 | |
|             indexed_dataset.set_doc_idx(doc_idx_ptr)
 | |
|             # Checks.
 | |
|             assert indexed_dataset.doc_idx[0] == 0
 | |
|             assert indexed_dataset.doc_idx.shape[0] == (total_num_of_documents + 1)
 | |
|         return dataset
 | |
| 
 | |
|     train_dataset = build_dataset(0, "train")
 | |
|     valid_dataset = build_dataset(1, "valid")
 | |
|     test_dataset = build_dataset(2, "test")
 | |
| 
 | |
|     return (train_dataset, valid_dataset, test_dataset)
 | |
| 
 | |
| 
 | |
| def build_train_valid_test_datasets(
 | |
|     data_prefix,
 | |
|     data_impl,
 | |
|     splits_string,
 | |
|     train_valid_test_num_samples,
 | |
|     max_seq_length,
 | |
|     masked_lm_prob,
 | |
|     short_seq_prob,
 | |
|     seed,
 | |
|     skip_warmup,
 | |
|     binary_head,
 | |
|     dataset_type="standard_bert",
 | |
| ):
 | |
|     if len(data_prefix) == 1:
 | |
|         return _build_train_valid_test_datasets(
 | |
|             data_prefix[0],
 | |
|             data_impl,
 | |
|             splits_string,
 | |
|             train_valid_test_num_samples,
 | |
|             max_seq_length,
 | |
|             masked_lm_prob,
 | |
|             short_seq_prob,
 | |
|             seed,
 | |
|             skip_warmup,
 | |
|             binary_head,
 | |
|             dataset_type=dataset_type,
 | |
|         )
 | |
|     # Blending dataset.
 | |
|     # Parse the values.
 | |
|     output = get_datasets_weights_and_num_samples(data_prefix, train_valid_test_num_samples)
 | |
|     prefixes, weights, datasets_train_valid_test_num_samples = output
 | |
| 
 | |
|     # Build individual datasets.
 | |
|     train_datasets = []
 | |
|     valid_datasets = []
 | |
|     test_datasets = []
 | |
|     for i in range(len(prefixes)):
 | |
|         train_ds, valid_ds, test_ds = _build_train_valid_test_datasets(
 | |
|             prefixes[i],
 | |
|             data_impl,
 | |
|             splits_string,
 | |
|             datasets_train_valid_test_num_samples[i],
 | |
|             max_seq_length,
 | |
|             masked_lm_prob,
 | |
|             short_seq_prob,
 | |
|             seed,
 | |
|             skip_warmup,
 | |
|             binary_head,
 | |
|             dataset_type=dataset_type,
 | |
|         )
 | |
|         if train_ds:
 | |
|             train_datasets.append(train_ds)
 | |
|         if valid_ds:
 | |
|             valid_datasets.append(valid_ds)
 | |
|         if test_ds:
 | |
|             test_datasets.append(test_ds)
 | |
| 
 | |
|         # Blend.
 | |
|     blending_train_dataset = None
 | |
|     if train_datasets:
 | |
|         blending_train_dataset = BlendableDataset(train_datasets, weights)
 | |
|     blending_valid_dataset = None
 | |
|     if valid_datasets:
 | |
|         blending_valid_dataset = BlendableDataset(valid_datasets, weights)
 | |
|     blending_test_dataset = None
 | |
|     if test_datasets:
 | |
|         blending_test_dataset = BlendableDataset(test_datasets, weights)
 | |
| 
 | |
|     return (blending_train_dataset, blending_valid_dataset, blending_test_dataset)
 |