mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-10 13:30:19 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -41,10 +41,19 @@ except:
|
||||
|
||||
|
||||
class BertDataset(Dataset):
|
||||
|
||||
def __init__(self, name, indexed_dataset, data_prefix, num_epochs, max_num_samples, masked_lm_prob, max_seq_length,
|
||||
short_seq_prob, seed, binary_head):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name,
|
||||
indexed_dataset,
|
||||
data_prefix,
|
||||
num_epochs,
|
||||
max_num_samples,
|
||||
masked_lm_prob,
|
||||
max_seq_length,
|
||||
short_seq_prob,
|
||||
seed,
|
||||
binary_head,
|
||||
):
|
||||
# Params to store.
|
||||
self.name = name
|
||||
self.seed = seed
|
||||
@@ -61,11 +70,12 @@ class BertDataset(Dataset):
|
||||
data_prefix,
|
||||
num_epochs,
|
||||
max_num_samples,
|
||||
self.max_seq_length - 3, # account for added tokens,
|
||||
self.max_seq_length - 3, # account for added tokens,
|
||||
short_seq_prob,
|
||||
self.seed,
|
||||
self.name,
|
||||
self.binary_head)
|
||||
self.binary_head,
|
||||
)
|
||||
|
||||
# Vocab stuff.
|
||||
tokenizer = get_tokenizer()
|
||||
@@ -89,7 +99,7 @@ class BertDataset(Dataset):
|
||||
return build_training_sample(
|
||||
sample,
|
||||
seq_length,
|
||||
self.max_seq_length, # needed for padding
|
||||
self.max_seq_length, # needed for padding
|
||||
self.vocab_id_list,
|
||||
self.vocab_id_to_token_dict,
|
||||
self.cls_id,
|
||||
@@ -98,37 +108,39 @@ class BertDataset(Dataset):
|
||||
self.pad_id,
|
||||
self.masked_lm_prob,
|
||||
np_rng,
|
||||
self.binary_head)
|
||||
self.binary_head,
|
||||
)
|
||||
|
||||
|
||||
def get_samples_mapping_(indexed_dataset, data_prefix, num_epochs, max_num_samples, max_seq_length, short_seq_prob,
|
||||
seed, name, binary_head):
|
||||
def get_samples_mapping_(
|
||||
indexed_dataset, data_prefix, num_epochs, max_num_samples, max_seq_length, short_seq_prob, seed, name, binary_head
|
||||
):
|
||||
logger = get_dist_logger()
|
||||
if not num_epochs:
|
||||
if not max_num_samples:
|
||||
raise ValueError("Need to specify either max_num_samples "
|
||||
"or num_epochs")
|
||||
raise ValueError("Need to specify either max_num_samples " "or num_epochs")
|
||||
num_epochs = np.iinfo(np.int32).max - 1
|
||||
if not max_num_samples:
|
||||
max_num_samples = np.iinfo(np.int64).max - 1
|
||||
|
||||
# Filename of the index mapping
|
||||
indexmap_filename = data_prefix
|
||||
indexmap_filename += '_{}_indexmap'.format(name)
|
||||
indexmap_filename += "_{}_indexmap".format(name)
|
||||
if num_epochs != (np.iinfo(np.int32).max - 1):
|
||||
indexmap_filename += '_{}ep'.format(num_epochs)
|
||||
indexmap_filename += "_{}ep".format(num_epochs)
|
||||
if max_num_samples != (np.iinfo(np.int64).max - 1):
|
||||
indexmap_filename += '_{}mns'.format(max_num_samples)
|
||||
indexmap_filename += '_{}msl'.format(max_seq_length)
|
||||
indexmap_filename += '_{:0.2f}ssp'.format(short_seq_prob)
|
||||
indexmap_filename += '_{}s'.format(seed)
|
||||
indexmap_filename += '.npy'
|
||||
indexmap_filename += "_{}mns".format(max_num_samples)
|
||||
indexmap_filename += "_{}msl".format(max_seq_length)
|
||||
indexmap_filename += "_{:0.2f}ssp".format(short_seq_prob)
|
||||
indexmap_filename += "_{}s".format(seed)
|
||||
indexmap_filename += ".npy"
|
||||
|
||||
# Build the indexed mapping if not exist.
|
||||
if torch.distributed.get_rank() == 0 and \
|
||||
not os.path.isfile(indexmap_filename):
|
||||
print(' > WARNING: could not find index map file {}, building '
|
||||
'the indices on rank 0 ...'.format(indexmap_filename))
|
||||
if torch.distributed.get_rank() == 0 and not os.path.isfile(indexmap_filename):
|
||||
print(
|
||||
" > WARNING: could not find index map file {}, building "
|
||||
"the indices on rank 0 ...".format(indexmap_filename)
|
||||
)
|
||||
|
||||
# Make sure the types match the helpers input types.
|
||||
assert indexed_dataset.doc_idx.dtype == np.int64
|
||||
@@ -137,18 +149,27 @@ def get_samples_mapping_(indexed_dataset, data_prefix, num_epochs, max_num_sampl
|
||||
# Build samples mapping
|
||||
verbose = torch.distributed.get_rank() == 0
|
||||
start_time = time.time()
|
||||
logger.info('\n > building samples index mapping for {} ...'.format(name), ranks=[0])
|
||||
logger.info("\n > building samples index mapping for {} ...".format(name), ranks=[0])
|
||||
# First compile and then import.
|
||||
samples_mapping = helpers.build_mapping(indexed_dataset.doc_idx, indexed_dataset.sizes, num_epochs,
|
||||
max_num_samples, max_seq_length, short_seq_prob, seed, verbose,
|
||||
2 if binary_head else 1)
|
||||
logger.info('\n > done building samples index maping', ranks=[0])
|
||||
samples_mapping = helpers.build_mapping(
|
||||
indexed_dataset.doc_idx,
|
||||
indexed_dataset.sizes,
|
||||
num_epochs,
|
||||
max_num_samples,
|
||||
max_seq_length,
|
||||
short_seq_prob,
|
||||
seed,
|
||||
verbose,
|
||||
2 if binary_head else 1,
|
||||
)
|
||||
logger.info("\n > done building samples index maping", ranks=[0])
|
||||
np.save(indexmap_filename, samples_mapping, allow_pickle=True)
|
||||
logger.info('\n > saved the index mapping in {}'.format(indexmap_filename), ranks=[0])
|
||||
logger.info("\n > saved the index mapping in {}".format(indexmap_filename), ranks=[0])
|
||||
# Make sure all the ranks have built the mapping
|
||||
logger.info('\n > elapsed time to build and save samples mapping '
|
||||
'(seconds): {:4f}'.format(time.time() - start_time),
|
||||
ranks=[0])
|
||||
logger.info(
|
||||
"\n > elapsed time to build and save samples mapping " "(seconds): {:4f}".format(time.time() - start_time),
|
||||
ranks=[0],
|
||||
)
|
||||
# This should be a barrier but nccl barrier assumes
|
||||
# device_index=rank which is not the case for model
|
||||
# parallel case
|
||||
@@ -156,22 +177,38 @@ def get_samples_mapping_(indexed_dataset, data_prefix, num_epochs, max_num_sampl
|
||||
torch.distributed.all_reduce(counts, group=gpc.get_group(ParallelMode.DATA))
|
||||
if gpc.is_initialized(ParallelMode.PIPELINE):
|
||||
torch.distributed.all_reduce(counts, group=gpc.get_group(ParallelMode.PIPELINE))
|
||||
assert counts[0].item() == (torch.distributed.get_world_size() //
|
||||
torch.distributed.get_world_size(group=gpc.get_group(ParallelMode.SEQUENCE)))
|
||||
assert counts[0].item() == (
|
||||
torch.distributed.get_world_size()
|
||||
// torch.distributed.get_world_size(group=gpc.get_group(ParallelMode.SEQUENCE))
|
||||
)
|
||||
|
||||
# Load indexed dataset.
|
||||
start_time = time.time()
|
||||
samples_mapping = np.load(indexmap_filename, allow_pickle=True, mmap_mode='r')
|
||||
logger.info('\n > loading indexed mapping from {}'.format(indexmap_filename) +
|
||||
'\n loaded indexed file in {:3.3f} seconds'.format(time.time() - start_time) +
|
||||
'\n total number of samples: {}'.format(samples_mapping.shape[0]),
|
||||
ranks=[0])
|
||||
samples_mapping = np.load(indexmap_filename, allow_pickle=True, mmap_mode="r")
|
||||
logger.info(
|
||||
"\n > loading indexed mapping from {}".format(indexmap_filename)
|
||||
+ "\n loaded indexed file in {:3.3f} seconds".format(time.time() - start_time)
|
||||
+ "\n total number of samples: {}".format(samples_mapping.shape[0]),
|
||||
ranks=[0],
|
||||
)
|
||||
|
||||
return samples_mapping
|
||||
|
||||
|
||||
def build_training_sample(sample, target_seq_length, max_seq_length, vocab_id_list, vocab_id_to_token_dict, cls_id,
|
||||
sep_id, mask_id, pad_id, masked_lm_prob, np_rng, binary_head):
|
||||
def build_training_sample(
|
||||
sample,
|
||||
target_seq_length,
|
||||
max_seq_length,
|
||||
vocab_id_list,
|
||||
vocab_id_to_token_dict,
|
||||
cls_id,
|
||||
sep_id,
|
||||
mask_id,
|
||||
pad_id,
|
||||
masked_lm_prob,
|
||||
np_rng,
|
||||
binary_head,
|
||||
):
|
||||
"""Build training sample.
|
||||
|
||||
Arguments:
|
||||
@@ -215,22 +252,30 @@ def build_training_sample(sample, target_seq_length, max_seq_length, vocab_id_li
|
||||
|
||||
# Masking.
|
||||
max_predictions_per_seq = masked_lm_prob * max_num_tokens
|
||||
(tokens, masked_positions, masked_labels,
|
||||
_) = create_masked_lm_predictions(tokens, vocab_id_list, vocab_id_to_token_dict, masked_lm_prob, cls_id, sep_id,
|
||||
mask_id, max_predictions_per_seq, np_rng)
|
||||
(tokens, masked_positions, masked_labels, _) = create_masked_lm_predictions(
|
||||
tokens,
|
||||
vocab_id_list,
|
||||
vocab_id_to_token_dict,
|
||||
masked_lm_prob,
|
||||
cls_id,
|
||||
sep_id,
|
||||
mask_id,
|
||||
max_predictions_per_seq,
|
||||
np_rng,
|
||||
)
|
||||
|
||||
# Padding.
|
||||
tokens_np, tokentypes_np, labels_np, padding_mask_np, loss_mask_np \
|
||||
= pad_and_convert_to_numpy(tokens, tokentypes, masked_positions,
|
||||
masked_labels, pad_id, max_seq_length)
|
||||
tokens_np, tokentypes_np, labels_np, padding_mask_np, loss_mask_np = pad_and_convert_to_numpy(
|
||||
tokens, tokentypes, masked_positions, masked_labels, pad_id, max_seq_length
|
||||
)
|
||||
|
||||
train_sample = {
|
||||
'text': tokens_np,
|
||||
'types': tokentypes_np,
|
||||
'labels': labels_np,
|
||||
'is_random': int(is_next_random),
|
||||
'loss_mask': loss_mask_np,
|
||||
'padding_mask': padding_mask_np,
|
||||
'truncated': int(truncated)
|
||||
"text": tokens_np,
|
||||
"types": tokentypes_np,
|
||||
"labels": labels_np,
|
||||
"is_random": int(is_next_random),
|
||||
"loss_mask": loss_mask_np,
|
||||
"padding_mask": padding_mask_np,
|
||||
"truncated": int(truncated),
|
||||
}
|
||||
return train_sample
|
||||
|
@@ -22,9 +22,7 @@ import torch
|
||||
|
||||
|
||||
class BlendableDataset(torch.utils.data.Dataset):
|
||||
|
||||
def __init__(self, datasets, weights):
|
||||
|
||||
self.datasets = datasets
|
||||
num_datasets = len(datasets)
|
||||
assert num_datasets == len(weights)
|
||||
@@ -46,12 +44,16 @@ class BlendableDataset(torch.utils.data.Dataset):
|
||||
self.dataset_sample_index = np.zeros(self.size, dtype=np.int64)
|
||||
|
||||
from . import helpers
|
||||
helpers.build_blending_indices(self.dataset_index,
|
||||
self.dataset_sample_index,
|
||||
weights, num_datasets, self.size,
|
||||
torch.distributed.get_rank() == 0)
|
||||
print('> elapsed time for building blendable dataset indices: '
|
||||
'{:.2f} (sec)'.format(time.time() - start_time))
|
||||
|
||||
helpers.build_blending_indices(
|
||||
self.dataset_index,
|
||||
self.dataset_sample_index,
|
||||
weights,
|
||||
num_datasets,
|
||||
self.size,
|
||||
torch.distributed.get_rank() == 0,
|
||||
)
|
||||
print("> elapsed time for building blendable dataset indices: " "{:.2f} (sec)".format(time.time() - start_time))
|
||||
|
||||
def __len__(self):
|
||||
return self.size
|
||||
|
@@ -1,29 +1,34 @@
|
||||
from .blendable_dataset import BlendableDataset
|
||||
from .dataset_utils import get_datasets_weights_and_num_samples, get_indexed_dataset_, get_train_valid_test_split_
|
||||
from .bert_dataset import BertDataset
|
||||
from colossalai.logging import get_dist_logger
|
||||
|
||||
DSET_TYPE_BERT = 'standard_bert'
|
||||
DSET_TYPE_ICT = 'ict'
|
||||
DSET_TYPE_T5 = 't5'
|
||||
from .bert_dataset import BertDataset
|
||||
from .blendable_dataset import BlendableDataset
|
||||
from .dataset_utils import get_datasets_weights_and_num_samples, get_indexed_dataset_, get_train_valid_test_split_
|
||||
|
||||
DSET_TYPE_BERT = "standard_bert"
|
||||
DSET_TYPE_ICT = "ict"
|
||||
DSET_TYPE_T5 = "t5"
|
||||
|
||||
DSET_TYPES = [DSET_TYPE_BERT, DSET_TYPE_ICT, DSET_TYPE_T5]
|
||||
|
||||
|
||||
def _build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
|
||||
train_valid_test_num_samples,
|
||||
max_seq_length, masked_lm_prob,
|
||||
short_seq_prob, seed, skip_warmup,
|
||||
binary_head,
|
||||
dataset_type='standard_bert'):
|
||||
|
||||
def _build_train_valid_test_datasets(
|
||||
data_prefix,
|
||||
data_impl,
|
||||
splits_string,
|
||||
train_valid_test_num_samples,
|
||||
max_seq_length,
|
||||
masked_lm_prob,
|
||||
short_seq_prob,
|
||||
seed,
|
||||
skip_warmup,
|
||||
binary_head,
|
||||
dataset_type="standard_bert",
|
||||
):
|
||||
if dataset_type not in DSET_TYPES:
|
||||
raise ValueError("Invalid dataset_type: ", dataset_type)
|
||||
|
||||
# Indexed dataset.
|
||||
indexed_dataset = get_indexed_dataset_(data_prefix,
|
||||
data_impl,
|
||||
skip_warmup)
|
||||
indexed_dataset = get_indexed_dataset_(data_prefix, data_impl, skip_warmup)
|
||||
|
||||
# Get start and end indices of train/valid/train into doc-idx
|
||||
# Note that doc-idx is designed to be num-docs + 1 so we can
|
||||
@@ -34,22 +39,25 @@ def _build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
|
||||
logger = get_dist_logger()
|
||||
|
||||
# Print stats about the splits.
|
||||
logger.info('\n > dataset split:', ranks=[0])
|
||||
logger.info("\n > dataset split:", ranks=[0])
|
||||
|
||||
def print_split_stats(name, index):
|
||||
start_index = indexed_dataset.doc_idx[splits[index]]
|
||||
end_index = indexed_dataset.doc_idx[splits[index + 1]]
|
||||
logger.info('\n {}:'.format(name) +
|
||||
'\n document indices in [{}, {}) total of {} documents'.format(
|
||||
splits[index], splits[index + 1],
|
||||
splits[index + 1] - splits[index]) +
|
||||
'\n sentence indices in [{}, {}) total of {} sentences'.format(
|
||||
start_index, end_index,
|
||||
end_index - start_index),
|
||||
ranks=[0])
|
||||
print_split_stats('train', 0)
|
||||
print_split_stats('validation', 1)
|
||||
print_split_stats('test', 2)
|
||||
logger.info(
|
||||
"\n {}:".format(name)
|
||||
+ "\n document indices in [{}, {}) total of {} documents".format(
|
||||
splits[index], splits[index + 1], splits[index + 1] - splits[index]
|
||||
)
|
||||
+ "\n sentence indices in [{}, {}) total of {} sentences".format(
|
||||
start_index, end_index, end_index - start_index
|
||||
),
|
||||
ranks=[0],
|
||||
)
|
||||
|
||||
print_split_stats("train", 0)
|
||||
print_split_stats("validation", 1)
|
||||
print_split_stats("test", 2)
|
||||
|
||||
def build_dataset(index, name):
|
||||
dataset = None
|
||||
@@ -80,44 +88,53 @@ def _build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
|
||||
masked_lm_prob=masked_lm_prob,
|
||||
short_seq_prob=short_seq_prob,
|
||||
binary_head=binary_head,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Set the original pointer so dataset remains the main dataset.
|
||||
indexed_dataset.set_doc_idx(doc_idx_ptr)
|
||||
# Checks.
|
||||
assert indexed_dataset.doc_idx[0] == 0
|
||||
assert indexed_dataset.doc_idx.shape[0] == \
|
||||
(total_num_of_documents + 1)
|
||||
assert indexed_dataset.doc_idx.shape[0] == (total_num_of_documents + 1)
|
||||
return dataset
|
||||
|
||||
train_dataset = build_dataset(0, 'train')
|
||||
valid_dataset = build_dataset(1, 'valid')
|
||||
test_dataset = build_dataset(2, 'test')
|
||||
train_dataset = build_dataset(0, "train")
|
||||
valid_dataset = build_dataset(1, "valid")
|
||||
test_dataset = build_dataset(2, "test")
|
||||
|
||||
return (train_dataset, valid_dataset, test_dataset)
|
||||
|
||||
|
||||
def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
|
||||
train_valid_test_num_samples,
|
||||
max_seq_length, masked_lm_prob,
|
||||
short_seq_prob, seed, skip_warmup,
|
||||
binary_head,
|
||||
dataset_type='standard_bert'):
|
||||
|
||||
def build_train_valid_test_datasets(
|
||||
data_prefix,
|
||||
data_impl,
|
||||
splits_string,
|
||||
train_valid_test_num_samples,
|
||||
max_seq_length,
|
||||
masked_lm_prob,
|
||||
short_seq_prob,
|
||||
seed,
|
||||
skip_warmup,
|
||||
binary_head,
|
||||
dataset_type="standard_bert",
|
||||
):
|
||||
if len(data_prefix) == 1:
|
||||
return _build_train_valid_test_datasets(data_prefix[0],
|
||||
data_impl, splits_string,
|
||||
train_valid_test_num_samples,
|
||||
max_seq_length, masked_lm_prob,
|
||||
short_seq_prob, seed,
|
||||
skip_warmup,
|
||||
binary_head,
|
||||
dataset_type=dataset_type)
|
||||
return _build_train_valid_test_datasets(
|
||||
data_prefix[0],
|
||||
data_impl,
|
||||
splits_string,
|
||||
train_valid_test_num_samples,
|
||||
max_seq_length,
|
||||
masked_lm_prob,
|
||||
short_seq_prob,
|
||||
seed,
|
||||
skip_warmup,
|
||||
binary_head,
|
||||
dataset_type=dataset_type,
|
||||
)
|
||||
# Blending dataset.
|
||||
# Parse the values.
|
||||
output = get_datasets_weights_and_num_samples(data_prefix,
|
||||
train_valid_test_num_samples)
|
||||
output = get_datasets_weights_and_num_samples(data_prefix, train_valid_test_num_samples)
|
||||
prefixes, weights, datasets_train_valid_test_num_samples = output
|
||||
|
||||
# Build individual datasets.
|
||||
@@ -126,10 +143,18 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
|
||||
test_datasets = []
|
||||
for i in range(len(prefixes)):
|
||||
train_ds, valid_ds, test_ds = _build_train_valid_test_datasets(
|
||||
prefixes[i], data_impl, splits_string,
|
||||
prefixes[i],
|
||||
data_impl,
|
||||
splits_string,
|
||||
datasets_train_valid_test_num_samples[i],
|
||||
max_seq_length, masked_lm_prob, short_seq_prob,
|
||||
seed, skip_warmup, binary_head, dataset_type=dataset_type)
|
||||
max_seq_length,
|
||||
masked_lm_prob,
|
||||
short_seq_prob,
|
||||
seed,
|
||||
skip_warmup,
|
||||
binary_head,
|
||||
dataset_type=dataset_type,
|
||||
)
|
||||
if train_ds:
|
||||
train_datasets.append(train_ds)
|
||||
if valid_ds:
|
||||
@@ -148,5 +173,4 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
|
||||
if test_datasets:
|
||||
blending_test_dataset = BlendableDataset(test_datasets, weights)
|
||||
|
||||
return (blending_train_dataset, blending_valid_dataset,
|
||||
blending_test_dataset)
|
||||
return (blending_train_dataset, blending_valid_dataset, blending_test_dataset)
|
||||
|
@@ -14,7 +14,6 @@
|
||||
# limitations under the License.
|
||||
"""Dataloaders."""
|
||||
|
||||
import random
|
||||
|
||||
import torch
|
||||
|
||||
@@ -22,61 +21,60 @@ from colossalai.legacy.context import ParallelMode
|
||||
from colossalai.legacy.core import global_context as gpc
|
||||
|
||||
|
||||
def build_pretraining_data_loader(dataset, consumed_samples, micro_batch_size, dataloader_type='single', num_workers=0):
|
||||
def build_pretraining_data_loader(dataset, consumed_samples, micro_batch_size, dataloader_type="single", num_workers=0):
|
||||
"""Build dataloader given an input dataset."""
|
||||
|
||||
if dataset is None:
|
||||
return None
|
||||
|
||||
# Megatron sampler
|
||||
if dataloader_type == 'single':
|
||||
batch_sampler = MegatronPretrainingSampler(total_samples=len(dataset),
|
||||
consumed_samples=consumed_samples,
|
||||
micro_batch_size=micro_batch_size,
|
||||
data_parallel_rank=gpc.get_local_rank(ParallelMode.DATA),
|
||||
data_parallel_size=gpc.get_world_size(ParallelMode.DATA))
|
||||
elif dataloader_type == 'cyclic':
|
||||
batch_sampler = MegatronPretrainingRandomSampler(total_samples=len(dataset),
|
||||
consumed_samples=consumed_samples,
|
||||
micro_batch_size=micro_batch_size,
|
||||
data_parallel_rank=gpc.get_local_rank(ParallelMode.DATA),
|
||||
data_parallel_size=gpc.get_world_size(ParallelMode.DATA))
|
||||
if dataloader_type == "single":
|
||||
batch_sampler = MegatronPretrainingSampler(
|
||||
total_samples=len(dataset),
|
||||
consumed_samples=consumed_samples,
|
||||
micro_batch_size=micro_batch_size,
|
||||
data_parallel_rank=gpc.get_local_rank(ParallelMode.DATA),
|
||||
data_parallel_size=gpc.get_world_size(ParallelMode.DATA),
|
||||
)
|
||||
elif dataloader_type == "cyclic":
|
||||
batch_sampler = MegatronPretrainingRandomSampler(
|
||||
total_samples=len(dataset),
|
||||
consumed_samples=consumed_samples,
|
||||
micro_batch_size=micro_batch_size,
|
||||
data_parallel_rank=gpc.get_local_rank(ParallelMode.DATA),
|
||||
data_parallel_size=gpc.get_world_size(ParallelMode.DATA),
|
||||
)
|
||||
else:
|
||||
raise Exception('{} dataloader type is not supported.'.format(dataloader_type))
|
||||
raise Exception("{} dataloader type is not supported.".format(dataloader_type))
|
||||
|
||||
# Torch dataloader.
|
||||
return torch.utils.data.DataLoader(dataset, batch_sampler=batch_sampler, num_workers=num_workers, pin_memory=True)
|
||||
|
||||
|
||||
class MegatronPretrainingSampler:
|
||||
|
||||
def __init__(self,
|
||||
total_samples,
|
||||
consumed_samples,
|
||||
micro_batch_size,
|
||||
data_parallel_rank,
|
||||
data_parallel_size,
|
||||
drop_last=True):
|
||||
def __init__(
|
||||
self, total_samples, consumed_samples, micro_batch_size, data_parallel_rank, data_parallel_size, drop_last=True
|
||||
):
|
||||
# Keep a copy of input params for later use.
|
||||
self.total_samples = total_samples
|
||||
self.consumed_samples = consumed_samples
|
||||
self.micro_batch_size = micro_batch_size
|
||||
self.data_parallel_rank = data_parallel_rank
|
||||
self.micro_batch_times_data_parallel_size = \
|
||||
self.micro_batch_size * data_parallel_size
|
||||
self.micro_batch_times_data_parallel_size = self.micro_batch_size * data_parallel_size
|
||||
self.drop_last = drop_last
|
||||
|
||||
# Sanity checks.
|
||||
assert self.total_samples > 0, \
|
||||
'no sample to consume: {}'.format(self.total_samples)
|
||||
assert self.consumed_samples < self.total_samples, \
|
||||
'no samples left to consume: {}, {}'.format(self.consumed_samples,
|
||||
self.total_samples)
|
||||
assert self.total_samples > 0, "no sample to consume: {}".format(self.total_samples)
|
||||
assert self.consumed_samples < self.total_samples, "no samples left to consume: {}, {}".format(
|
||||
self.consumed_samples, self.total_samples
|
||||
)
|
||||
assert self.micro_batch_size > 0
|
||||
assert data_parallel_size > 0
|
||||
assert self.data_parallel_rank < data_parallel_size, \
|
||||
'data_parallel_rank should be smaller than data size: {}, ' \
|
||||
'{}'.format(self.data_parallel_rank, data_parallel_size)
|
||||
assert (
|
||||
self.data_parallel_rank < data_parallel_size
|
||||
), "data_parallel_rank should be smaller than data size: {}, " "{}".format(
|
||||
self.data_parallel_rank, data_parallel_size
|
||||
)
|
||||
|
||||
def __len__(self):
|
||||
return self.total_samples
|
||||
@@ -103,7 +101,6 @@ class MegatronPretrainingSampler:
|
||||
|
||||
|
||||
class MegatronPretrainingRandomSampler:
|
||||
|
||||
def __init__(self, total_samples, consumed_samples, micro_batch_size, data_parallel_rank, data_parallel_size):
|
||||
# Keep a copy of input params for later use.
|
||||
self.total_samples = total_samples
|
||||
@@ -111,19 +108,18 @@ class MegatronPretrainingRandomSampler:
|
||||
self.micro_batch_size = micro_batch_size
|
||||
self.data_parallel_rank = data_parallel_rank
|
||||
self.data_parallel_size = data_parallel_size
|
||||
self.micro_batch_times_data_parallel_size = \
|
||||
self.micro_batch_size * data_parallel_size
|
||||
self.last_batch_size = \
|
||||
self.total_samples % self.micro_batch_times_data_parallel_size
|
||||
self.micro_batch_times_data_parallel_size = self.micro_batch_size * data_parallel_size
|
||||
self.last_batch_size = self.total_samples % self.micro_batch_times_data_parallel_size
|
||||
|
||||
# Sanity checks.
|
||||
assert self.total_samples > 0, \
|
||||
'no sample to consume: {}'.format(self.total_samples)
|
||||
assert self.total_samples > 0, "no sample to consume: {}".format(self.total_samples)
|
||||
assert self.micro_batch_size > 0
|
||||
assert data_parallel_size > 0
|
||||
assert self.data_parallel_rank < data_parallel_size, \
|
||||
'data_parallel_rank should be smaller than data size: {}, ' \
|
||||
'{}'.format(self.data_parallel_rank, data_parallel_size)
|
||||
assert (
|
||||
self.data_parallel_rank < data_parallel_size
|
||||
), "data_parallel_rank should be smaller than data size: {}, " "{}".format(
|
||||
self.data_parallel_rank, data_parallel_size
|
||||
)
|
||||
|
||||
def __len__(self):
|
||||
return self.total_samples
|
||||
@@ -135,8 +131,7 @@ class MegatronPretrainingRandomSampler:
|
||||
assert current_epoch_samples % self.micro_batch_times_data_parallel_size == 0
|
||||
|
||||
# data sharding and random sampling
|
||||
bucket_size = (self.total_samples // self.micro_batch_times_data_parallel_size) \
|
||||
* self.micro_batch_size
|
||||
bucket_size = (self.total_samples // self.micro_batch_times_data_parallel_size) * self.micro_batch_size
|
||||
bucket_offset = current_epoch_samples // self.data_parallel_size
|
||||
start_idx = self.data_parallel_rank * bucket_size
|
||||
|
||||
|
@@ -18,32 +18,33 @@
|
||||
# https://github.com/google-research/albert/blob/master/create_pretraining_data.py
|
||||
# with some modifications.
|
||||
|
||||
import collections
|
||||
import math
|
||||
import time
|
||||
import collections
|
||||
from colossalai.logging import get_dist_logger
|
||||
|
||||
import numpy as np
|
||||
|
||||
from colossalai.logging import get_dist_logger
|
||||
|
||||
from .blendable_dataset import BlendableDataset
|
||||
from .indexed_dataset import make_dataset as make_indexed_dataset
|
||||
|
||||
DSET_TYPE_STD = 'standard_bert'
|
||||
DSET_TYPE_ICT = 'ict'
|
||||
DSET_TYPE_STD = "standard_bert"
|
||||
DSET_TYPE_ICT = "ict"
|
||||
|
||||
DSET_TYPES = [DSET_TYPE_ICT, DSET_TYPE_STD]
|
||||
|
||||
|
||||
def get_datasets_weights_and_num_samples(data_prefix,
|
||||
train_valid_test_num_samples):
|
||||
|
||||
def get_datasets_weights_and_num_samples(data_prefix, train_valid_test_num_samples):
|
||||
# The data prefix should be in the format of:
|
||||
# weight-1, data-prefix-1, weight-2, data-prefix-2, ..
|
||||
assert len(data_prefix) % 2 == 0
|
||||
num_datasets = len(data_prefix) // 2
|
||||
weights = [0]*num_datasets
|
||||
prefixes = [0]*num_datasets
|
||||
weights = [0] * num_datasets
|
||||
prefixes = [0] * num_datasets
|
||||
for i in range(num_datasets):
|
||||
weights[i] = float(data_prefix[2*i])
|
||||
prefixes[i] = (data_prefix[2*i+1]).strip()
|
||||
weights[i] = float(data_prefix[2 * i])
|
||||
prefixes[i] = (data_prefix[2 * i + 1]).strip()
|
||||
# Normalize weights
|
||||
weight_sum = 0.0
|
||||
for weight in weights:
|
||||
@@ -57,8 +58,8 @@ def get_datasets_weights_and_num_samples(data_prefix,
|
||||
datasets_train_valid_test_num_samples = []
|
||||
for weight in weights:
|
||||
datasets_train_valid_test_num_samples.append(
|
||||
[int(math.ceil(val * weight * 1.005))
|
||||
for val in train_valid_test_num_samples])
|
||||
[int(math.ceil(val * weight * 1.005)) for val in train_valid_test_num_samples]
|
||||
)
|
||||
|
||||
return prefixes, weights, datasets_train_valid_test_num_samples
|
||||
|
||||
@@ -68,11 +69,13 @@ def compile_helper():
|
||||
is invoked on a single process."""
|
||||
import os
|
||||
import subprocess
|
||||
|
||||
path = os.path.abspath(os.path.dirname(__file__))
|
||||
ret = subprocess.run(['make', '-C', path])
|
||||
ret = subprocess.run(["make", "-C", path])
|
||||
if ret.returncode != 0:
|
||||
print("Making C++ dataset helpers module failed, exiting.")
|
||||
import sys
|
||||
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
@@ -82,7 +85,7 @@ def get_a_and_b_segments(sample, np_rng):
|
||||
# Number of sentences in the sample.
|
||||
n_sentences = len(sample)
|
||||
# Make sure we always have two sentences.
|
||||
assert n_sentences > 1, 'make sure each sample has at least two sentences.'
|
||||
assert n_sentences > 1, "make sure each sample has at least two sentences."
|
||||
|
||||
# First part:
|
||||
# `a_end` is how many sentences go into the `A`.
|
||||
@@ -110,7 +113,7 @@ def get_a_and_b_segments(sample, np_rng):
|
||||
|
||||
def truncate_segments(tokens_a, tokens_b, len_a, len_b, max_num_tokens, np_rng):
|
||||
"""Truncates a pair of sequences to a maximum sequence length."""
|
||||
#print(len_a, len_b, max_num_tokens)
|
||||
# print(len_a, len_b, max_num_tokens)
|
||||
assert len_a > 0
|
||||
if len_a + len_b <= max_num_tokens:
|
||||
return False
|
||||
@@ -155,8 +158,7 @@ def create_tokens_and_tokentypes(tokens_a, tokens_b, cls_id, sep_id):
|
||||
return tokens, tokentypes
|
||||
|
||||
|
||||
MaskedLmInstance = collections.namedtuple("MaskedLmInstance",
|
||||
["index", "label"])
|
||||
MaskedLmInstance = collections.namedtuple("MaskedLmInstance", ["index", "label"])
|
||||
|
||||
|
||||
def is_start_piece(piece):
|
||||
@@ -168,16 +170,21 @@ def is_start_piece(piece):
|
||||
return not piece.startswith("##")
|
||||
|
||||
|
||||
def create_masked_lm_predictions(tokens,
|
||||
vocab_id_list, vocab_id_to_token_dict,
|
||||
masked_lm_prob,
|
||||
cls_id, sep_id, mask_id,
|
||||
max_predictions_per_seq,
|
||||
np_rng,
|
||||
max_ngrams=3,
|
||||
do_whole_word_mask=True,
|
||||
favor_longer_ngram=False,
|
||||
do_permutation=False):
|
||||
def create_masked_lm_predictions(
|
||||
tokens,
|
||||
vocab_id_list,
|
||||
vocab_id_to_token_dict,
|
||||
masked_lm_prob,
|
||||
cls_id,
|
||||
sep_id,
|
||||
mask_id,
|
||||
max_predictions_per_seq,
|
||||
np_rng,
|
||||
max_ngrams=3,
|
||||
do_whole_word_mask=True,
|
||||
favor_longer_ngram=False,
|
||||
do_permutation=False,
|
||||
):
|
||||
"""Creates the predictions for the masked LM objective.
|
||||
Note: Tokens here are vocab ids and not text tokens."""
|
||||
|
||||
@@ -187,7 +194,7 @@ def create_masked_lm_predictions(tokens,
|
||||
# on-the-fly whole word masking is possible.
|
||||
token_boundary = [0] * len(tokens)
|
||||
|
||||
for (i, token) in enumerate(tokens):
|
||||
for i, token in enumerate(tokens):
|
||||
if token == cls_id or token == sep_id:
|
||||
token_boundary[i] = 1
|
||||
continue
|
||||
@@ -197,8 +204,7 @@ def create_masked_lm_predictions(tokens,
|
||||
# Note that Whole Word Masking does *not* change the training code
|
||||
# at all -- we still predict each WordPiece independently, softmaxed
|
||||
# over the entire vocabulary.
|
||||
if (do_whole_word_mask and len(cand_indexes) >= 1 and
|
||||
not is_start_piece(vocab_id_to_token_dict[token])):
|
||||
if do_whole_word_mask and len(cand_indexes) >= 1 and not is_start_piece(vocab_id_to_token_dict[token]):
|
||||
cand_indexes[-1].append(i)
|
||||
else:
|
||||
cand_indexes.append([i])
|
||||
@@ -211,16 +217,14 @@ def create_masked_lm_predictions(tokens,
|
||||
masked_lm_labels = []
|
||||
|
||||
if masked_lm_prob == 0:
|
||||
return (output_tokens, masked_lm_positions,
|
||||
masked_lm_labels, token_boundary)
|
||||
return (output_tokens, masked_lm_positions, masked_lm_labels, token_boundary)
|
||||
|
||||
num_to_predict = min(max_predictions_per_seq,
|
||||
max(1, int(round(len(tokens) * masked_lm_prob))))
|
||||
num_to_predict = min(max_predictions_per_seq, max(1, int(round(len(tokens) * masked_lm_prob))))
|
||||
|
||||
# Note(mingdachen):
|
||||
# By default, we set the probabilities to favor shorter ngram sequences.
|
||||
ngrams = np.arange(1, max_ngrams + 1, dtype=np.int64)
|
||||
pvals = 1. / np.arange(1, max_ngrams + 1)
|
||||
pvals = 1.0 / np.arange(1, max_ngrams + 1)
|
||||
pvals /= pvals.sum(keepdims=True)
|
||||
|
||||
if favor_longer_ngram:
|
||||
@@ -230,7 +234,7 @@ def create_masked_lm_predictions(tokens,
|
||||
for idx in range(len(cand_indexes)):
|
||||
ngram_index = []
|
||||
for n in ngrams:
|
||||
ngram_index.append(cand_indexes[idx:idx + n])
|
||||
ngram_index.append(cand_indexes[idx : idx + n])
|
||||
ngram_indexes.append(ngram_index)
|
||||
|
||||
np_rng.shuffle(ngram_indexes)
|
||||
@@ -249,9 +253,10 @@ def create_masked_lm_predictions(tokens,
|
||||
if index in covered_indexes:
|
||||
continue
|
||||
|
||||
n = np_rng.choice(ngrams[:len(cand_index_set)],
|
||||
p=pvals[:len(cand_index_set)] /
|
||||
pvals[:len(cand_index_set)].sum(keepdims=True))
|
||||
n = np_rng.choice(
|
||||
ngrams[: len(cand_index_set)],
|
||||
p=pvals[: len(cand_index_set)] / pvals[: len(cand_index_set)].sum(keepdims=True),
|
||||
)
|
||||
index_set = sum(cand_index_set[n - 1], [])
|
||||
n -= 1
|
||||
# Note(mingdachen):
|
||||
@@ -309,9 +314,10 @@ def create_masked_lm_predictions(tokens,
|
||||
if index in covered_indexes or index in select_indexes:
|
||||
continue
|
||||
|
||||
n = np.random.choice(ngrams[:len(cand_index_set)],
|
||||
p=pvals[:len(cand_index_set)] /
|
||||
pvals[:len(cand_index_set)].sum(keepdims=True))
|
||||
n = np.random.choice(
|
||||
ngrams[: len(cand_index_set)],
|
||||
p=pvals[: len(cand_index_set)] / pvals[: len(cand_index_set)].sum(keepdims=True),
|
||||
)
|
||||
index_set = sum(cand_index_set[n - 1], [])
|
||||
n -= 1
|
||||
|
||||
@@ -353,8 +359,7 @@ def create_masked_lm_predictions(tokens,
|
||||
return (output_tokens, masked_lm_positions, masked_lm_labels, token_boundary)
|
||||
|
||||
|
||||
def pad_and_convert_to_numpy(tokens, tokentypes, masked_positions,
|
||||
masked_labels, pad_id, max_seq_length):
|
||||
def pad_and_convert_to_numpy(tokens, tokentypes, masked_positions, masked_labels, pad_id, max_seq_length):
|
||||
"""Pad sequences and convert them to numpy."""
|
||||
|
||||
# Some checks.
|
||||
@@ -370,8 +375,7 @@ def pad_and_convert_to_numpy(tokens, tokentypes, masked_positions,
|
||||
tokentypes_np = np.array(tokentypes + filler, dtype=np.int64)
|
||||
|
||||
# Padding mask.
|
||||
padding_mask_np = np.array([1] * num_tokens + [0] * padding_length,
|
||||
dtype=np.int64)
|
||||
padding_mask_np = np.array([1] * num_tokens + [0] * padding_length, dtype=np.int64)
|
||||
|
||||
# Lables and loss mask.
|
||||
labels = [-1] * max_seq_length
|
||||
@@ -386,26 +390,36 @@ def pad_and_convert_to_numpy(tokens, tokentypes, masked_positions,
|
||||
return tokens_np, tokentypes_np, labels_np, padding_mask_np, loss_mask_np
|
||||
|
||||
|
||||
def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
|
||||
train_valid_test_num_samples,
|
||||
max_seq_length, masked_lm_prob,
|
||||
short_seq_prob, seed, skip_warmup,
|
||||
binary_head,
|
||||
dataset_type='standard_bert'):
|
||||
|
||||
def build_train_valid_test_datasets(
|
||||
data_prefix,
|
||||
data_impl,
|
||||
splits_string,
|
||||
train_valid_test_num_samples,
|
||||
max_seq_length,
|
||||
masked_lm_prob,
|
||||
short_seq_prob,
|
||||
seed,
|
||||
skip_warmup,
|
||||
binary_head,
|
||||
dataset_type="standard_bert",
|
||||
):
|
||||
if len(data_prefix) == 1:
|
||||
return _build_train_valid_test_datasets(data_prefix[0],
|
||||
data_impl, splits_string,
|
||||
train_valid_test_num_samples,
|
||||
max_seq_length, masked_lm_prob,
|
||||
short_seq_prob, seed,
|
||||
skip_warmup,
|
||||
binary_head,
|
||||
dataset_type=dataset_type)
|
||||
return _build_train_valid_test_datasets(
|
||||
data_prefix[0],
|
||||
data_impl,
|
||||
splits_string,
|
||||
train_valid_test_num_samples,
|
||||
max_seq_length,
|
||||
masked_lm_prob,
|
||||
short_seq_prob,
|
||||
seed,
|
||||
skip_warmup,
|
||||
binary_head,
|
||||
dataset_type=dataset_type,
|
||||
)
|
||||
# Blending dataset.
|
||||
# Parse the values.
|
||||
output = get_datasets_weights_and_num_samples(data_prefix,
|
||||
train_valid_test_num_samples)
|
||||
output = get_datasets_weights_and_num_samples(data_prefix, train_valid_test_num_samples)
|
||||
prefixes, weights, datasets_train_valid_test_num_samples = output
|
||||
|
||||
# Build individual datasets.
|
||||
@@ -414,10 +428,18 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
|
||||
test_datasets = []
|
||||
for i in range(len(prefixes)):
|
||||
train_ds, valid_ds, test_ds = _build_train_valid_test_datasets(
|
||||
prefixes[i], data_impl, splits_string,
|
||||
prefixes[i],
|
||||
data_impl,
|
||||
splits_string,
|
||||
datasets_train_valid_test_num_samples[i],
|
||||
max_seq_length, masked_lm_prob, short_seq_prob,
|
||||
seed, skip_warmup, binary_head, dataset_type=dataset_type)
|
||||
max_seq_length,
|
||||
masked_lm_prob,
|
||||
short_seq_prob,
|
||||
seed,
|
||||
skip_warmup,
|
||||
binary_head,
|
||||
dataset_type=dataset_type,
|
||||
)
|
||||
if train_ds:
|
||||
train_datasets.append(train_ds)
|
||||
if valid_ds:
|
||||
@@ -436,31 +458,33 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
|
||||
if test_datasets:
|
||||
blending_test_dataset = BlendableDataset(test_datasets, weights)
|
||||
|
||||
return (blending_train_dataset, blending_valid_dataset,
|
||||
blending_test_dataset)
|
||||
return (blending_train_dataset, blending_valid_dataset, blending_test_dataset)
|
||||
|
||||
|
||||
def _build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
|
||||
train_valid_test_num_samples,
|
||||
max_seq_length, masked_lm_prob,
|
||||
short_seq_prob, seed, skip_warmup,
|
||||
binary_head,
|
||||
dataset_type='standard_bert'):
|
||||
def _build_train_valid_test_datasets(
|
||||
data_prefix,
|
||||
data_impl,
|
||||
splits_string,
|
||||
train_valid_test_num_samples,
|
||||
max_seq_length,
|
||||
masked_lm_prob,
|
||||
short_seq_prob,
|
||||
seed,
|
||||
skip_warmup,
|
||||
binary_head,
|
||||
dataset_type="standard_bert",
|
||||
):
|
||||
logger = get_dist_logger()
|
||||
|
||||
if dataset_type not in DSET_TYPES:
|
||||
raise ValueError("Invalid dataset_type: ", dataset_type)
|
||||
|
||||
# Indexed dataset.
|
||||
indexed_dataset = get_indexed_dataset_(data_prefix,
|
||||
data_impl,
|
||||
skip_warmup)
|
||||
indexed_dataset = get_indexed_dataset_(data_prefix, data_impl, skip_warmup)
|
||||
|
||||
if dataset_type == DSET_TYPE_ICT:
|
||||
args = get_args()
|
||||
title_dataset = get_indexed_dataset_(args.titles_data_path,
|
||||
data_impl,
|
||||
skip_warmup)
|
||||
title_dataset = get_indexed_dataset_(args.titles_data_path, data_impl, skip_warmup)
|
||||
|
||||
# Get start and end indices of train/valid/train into doc-idx
|
||||
# Note that doc-idx is designed to be num-docs + 1 so we can
|
||||
@@ -469,27 +493,29 @@ def _build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
|
||||
splits = get_train_valid_test_split_(splits_string, total_num_of_documents)
|
||||
|
||||
# Print stats about the splits.
|
||||
logger.info('\n > dataset split:')
|
||||
logger.info("\n > dataset split:")
|
||||
|
||||
def print_split_stats(name, index):
|
||||
start_index = indexed_dataset.doc_idx[splits[index]]
|
||||
end_index = indexed_dataset.doc_idx[splits[index + 1]]
|
||||
logger.info('\n {}:'.format(name) +
|
||||
'\n document indices in [{}, {}) total of {} documents'.format(
|
||||
splits[index],
|
||||
splits[index + 1],
|
||||
splits[index + 1] - splits[index]) +
|
||||
'\n sentence indices in [{}, {}) total of {} sentences'.format(
|
||||
start_index,
|
||||
end_index,
|
||||
end_index - start_index),
|
||||
ranks=[0])
|
||||
print_split_stats('train', 0)
|
||||
print_split_stats('validation', 1)
|
||||
print_split_stats('test', 2)
|
||||
logger.info(
|
||||
"\n {}:".format(name)
|
||||
+ "\n document indices in [{}, {}) total of {} documents".format(
|
||||
splits[index], splits[index + 1], splits[index + 1] - splits[index]
|
||||
)
|
||||
+ "\n sentence indices in [{}, {}) total of {} sentences".format(
|
||||
start_index, end_index, end_index - start_index
|
||||
),
|
||||
ranks=[0],
|
||||
)
|
||||
|
||||
print_split_stats("train", 0)
|
||||
print_split_stats("validation", 1)
|
||||
print_split_stats("test", 2)
|
||||
|
||||
def build_dataset(index, name):
|
||||
from .bert_dataset import BertDataset
|
||||
|
||||
dataset = None
|
||||
if splits[index + 1] > splits[index]:
|
||||
# Get the pointer to the original doc-idx so we can set it later.
|
||||
@@ -508,7 +534,7 @@ def _build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
|
||||
max_num_samples=train_valid_test_num_samples[index],
|
||||
max_seq_length=max_seq_length,
|
||||
seed=seed,
|
||||
binary_head=binary_head
|
||||
binary_head=binary_head,
|
||||
)
|
||||
|
||||
if dataset_type == DSET_TYPE_ICT:
|
||||
@@ -518,27 +544,26 @@ def _build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
|
||||
title_dataset=title_dataset,
|
||||
query_in_block_prob=args.query_in_block_prob,
|
||||
use_one_sent_docs=args.use_one_sent_docs,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
dataset = BertDataset(
|
||||
indexed_dataset=indexed_dataset,
|
||||
masked_lm_prob=masked_lm_prob,
|
||||
short_seq_prob=short_seq_prob,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Set the original pointer so dataset remains the main dataset.
|
||||
indexed_dataset.set_doc_idx(doc_idx_ptr)
|
||||
# Checks.
|
||||
assert indexed_dataset.doc_idx[0] == 0
|
||||
assert indexed_dataset.doc_idx.shape[0] == \
|
||||
(total_num_of_documents + 1)
|
||||
assert indexed_dataset.doc_idx.shape[0] == (total_num_of_documents + 1)
|
||||
return dataset
|
||||
|
||||
train_dataset = build_dataset(0, 'train')
|
||||
valid_dataset = build_dataset(1, 'valid')
|
||||
test_dataset = build_dataset(2, 'test')
|
||||
train_dataset = build_dataset(0, "train")
|
||||
valid_dataset = build_dataset(1, "valid")
|
||||
test_dataset = build_dataset(2, "test")
|
||||
|
||||
return (train_dataset, valid_dataset, test_dataset)
|
||||
|
||||
@@ -546,44 +571,41 @@ def _build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
|
||||
def get_indexed_dataset_(data_prefix, data_impl, skip_warmup):
|
||||
logger = get_dist_logger()
|
||||
start_time = time.time()
|
||||
indexed_dataset = make_indexed_dataset(data_prefix,
|
||||
data_impl,
|
||||
skip_warmup)
|
||||
indexed_dataset = make_indexed_dataset(data_prefix, data_impl, skip_warmup)
|
||||
assert indexed_dataset.sizes.shape[0] == indexed_dataset.doc_idx[-1]
|
||||
logger.info('\n > building dataset index ...', ranks=[0])
|
||||
logger.info('\n > finished creating indexed dataset in {:4f} '
|
||||
'seconds'.format(time.time() - start_time), ranks=[0])
|
||||
logger.info('\n > indexed dataset stats:' +
|
||||
'\n number of documents: {}'.format(
|
||||
indexed_dataset.doc_idx.shape[0] - 1) +
|
||||
'\n number of sentences: {}'.format(
|
||||
indexed_dataset.sizes.shape[0]),
|
||||
ranks=[0]
|
||||
)
|
||||
logger.info("\n > building dataset index ...", ranks=[0])
|
||||
logger.info(
|
||||
"\n > finished creating indexed dataset in {:4f} " "seconds".format(time.time() - start_time), ranks=[0]
|
||||
)
|
||||
logger.info(
|
||||
"\n > indexed dataset stats:"
|
||||
+ "\n number of documents: {}".format(indexed_dataset.doc_idx.shape[0] - 1)
|
||||
+ "\n number of sentences: {}".format(indexed_dataset.sizes.shape[0]),
|
||||
ranks=[0],
|
||||
)
|
||||
|
||||
return indexed_dataset
|
||||
|
||||
|
||||
def get_train_valid_test_split_(splits_string, size):
|
||||
""" Get dataset splits from comma or '/' separated string list."""
|
||||
"""Get dataset splits from comma or '/' separated string list."""
|
||||
|
||||
splits = []
|
||||
if splits_string.find(',') != -1:
|
||||
splits = [float(s) for s in splits_string.split(',')]
|
||||
elif splits_string.find('/') != -1:
|
||||
splits = [float(s) for s in splits_string.split('/')]
|
||||
if splits_string.find(",") != -1:
|
||||
splits = [float(s) for s in splits_string.split(",")]
|
||||
elif splits_string.find("/") != -1:
|
||||
splits = [float(s) for s in splits_string.split("/")]
|
||||
else:
|
||||
splits = [float(splits_string)]
|
||||
while len(splits) < 3:
|
||||
splits.append(0.)
|
||||
splits.append(0.0)
|
||||
splits = splits[:3]
|
||||
splits_sum = sum(splits)
|
||||
assert splits_sum > 0.0
|
||||
splits = [split / splits_sum for split in splits]
|
||||
splits_index = [0]
|
||||
for index, split in enumerate(splits):
|
||||
splits_index.append(splits_index[index] +
|
||||
int(round(split * float(size))))
|
||||
splits_index.append(splits_index[index] + int(round(split * float(size))))
|
||||
diff = splits_index[-1] - size
|
||||
for index in range(1, len(splits_index)):
|
||||
splits_index[index] -= diff
|
||||
|
File diff suppressed because it is too large
Load Diff
@@ -2,12 +2,11 @@ import itertools
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
from megatron import get_tokenizer
|
||||
from megatron import get_args
|
||||
from megatron import get_args, get_tokenizer
|
||||
from megatron.data.dataset_utils import get_indexed_dataset_
|
||||
from megatron.data.realm_dataset_utils import get_block_samples_mapping
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
|
||||
def make_attention_mask(source_block, target_block):
|
||||
"""
|
||||
@@ -20,16 +19,17 @@ def make_attention_mask(source_block, target_block):
|
||||
# (source_length, target_length)
|
||||
return mask
|
||||
|
||||
|
||||
def get_ict_dataset(use_titles=True, query_in_block_prob=1):
|
||||
"""Get a dataset which uses block samples mappings to get ICT/block indexing data (via get_block())
|
||||
rather than for training, since it is only built with a single epoch sample mapping.
|
||||
"""
|
||||
args = get_args()
|
||||
block_dataset = get_indexed_dataset_(args.data_path, 'mmap', True)
|
||||
titles_dataset = get_indexed_dataset_(args.titles_data_path, 'mmap', True)
|
||||
block_dataset = get_indexed_dataset_(args.data_path, "mmap", True)
|
||||
titles_dataset = get_indexed_dataset_(args.titles_data_path, "mmap", True)
|
||||
|
||||
kwargs = dict(
|
||||
name='full',
|
||||
name="full",
|
||||
block_dataset=block_dataset,
|
||||
title_dataset=titles_dataset,
|
||||
data_prefix=args.data_path,
|
||||
@@ -39,7 +39,7 @@ def get_ict_dataset(use_titles=True, query_in_block_prob=1):
|
||||
seed=1,
|
||||
query_in_block_prob=query_in_block_prob,
|
||||
use_titles=use_titles,
|
||||
use_one_sent_docs=args.use_one_sent_docs
|
||||
use_one_sent_docs=args.use_one_sent_docs,
|
||||
)
|
||||
dataset = ICTDataset(**kwargs)
|
||||
return dataset
|
||||
@@ -47,9 +47,22 @@ def get_ict_dataset(use_titles=True, query_in_block_prob=1):
|
||||
|
||||
class ICTDataset(Dataset):
|
||||
"""Dataset containing sentences and their blocks for an inverse cloze task."""
|
||||
def __init__(self, name, block_dataset, title_dataset, data_prefix,
|
||||
num_epochs, max_num_samples, max_seq_length, query_in_block_prob,
|
||||
seed, use_titles=True, use_one_sent_docs=False, binary_head=False):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name,
|
||||
block_dataset,
|
||||
title_dataset,
|
||||
data_prefix,
|
||||
num_epochs,
|
||||
max_num_samples,
|
||||
max_seq_length,
|
||||
query_in_block_prob,
|
||||
seed,
|
||||
use_titles=True,
|
||||
use_one_sent_docs=False,
|
||||
binary_head=False,
|
||||
):
|
||||
self.name = name
|
||||
self.seed = seed
|
||||
self.max_seq_length = max_seq_length
|
||||
@@ -61,8 +74,16 @@ class ICTDataset(Dataset):
|
||||
self.use_one_sent_docs = use_one_sent_docs
|
||||
|
||||
self.samples_mapping = get_block_samples_mapping(
|
||||
block_dataset, title_dataset, data_prefix, num_epochs,
|
||||
max_num_samples, max_seq_length, seed, name, use_one_sent_docs)
|
||||
block_dataset,
|
||||
title_dataset,
|
||||
data_prefix,
|
||||
num_epochs,
|
||||
max_num_samples,
|
||||
max_seq_length,
|
||||
seed,
|
||||
name,
|
||||
use_one_sent_docs,
|
||||
)
|
||||
self.tokenizer = get_tokenizer()
|
||||
self.vocab_id_list = list(self.tokenizer.inv_vocab.keys())
|
||||
self.vocab_id_to_token_list = self.tokenizer.inv_vocab
|
||||
@@ -99,8 +120,8 @@ class ICTDataset(Dataset):
|
||||
|
||||
# still need to truncate because blocks are concluded when
|
||||
# the sentence lengths have exceeded max_seq_length.
|
||||
query = query[:self.max_seq_length - 2]
|
||||
block = list(itertools.chain(*block))[:self.max_seq_length - title_pad_offset]
|
||||
query = query[: self.max_seq_length - 2]
|
||||
block = list(itertools.chain(*block))[: self.max_seq_length - title_pad_offset]
|
||||
|
||||
query_tokens, query_pad_mask = self.concat_and_pad_tokens(query)
|
||||
context_tokens, context_pad_mask = self.concat_and_pad_tokens(block, title)
|
||||
@@ -111,13 +132,13 @@ class ICTDataset(Dataset):
|
||||
block_data = sample_data.as_array()
|
||||
|
||||
sample = {
|
||||
'query_tokens': query_tokens,
|
||||
'query_mask': query_mask,
|
||||
'query_pad_mask': query_pad_mask,
|
||||
'context_tokens': context_tokens,
|
||||
'context_mask': context_mask,
|
||||
'context_pad_mask': context_pad_mask,
|
||||
'block_data': block_data,
|
||||
"query_tokens": query_tokens,
|
||||
"query_mask": query_mask,
|
||||
"query_pad_mask": query_pad_mask,
|
||||
"context_tokens": context_tokens,
|
||||
"context_mask": context_mask,
|
||||
"context_pad_mask": context_pad_mask,
|
||||
"block_data": block_data,
|
||||
}
|
||||
|
||||
return sample
|
||||
@@ -127,7 +148,7 @@ class ICTDataset(Dataset):
|
||||
block = [self.block_dataset[i] for i in range(start_idx, end_idx)]
|
||||
title = self.title_dataset[int(doc_idx)]
|
||||
|
||||
block = list(itertools.chain(*block))[:self.max_seq_length - (3 + len(title))]
|
||||
block = list(itertools.chain(*block))[: self.max_seq_length - (3 + len(title))]
|
||||
block_tokens, block_pad_mask = self.concat_and_pad_tokens(block, title)
|
||||
|
||||
return block_tokens, block_pad_mask
|
||||
|
@@ -27,17 +27,17 @@ def __best_fitting_dtype(vocab_size=None):
|
||||
|
||||
|
||||
def get_available_dataset_impl():
|
||||
return ['lazy', 'cached', 'mmap']
|
||||
return ["lazy", "cached", "mmap"]
|
||||
|
||||
|
||||
def infer_dataset_impl(path):
|
||||
if IndexedDataset.exists(path):
|
||||
with open(index_file_path(path), 'rb') as f:
|
||||
with open(index_file_path(path), "rb") as f:
|
||||
magic = f.read(8)
|
||||
if magic == IndexedDataset._HDR_MAGIC:
|
||||
return 'cached'
|
||||
return "cached"
|
||||
elif magic == MMapIndexedDataset.Index._HDR_MAGIC[:8]:
|
||||
return 'mmap'
|
||||
return "mmap"
|
||||
else:
|
||||
return None
|
||||
else:
|
||||
@@ -47,7 +47,7 @@ def infer_dataset_impl(path):
|
||||
|
||||
|
||||
def make_builder(out_file, impl, vocab_size=None):
|
||||
if impl == 'mmap':
|
||||
if impl == "mmap":
|
||||
return MMapIndexedDatasetBuilder(out_file, dtype=__best_fitting_dtype(vocab_size))
|
||||
else:
|
||||
return IndexedDatasetBuilder(out_file)
|
||||
@@ -58,20 +58,20 @@ def make_dataset(path, impl, skip_warmup=False):
|
||||
print(f"Dataset does not exist: {path}")
|
||||
print("Path should be a basename that both .idx and .bin can be appended to get full filenames.")
|
||||
return None
|
||||
if impl == 'infer':
|
||||
if impl == "infer":
|
||||
impl = infer_dataset_impl(path)
|
||||
if impl == 'lazy' and IndexedDataset.exists(path):
|
||||
if impl == "lazy" and IndexedDataset.exists(path):
|
||||
return IndexedDataset(path)
|
||||
elif impl == 'cached' and IndexedDataset.exists(path):
|
||||
elif impl == "cached" and IndexedDataset.exists(path):
|
||||
return IndexedCachedDataset(path)
|
||||
elif impl == 'mmap' and MMapIndexedDataset.exists(path):
|
||||
elif impl == "mmap" and MMapIndexedDataset.exists(path):
|
||||
return MMapIndexedDataset(path, skip_warmup)
|
||||
print(f"Unknown dataset implementation: {impl}")
|
||||
return None
|
||||
|
||||
|
||||
def dataset_exists(path, impl):
|
||||
if impl == 'mmap':
|
||||
if impl == "mmap":
|
||||
return MMapIndexedDataset.exists(path)
|
||||
else:
|
||||
return IndexedDataset.exists(path)
|
||||
@@ -98,11 +98,11 @@ def code(dtype):
|
||||
|
||||
|
||||
def index_file_path(prefix_path):
|
||||
return prefix_path + '.idx'
|
||||
return prefix_path + ".idx"
|
||||
|
||||
|
||||
def data_file_path(prefix_path):
|
||||
return prefix_path + '.bin'
|
||||
return prefix_path + ".bin"
|
||||
|
||||
|
||||
def create_doc_idx(sizes):
|
||||
@@ -115,7 +115,8 @@ def create_doc_idx(sizes):
|
||||
|
||||
class IndexedDataset(torch.utils.data.Dataset):
|
||||
"""Loader for IndexedDataset"""
|
||||
_HDR_MAGIC = b'TNTIDX\x00\x00'
|
||||
|
||||
_HDR_MAGIC = b"TNTIDX\x00\x00"
|
||||
|
||||
def __init__(self, path):
|
||||
super().__init__()
|
||||
@@ -124,27 +125,28 @@ class IndexedDataset(torch.utils.data.Dataset):
|
||||
self.read_index(path)
|
||||
|
||||
def read_index(self, path):
|
||||
with open(index_file_path(path), 'rb') as f:
|
||||
with open(index_file_path(path), "rb") as f:
|
||||
magic = f.read(8)
|
||||
assert magic == self._HDR_MAGIC, ('Index file doesn\'t match expected format. '
|
||||
'Make sure that --dataset-impl is configured properly.')
|
||||
assert magic == self._HDR_MAGIC, (
|
||||
"Index file doesn't match expected format. " "Make sure that --dataset-impl is configured properly."
|
||||
)
|
||||
version = f.read(8)
|
||||
assert struct.unpack('<Q', version) == (1,)
|
||||
code, self.element_size = struct.unpack('<QQ', f.read(16))
|
||||
assert struct.unpack("<Q", version) == (1,)
|
||||
code, self.element_size = struct.unpack("<QQ", f.read(16))
|
||||
self.dtype = dtypes[code]
|
||||
self._len, self.s = struct.unpack('<QQ', f.read(16))
|
||||
self.doc_count = struct.unpack('<Q', f.read(8))
|
||||
self._len, self.s = struct.unpack("<QQ", f.read(16))
|
||||
self.doc_count = struct.unpack("<Q", f.read(8))
|
||||
self.dim_offsets = read_longs(f, self._len + 1)
|
||||
self.data_offsets = read_longs(f, self._len + 1)
|
||||
self.sizes = read_longs(f, self.s)
|
||||
self.doc_idx = read_longs(f, self.doc_count)
|
||||
|
||||
def read_data(self, path):
|
||||
self.data_file = open(data_file_path(path), 'rb', buffering=0)
|
||||
self.data_file = open(data_file_path(path), "rb", buffering=0)
|
||||
|
||||
def check_index(self, i):
|
||||
if i < 0 or i >= self._len:
|
||||
raise IndexError('index out of range')
|
||||
raise IndexError("index out of range")
|
||||
|
||||
def __del__(self):
|
||||
if self.data_file:
|
||||
@@ -157,7 +159,7 @@ class IndexedDataset(torch.utils.data.Dataset):
|
||||
if isinstance(idx, int):
|
||||
i = idx
|
||||
self.check_index(i)
|
||||
tensor_size = self.sizes[self.dim_offsets[i]:self.dim_offsets[i + 1]]
|
||||
tensor_size = self.sizes[self.dim_offsets[i] : self.dim_offsets[i + 1]]
|
||||
a = np.empty(tensor_size, dtype=self.dtype)
|
||||
self.data_file.seek(self.data_offsets[i] * self.element_size)
|
||||
self.data_file.readinto(a)
|
||||
@@ -166,7 +168,7 @@ class IndexedDataset(torch.utils.data.Dataset):
|
||||
start, stop, step = idx.indices(len(self))
|
||||
if step != 1:
|
||||
raise ValueError("Slices into indexed_dataset must be contiguous")
|
||||
sizes = self.sizes[self.dim_offsets[start]:self.dim_offsets[stop]]
|
||||
sizes = self.sizes[self.dim_offsets[start] : self.dim_offsets[stop]]
|
||||
size = sum(sizes)
|
||||
a = np.empty(size, dtype=self.dtype)
|
||||
self.data_file.seek(self.data_offsets[start] * self.element_size)
|
||||
@@ -186,15 +188,14 @@ class IndexedDataset(torch.utils.data.Dataset):
|
||||
|
||||
@staticmethod
|
||||
def exists(path):
|
||||
return (os.path.exists(index_file_path(path)) and os.path.exists(data_file_path(path)))
|
||||
return os.path.exists(index_file_path(path)) and os.path.exists(data_file_path(path))
|
||||
|
||||
@property
|
||||
def supports_prefetch(self):
|
||||
return False # avoid prefetching to save memory
|
||||
return False # avoid prefetching to save memory
|
||||
|
||||
|
||||
class IndexedCachedDataset(IndexedDataset):
|
||||
|
||||
def __init__(self, path):
|
||||
super().__init__(path)
|
||||
self.cache = None
|
||||
@@ -219,7 +220,7 @@ class IndexedCachedDataset(IndexedDataset):
|
||||
for i in indices:
|
||||
self.cache_index[i] = ptx
|
||||
size = self.data_offsets[i + 1] - self.data_offsets[i]
|
||||
a = self.cache[ptx:ptx + size]
|
||||
a = self.cache[ptx : ptx + size]
|
||||
self.data_file.seek(self.data_offsets[i] * self.element_size)
|
||||
self.data_file.readinto(a)
|
||||
ptx += size
|
||||
@@ -233,10 +234,10 @@ class IndexedCachedDataset(IndexedDataset):
|
||||
if isinstance(idx, int):
|
||||
i = idx
|
||||
self.check_index(i)
|
||||
tensor_size = self.sizes[self.dim_offsets[i]:self.dim_offsets[i + 1]]
|
||||
tensor_size = self.sizes[self.dim_offsets[i] : self.dim_offsets[i + 1]]
|
||||
a = np.empty(tensor_size, dtype=self.dtype)
|
||||
ptx = self.cache_index[i]
|
||||
np.copyto(a, self.cache[ptx:ptx + a.size])
|
||||
np.copyto(a, self.cache[ptx : ptx + a.size])
|
||||
return a
|
||||
elif isinstance(idx, slice):
|
||||
# Hack just to make this work, can optimizer later if necessary
|
||||
@@ -250,7 +251,7 @@ class IndexedDatasetBuilder(object):
|
||||
element_sizes = {np.uint8: 1, np.int8: 1, np.int16: 2, np.int32: 4, np.int64: 8, float: 4, np.double: 8}
|
||||
|
||||
def __init__(self, out_file, dtype=np.int32):
|
||||
self.out_file = open(out_file, 'wb')
|
||||
self.out_file = open(out_file, "wb")
|
||||
self.dtype = dtype
|
||||
self.data_offsets = [0]
|
||||
self.dim_offsets = [0]
|
||||
@@ -280,7 +281,7 @@ class IndexedDatasetBuilder(object):
|
||||
for dim_offset in index.dim_offsets[1:]:
|
||||
self.dim_offsets.append(begin + dim_offset)
|
||||
|
||||
with open(data_file_path(another_file), 'rb') as f:
|
||||
with open(data_file_path(another_file), "rb") as f:
|
||||
while True:
|
||||
data = f.read(1024)
|
||||
if data:
|
||||
@@ -290,12 +291,12 @@ class IndexedDatasetBuilder(object):
|
||||
|
||||
def finalize(self, index_file):
|
||||
self.out_file.close()
|
||||
index = open(index_file, 'wb')
|
||||
index.write(b'TNTIDX\x00\x00')
|
||||
index.write(struct.pack('<Q', 1))
|
||||
index.write(struct.pack('<QQ', code(self.dtype), self.element_size))
|
||||
index.write(struct.pack('<QQ', len(self.data_offsets) - 1, len(self.sizes)))
|
||||
index.write(struct.pack('<Q', len(self.doc_idx)))
|
||||
index = open(index_file, "wb")
|
||||
index.write(b"TNTIDX\x00\x00")
|
||||
index.write(struct.pack("<Q", 1))
|
||||
index.write(struct.pack("<QQ", code(self.dtype), self.element_size))
|
||||
index.write(struct.pack("<QQ", len(self.data_offsets) - 1, len(self.sizes)))
|
||||
index.write(struct.pack("<Q", len(self.doc_idx)))
|
||||
write_longs(index, self.dim_offsets)
|
||||
write_longs(index, self.data_offsets)
|
||||
write_longs(index, self.sizes)
|
||||
@@ -304,27 +305,24 @@ class IndexedDatasetBuilder(object):
|
||||
|
||||
|
||||
def _warmup_mmap_file(path):
|
||||
with open(path, 'rb') as stream:
|
||||
with open(path, "rb") as stream:
|
||||
while stream.read(100 * 1024 * 1024):
|
||||
pass
|
||||
|
||||
|
||||
class MMapIndexedDataset(torch.utils.data.Dataset):
|
||||
|
||||
class Index(object):
|
||||
_HDR_MAGIC = b'MMIDIDX\x00\x00'
|
||||
_HDR_MAGIC = b"MMIDIDX\x00\x00"
|
||||
|
||||
@classmethod
|
||||
def writer(cls, path, dtype):
|
||||
|
||||
class _Writer(object):
|
||||
|
||||
def __enter__(self):
|
||||
self._file = open(path, 'wb')
|
||||
self._file = open(path, "wb")
|
||||
|
||||
self._file.write(cls._HDR_MAGIC)
|
||||
self._file.write(struct.pack('<Q', 1))
|
||||
self._file.write(struct.pack('<B', code(dtype)))
|
||||
self._file.write(struct.pack("<Q", 1))
|
||||
self._file.write(struct.pack("<B", code(dtype)))
|
||||
|
||||
return self
|
||||
|
||||
@@ -343,19 +341,19 @@ class MMapIndexedDataset(torch.utils.data.Dataset):
|
||||
def write(self, sizes, doc_idx):
|
||||
pointers = self._get_pointers(sizes)
|
||||
|
||||
self._file.write(struct.pack('<Q', len(sizes)))
|
||||
self._file.write(struct.pack('<Q', len(doc_idx)))
|
||||
self._file.write(struct.pack("<Q", len(sizes)))
|
||||
self._file.write(struct.pack("<Q", len(doc_idx)))
|
||||
|
||||
sizes = np.array(sizes, dtype=np.int32)
|
||||
self._file.write(sizes.tobytes(order='C'))
|
||||
self._file.write(sizes.tobytes(order="C"))
|
||||
del sizes
|
||||
|
||||
pointers = np.array(pointers, dtype=np.int64)
|
||||
self._file.write(pointers.tobytes(order='C'))
|
||||
self._file.write(pointers.tobytes(order="C"))
|
||||
del pointers
|
||||
|
||||
doc_idx = np.array(doc_idx, dtype=np.int64)
|
||||
self._file.write(doc_idx.tobytes(order='C'))
|
||||
self._file.write(doc_idx.tobytes(order="C"))
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
self._file.close()
|
||||
@@ -363,39 +361,41 @@ class MMapIndexedDataset(torch.utils.data.Dataset):
|
||||
return _Writer()
|
||||
|
||||
def __init__(self, path, skip_warmup=False):
|
||||
with open(path, 'rb') as stream:
|
||||
with open(path, "rb") as stream:
|
||||
magic_test = stream.read(9)
|
||||
assert self._HDR_MAGIC == magic_test, ('Index file doesn\'t match expected format. '
|
||||
'Make sure that --dataset-impl is configured properly.')
|
||||
version = struct.unpack('<Q', stream.read(8))
|
||||
assert self._HDR_MAGIC == magic_test, (
|
||||
"Index file doesn't match expected format. " "Make sure that --dataset-impl is configured properly."
|
||||
)
|
||||
version = struct.unpack("<Q", stream.read(8))
|
||||
assert (1,) == version
|
||||
|
||||
dtype_code, = struct.unpack('<B', stream.read(1))
|
||||
(dtype_code,) = struct.unpack("<B", stream.read(1))
|
||||
self._dtype = dtypes[dtype_code]
|
||||
self._dtype_size = self._dtype().itemsize
|
||||
|
||||
self._len = struct.unpack('<Q', stream.read(8))[0]
|
||||
self._doc_count = struct.unpack('<Q', stream.read(8))[0]
|
||||
self._len = struct.unpack("<Q", stream.read(8))[0]
|
||||
self._doc_count = struct.unpack("<Q", stream.read(8))[0]
|
||||
offset = stream.tell()
|
||||
|
||||
if not skip_warmup:
|
||||
print(" warming up index mmap file...")
|
||||
_warmup_mmap_file(path)
|
||||
|
||||
self._bin_buffer_mmap = np.memmap(path, mode='r', order='C')
|
||||
self._bin_buffer_mmap = np.memmap(path, mode="r", order="C")
|
||||
self._bin_buffer = memoryview(self._bin_buffer_mmap)
|
||||
print(" reading sizes...")
|
||||
self._sizes = np.frombuffer(self._bin_buffer, dtype=np.int32, count=self._len, offset=offset)
|
||||
print(" reading pointers...")
|
||||
self._pointers = np.frombuffer(self._bin_buffer,
|
||||
dtype=np.int64,
|
||||
count=self._len,
|
||||
offset=offset + self._sizes.nbytes)
|
||||
self._pointers = np.frombuffer(
|
||||
self._bin_buffer, dtype=np.int64, count=self._len, offset=offset + self._sizes.nbytes
|
||||
)
|
||||
print(" reading document index...")
|
||||
self._doc_idx = np.frombuffer(self._bin_buffer,
|
||||
dtype=np.int64,
|
||||
count=self._doc_count,
|
||||
offset=offset + self._sizes.nbytes + self._pointers.nbytes)
|
||||
self._doc_idx = np.frombuffer(
|
||||
self._bin_buffer,
|
||||
dtype=np.int64,
|
||||
count=self._doc_count,
|
||||
offset=offset + self._sizes.nbytes + self._pointers.nbytes,
|
||||
)
|
||||
|
||||
def __del__(self):
|
||||
self._bin_buffer_mmap._mmap.close()
|
||||
@@ -443,7 +443,7 @@ class MMapIndexedDataset(torch.utils.data.Dataset):
|
||||
print(" warming up data mmap file...")
|
||||
_warmup_mmap_file(data_file_path(self._path))
|
||||
print(" creating numpy buffer of mmap...")
|
||||
self._bin_buffer_mmap = np.memmap(data_file_path(self._path), mode='r', order='C')
|
||||
self._bin_buffer_mmap = np.memmap(data_file_path(self._path), mode="r", order="C")
|
||||
print(" creating memory view of numpy buffer...")
|
||||
self._bin_buffer = memoryview(self._bin_buffer_mmap)
|
||||
|
||||
@@ -474,7 +474,7 @@ class MMapIndexedDataset(torch.utils.data.Dataset):
|
||||
return sents
|
||||
|
||||
def get(self, idx, offset=0, length=None):
|
||||
""" Retrieves a single item from the dataset with the option to only
|
||||
"""Retrieves a single item from the dataset with the option to only
|
||||
return a portion of the item.
|
||||
|
||||
get(idx) is the same as [idx] but get() does not support slicing.
|
||||
@@ -506,20 +506,19 @@ class MMapIndexedDataset(torch.utils.data.Dataset):
|
||||
|
||||
@staticmethod
|
||||
def exists(path):
|
||||
return (os.path.exists(index_file_path(path)) and os.path.exists(data_file_path(path)))
|
||||
return os.path.exists(index_file_path(path)) and os.path.exists(data_file_path(path))
|
||||
|
||||
|
||||
class MMapIndexedDatasetBuilder(object):
|
||||
|
||||
def __init__(self, out_file, dtype=np.int64):
|
||||
self._data_file = open(out_file, 'wb')
|
||||
self._data_file = open(out_file, "wb")
|
||||
self._dtype = dtype
|
||||
self._sizes = []
|
||||
self._doc_idx = [0]
|
||||
|
||||
def add_item(self, tensor):
|
||||
np_array = np.array(tensor.numpy(), dtype=self._dtype)
|
||||
self._data_file.write(np_array.tobytes(order='C'))
|
||||
self._data_file.write(np_array.tobytes(order="C"))
|
||||
self._sizes.append(np_array.size)
|
||||
|
||||
def end_document(self):
|
||||
@@ -534,7 +533,7 @@ class MMapIndexedDatasetBuilder(object):
|
||||
self._sizes.append(size)
|
||||
|
||||
# Concatenate data
|
||||
with open(data_file_path(another_file), 'rb') as f:
|
||||
with open(data_file_path(another_file), "rb") as f:
|
||||
shutil.copyfileobj(f, self._data_file)
|
||||
|
||||
def finalize(self, index_file):
|
||||
|
@@ -2,13 +2,12 @@
|
||||
# put some code used during development and manual testing of
|
||||
# indexed_dataset.
|
||||
|
||||
from megatron.data import indexed_dataset
|
||||
from megatron.tokenizer import build_tokenizer
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
|
||||
import torch
|
||||
from megatron.data import indexed_dataset
|
||||
from megatron.tokenizer import build_tokenizer
|
||||
|
||||
script_dir = os.path.dirname(os.path.realpath(__file__))
|
||||
sys.path.append(os.path.join(script_dir, "../../../"))
|
||||
@@ -42,7 +41,7 @@ def test_indexed_dataset(args):
|
||||
|
||||
def test_indexed_dataset_get(args):
|
||||
ds = indexed_dataset.make_dataset(args.data, args.dataset_impl)
|
||||
tokenizer = build_tokenizer(args)
|
||||
build_tokenizer(args)
|
||||
size = ds.sizes[0]
|
||||
print(f"size: {size}")
|
||||
full = ds.get(0)
|
||||
@@ -61,6 +60,7 @@ def test_indexed_dataset_get(args):
|
||||
print(part)
|
||||
# print(tokenizer.detokenize(part.data.tolist()))
|
||||
|
||||
|
||||
# def test_albert_dataset(args):
|
||||
# # tokenizer = FullBertTokenizer(args.vocab, do_lower_case=True)
|
||||
# # idataset = indexed_dataset.make_dataset(args.data, args.dataset_impl)
|
||||
@@ -81,34 +81,27 @@ def test_indexed_dataset_get(args):
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--data', type=str, help='prefix to data files')
|
||||
parser.add_argument('--dataset-impl', type=str, default='infer',
|
||||
choices=['lazy', 'cached', 'mmap', 'infer'])
|
||||
parser.add_argument('--count', type=int, default=10,
|
||||
help='Number of samples/documents to print')
|
||||
parser.add_argument("--data", type=str, help="prefix to data files")
|
||||
parser.add_argument("--dataset-impl", type=str, default="infer", choices=["lazy", "cached", "mmap", "infer"])
|
||||
parser.add_argument("--count", type=int, default=10, help="Number of samples/documents to print")
|
||||
|
||||
group = parser.add_argument_group(title='tokenizer')
|
||||
group.add_argument('--tokenizer-type', type=str, required=True,
|
||||
choices=['BertWordPieceLowerCase',
|
||||
'GPT2BPETokenizer'],
|
||||
help='What type of tokenizer to use.')
|
||||
group.add_argument('--vocab-file', type=str, default=None,
|
||||
help='Path to the vocab file')
|
||||
group.add_argument('--merge-file', type=str, default=None,
|
||||
help='Path to the BPE merge file (if necessary).')
|
||||
group = parser.add_argument_group(title="tokenizer")
|
||||
group.add_argument(
|
||||
"--tokenizer-type",
|
||||
type=str,
|
||||
required=True,
|
||||
choices=["BertWordPieceLowerCase", "GPT2BPETokenizer"],
|
||||
help="What type of tokenizer to use.",
|
||||
)
|
||||
group.add_argument("--vocab-file", type=str, default=None, help="Path to the vocab file")
|
||||
group.add_argument("--merge-file", type=str, default=None, help="Path to the BPE merge file (if necessary).")
|
||||
|
||||
parser.add_argument('--epochs', type=int, default=5,
|
||||
help='Number of epochs to plan for')
|
||||
parser.add_argument('--max-num-samples', type=int, default=None,
|
||||
help='Maximum number of samples to plan for')
|
||||
parser.add_argument('--masked-lm-prob', type=float, default=0.15,
|
||||
help='probability of masking tokens')
|
||||
parser.add_argument('--seq-length', type=int, default=512,
|
||||
help='maximum sequence length')
|
||||
parser.add_argument('--short-seq-prob', type=float, default=0.1,
|
||||
help='probability of creating a short sequence')
|
||||
parser.add_argument('--seed', type=int, default=1234,
|
||||
help='random seed')
|
||||
parser.add_argument("--epochs", type=int, default=5, help="Number of epochs to plan for")
|
||||
parser.add_argument("--max-num-samples", type=int, default=None, help="Maximum number of samples to plan for")
|
||||
parser.add_argument("--masked-lm-prob", type=float, default=0.15, help="probability of masking tokens")
|
||||
parser.add_argument("--seq-length", type=int, default=512, help="maximum sequence length")
|
||||
parser.add_argument("--short-seq-prob", type=float, default=0.1, help="probability of creating a short sequence")
|
||||
parser.add_argument("--seed", type=int, default=1234, help="random seed")
|
||||
args = parser.parse_args()
|
||||
args.rank = 0
|
||||
args.make_vocab_size_divisible_by = 128
|
||||
@@ -117,7 +110,7 @@ def main():
|
||||
if args.dataset_impl == "infer":
|
||||
args.dataset_impl = indexed_dataset.infer_dataset_impl(args.data)
|
||||
|
||||
# test_albert_dataset(args)
|
||||
# test_albert_dataset(args)
|
||||
test_indexed_dataset_get(args)
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user