mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-10-23 16:08:55 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -4,7 +4,7 @@ from colossalai.legacy.amp import AMP_TYPE
|
||||
TRAIN_ITERS = 10
|
||||
DECAY_ITERS = 4
|
||||
WARMUP_FRACTION = 0.01
|
||||
GLOBAL_BATCH_SIZE = 32 # dp world size * sentences per GPU
|
||||
GLOBAL_BATCH_SIZE = 32 # dp world size * sentences per GPU
|
||||
EVAL_ITERS = 10
|
||||
EVAL_INTERVAL = 10
|
||||
LR = 0.0001
|
||||
@@ -28,8 +28,8 @@ SEED = 1234
|
||||
NUM_MICRO_BATCHES = 4
|
||||
|
||||
# colossalai config
|
||||
parallel = dict(pipeline=1, tensor=dict(size=2, mode='sequence'))
|
||||
parallel = dict(pipeline=1, tensor=dict(size=2, mode="sequence"))
|
||||
|
||||
fp16 = dict(mode=AMP_TYPE.NAIVE, verbose=True)
|
||||
|
||||
gradient_handler = [dict(type='SequenceParallelGradientHandler')]
|
||||
gradient_handler = [dict(type="SequenceParallelGradientHandler")]
|
||||
|
||||
@@ -15,16 +15,13 @@ def cyclic_iter(iter):
|
||||
yield x
|
||||
|
||||
|
||||
def build_train_valid_test_data_iterators(train_iters,
|
||||
global_batch_size,
|
||||
eval_interval,
|
||||
eval_iters,
|
||||
dataloader_type='single',
|
||||
**kwargs):
|
||||
def build_train_valid_test_data_iterators(
|
||||
train_iters, global_batch_size, eval_interval, eval_iters, dataloader_type="single", **kwargs
|
||||
):
|
||||
(train_dataloader, valid_dataloader, test_dataloader) = (None, None, None)
|
||||
|
||||
logger = get_dist_logger()
|
||||
logger.info('> building train, validation, and test datasets ...', ranks=[0])
|
||||
logger.info("> building train, validation, and test datasets ...", ranks=[0])
|
||||
|
||||
# Backward compatibility, assume fixed batch size.
|
||||
# if iteration > 0 and consumed_train_samples == 0:
|
||||
@@ -38,29 +35,29 @@ def build_train_valid_test_data_iterators(train_iters,
|
||||
|
||||
# Data loader only on rank 0 of each model parallel group.
|
||||
if not gpc.is_initialized(ParallelMode.TENSOR) or gpc.get_local_rank(ParallelMode.TENSOR) == 0:
|
||||
|
||||
# Number of train/valid/test samples.
|
||||
train_samples = train_iters * global_batch_size
|
||||
eval_iters_ = (train_iters // eval_interval + 1) * eval_iters
|
||||
test_iters = eval_iters
|
||||
train_val_test_num_samples = [train_samples, eval_iters_ * global_batch_size, test_iters * global_batch_size]
|
||||
logger.info(' > datasets target sizes (minimum size):')
|
||||
logger.info(' train: {}'.format(train_val_test_num_samples[0]), ranks=[0])
|
||||
logger.info(' validation: {}'.format(train_val_test_num_samples[1]), ranks=[0])
|
||||
logger.info(' test: {}'.format(train_val_test_num_samples[2]), ranks=[0])
|
||||
logger.info(" > datasets target sizes (minimum size):")
|
||||
logger.info(" train: {}".format(train_val_test_num_samples[0]), ranks=[0])
|
||||
logger.info(" validation: {}".format(train_val_test_num_samples[1]), ranks=[0])
|
||||
logger.info(" test: {}".format(train_val_test_num_samples[2]), ranks=[0])
|
||||
|
||||
# Build the datasets.
|
||||
train_ds, valid_ds, test_ds = build_train_valid_test_datasets(
|
||||
train_valid_test_num_samples=train_val_test_num_samples, **kwargs)
|
||||
train_valid_test_num_samples=train_val_test_num_samples, **kwargs
|
||||
)
|
||||
|
||||
# Build dataloaders.
|
||||
dp_size = gpc.get_world_size(ParallelMode.DATA)
|
||||
train_dataloader = build_pretraining_data_loader(train_ds,
|
||||
consumed_samples=0,
|
||||
micro_batch_size=global_batch_size // dp_size)
|
||||
valid_dataloader = build_pretraining_data_loader(valid_ds,
|
||||
consumed_samples=0,
|
||||
micro_batch_size=global_batch_size // dp_size)
|
||||
train_dataloader = build_pretraining_data_loader(
|
||||
train_ds, consumed_samples=0, micro_batch_size=global_batch_size // dp_size
|
||||
)
|
||||
valid_dataloader = build_pretraining_data_loader(
|
||||
valid_ds, consumed_samples=0, micro_batch_size=global_batch_size // dp_size
|
||||
)
|
||||
test_dataloader = build_pretraining_data_loader(test_ds, 0, micro_batch_size=global_batch_size // dp_size)
|
||||
|
||||
# Flags to know if we need to do training/validation/testing.
|
||||
@@ -73,29 +70,26 @@ def build_train_valid_test_data_iterators(train_iters,
|
||||
flags = torch.cuda.LongTensor([0, 0, 0])
|
||||
|
||||
# Broadcast num tokens.
|
||||
torch.distributed.broadcast(flags,
|
||||
gpc.get_ranks_in_group(ParallelMode.TENSOR)[0],
|
||||
group=gpc.get_group(ParallelMode.TENSOR))
|
||||
torch.distributed.broadcast(
|
||||
flags, gpc.get_ranks_in_group(ParallelMode.TENSOR)[0], group=gpc.get_group(ParallelMode.TENSOR)
|
||||
)
|
||||
|
||||
# Build iterators.
|
||||
dl_type = dataloader_type
|
||||
assert dl_type in ['single', 'cyclic']
|
||||
assert dl_type in ["single", "cyclic"]
|
||||
|
||||
if train_dataloader is not None:
|
||||
train_data_iterator = iter(train_dataloader) if dl_type == 'single' \
|
||||
else iter(cyclic_iter(train_dataloader))
|
||||
train_data_iterator = iter(train_dataloader) if dl_type == "single" else iter(cyclic_iter(train_dataloader))
|
||||
else:
|
||||
train_data_iterator = None
|
||||
|
||||
if valid_dataloader is not None:
|
||||
valid_data_iterator = iter(valid_dataloader) if dl_type == 'single' \
|
||||
else iter(cyclic_iter(valid_dataloader))
|
||||
valid_data_iterator = iter(valid_dataloader) if dl_type == "single" else iter(cyclic_iter(valid_dataloader))
|
||||
else:
|
||||
valid_data_iterator = None
|
||||
|
||||
if test_dataloader is not None:
|
||||
test_data_iterator = iter(test_dataloader) if dl_type == 'single' \
|
||||
else iter(cyclic_iter(test_dataloader))
|
||||
test_data_iterator = iter(test_dataloader) if dl_type == "single" else iter(cyclic_iter(test_dataloader))
|
||||
else:
|
||||
test_data_iterator = None
|
||||
|
||||
|
||||
@@ -15,7 +15,7 @@ def _build_key_size_numel_dictionaries(keys, data):
|
||||
if not gpc.is_initialized(ParallelMode.TENSOR) or gpc.get_local_rank(ParallelMode.TENSOR) == 0:
|
||||
offset = 0
|
||||
for key in keys:
|
||||
assert data[key].dim() < max_dim, 'you should increase MAX_DATA_DIM'
|
||||
assert data[key].dim() < max_dim, "you should increase MAX_DATA_DIM"
|
||||
size = data[key].size()
|
||||
for i, s in enumerate(size):
|
||||
sizes[i + offset] = s
|
||||
@@ -23,9 +23,9 @@ def _build_key_size_numel_dictionaries(keys, data):
|
||||
|
||||
# Move to GPU and broadcast.
|
||||
sizes_cuda = torch.cuda.LongTensor(sizes)
|
||||
torch.distributed.broadcast(sizes_cuda,
|
||||
gpc.get_ranks_in_group(ParallelMode.TENSOR)[0],
|
||||
group=gpc.get_group(ParallelMode.TENSOR))
|
||||
torch.distributed.broadcast(
|
||||
sizes_cuda, gpc.get_ranks_in_group(ParallelMode.TENSOR)[0], group=gpc.get_group(ParallelMode.TENSOR)
|
||||
)
|
||||
|
||||
# Move back to cpu and unpack.
|
||||
sizes_cpu = sizes_cuda.cpu()
|
||||
@@ -73,9 +73,9 @@ def broadcast_data(keys, data, datatype):
|
||||
flatten_data = torch.empty(total_numel, device=torch.cuda.current_device(), dtype=datatype)
|
||||
|
||||
# Broadcast
|
||||
torch.distributed.broadcast(flatten_data,
|
||||
gpc.get_ranks_in_group(ParallelMode.TENSOR)[0],
|
||||
group=gpc.get_group(ParallelMode.TENSOR))
|
||||
torch.distributed.broadcast(
|
||||
flatten_data, gpc.get_ranks_in_group(ParallelMode.TENSOR)[0], group=gpc.get_group(ParallelMode.TENSOR)
|
||||
)
|
||||
|
||||
# Unpack
|
||||
output = {}
|
||||
@@ -93,7 +93,7 @@ def get_batch(data_iterator):
|
||||
"""Build the batch."""
|
||||
|
||||
# Items and their type.
|
||||
keys = ['text', 'types', 'labels', 'is_random', 'loss_mask', 'padding_mask']
|
||||
keys = ["text", "types", "labels", "is_random", "loss_mask", "padding_mask"]
|
||||
datatype = torch.int64
|
||||
|
||||
# Broadcast data.
|
||||
@@ -104,12 +104,12 @@ def get_batch(data_iterator):
|
||||
data_b = broadcast_data(keys, data, datatype)
|
||||
|
||||
# Unpack.
|
||||
tokens = data_b['text'].long()
|
||||
types = data_b['types'].long()
|
||||
sentence_order = data_b['is_random'].long()
|
||||
loss_mask = data_b['loss_mask'].float()
|
||||
lm_labels = data_b['labels'].long()
|
||||
padding_mask = data_b['padding_mask'].long()
|
||||
tokens = data_b["text"].long()
|
||||
types = data_b["types"].long()
|
||||
sentence_order = data_b["is_random"].long()
|
||||
loss_mask = data_b["loss_mask"].float()
|
||||
lm_labels = data_b["labels"].long()
|
||||
padding_mask = data_b["padding_mask"].long()
|
||||
|
||||
return tokens, types, sentence_order, loss_mask, lm_labels, padding_mask
|
||||
|
||||
@@ -118,7 +118,7 @@ def get_batch_for_sequence_parallel(data_iterator):
|
||||
"""Build the batch."""
|
||||
|
||||
# Items and their type.
|
||||
keys = ['text', 'types', 'labels', 'is_random', 'loss_mask', 'padding_mask']
|
||||
keys = ["text", "types", "labels", "is_random", "loss_mask", "padding_mask"]
|
||||
datatype = torch.int64
|
||||
|
||||
# Broadcast data.
|
||||
@@ -134,24 +134,23 @@ def get_batch_for_sequence_parallel(data_iterator):
|
||||
global_rank = torch.distributed.get_rank()
|
||||
local_world_size = 1 if not gpc.is_initialized(ParallelMode.TENSOR) else gpc.get_world_size(ParallelMode.TENSOR)
|
||||
local_rank = global_rank % local_world_size
|
||||
seq_length = data_b['text'].size(1)
|
||||
seq_length = data_b["text"].size(1)
|
||||
sub_seq_length = seq_length // local_world_size
|
||||
sub_seq_start = local_rank * sub_seq_length
|
||||
sub_seq_end = (local_rank + 1) * sub_seq_length
|
||||
#
|
||||
# # Unpack.
|
||||
tokens = data_b['text'][:, sub_seq_start:sub_seq_end].long()
|
||||
types = data_b['types'][:, sub_seq_start:sub_seq_end].long()
|
||||
sentence_order = data_b['is_random'].long()
|
||||
loss_mask = data_b['loss_mask'][:, sub_seq_start:sub_seq_end].float()
|
||||
lm_labels = data_b['labels'][:, sub_seq_start:sub_seq_end].long()
|
||||
padding_mask = data_b['padding_mask'].long()
|
||||
tokens = data_b["text"][:, sub_seq_start:sub_seq_end].long()
|
||||
types = data_b["types"][:, sub_seq_start:sub_seq_end].long()
|
||||
sentence_order = data_b["is_random"].long()
|
||||
loss_mask = data_b["loss_mask"][:, sub_seq_start:sub_seq_end].float()
|
||||
lm_labels = data_b["labels"][:, sub_seq_start:sub_seq_end].long()
|
||||
padding_mask = data_b["padding_mask"].long()
|
||||
|
||||
return tokens, types, sentence_order, loss_mask, lm_labels, padding_mask
|
||||
|
||||
|
||||
class SequenceParallelDataIterator:
|
||||
|
||||
def __init__(self, data_iter):
|
||||
self.data_iter = data_iter
|
||||
|
||||
|
||||
@@ -41,10 +41,19 @@ except:
|
||||
|
||||
|
||||
class BertDataset(Dataset):
|
||||
|
||||
def __init__(self, name, indexed_dataset, data_prefix, num_epochs, max_num_samples, masked_lm_prob, max_seq_length,
|
||||
short_seq_prob, seed, binary_head):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name,
|
||||
indexed_dataset,
|
||||
data_prefix,
|
||||
num_epochs,
|
||||
max_num_samples,
|
||||
masked_lm_prob,
|
||||
max_seq_length,
|
||||
short_seq_prob,
|
||||
seed,
|
||||
binary_head,
|
||||
):
|
||||
# Params to store.
|
||||
self.name = name
|
||||
self.seed = seed
|
||||
@@ -61,11 +70,12 @@ class BertDataset(Dataset):
|
||||
data_prefix,
|
||||
num_epochs,
|
||||
max_num_samples,
|
||||
self.max_seq_length - 3, # account for added tokens,
|
||||
self.max_seq_length - 3, # account for added tokens,
|
||||
short_seq_prob,
|
||||
self.seed,
|
||||
self.name,
|
||||
self.binary_head)
|
||||
self.binary_head,
|
||||
)
|
||||
|
||||
# Vocab stuff.
|
||||
tokenizer = get_tokenizer()
|
||||
@@ -89,7 +99,7 @@ class BertDataset(Dataset):
|
||||
return build_training_sample(
|
||||
sample,
|
||||
seq_length,
|
||||
self.max_seq_length, # needed for padding
|
||||
self.max_seq_length, # needed for padding
|
||||
self.vocab_id_list,
|
||||
self.vocab_id_to_token_dict,
|
||||
self.cls_id,
|
||||
@@ -98,37 +108,39 @@ class BertDataset(Dataset):
|
||||
self.pad_id,
|
||||
self.masked_lm_prob,
|
||||
np_rng,
|
||||
self.binary_head)
|
||||
self.binary_head,
|
||||
)
|
||||
|
||||
|
||||
def get_samples_mapping_(indexed_dataset, data_prefix, num_epochs, max_num_samples, max_seq_length, short_seq_prob,
|
||||
seed, name, binary_head):
|
||||
def get_samples_mapping_(
|
||||
indexed_dataset, data_prefix, num_epochs, max_num_samples, max_seq_length, short_seq_prob, seed, name, binary_head
|
||||
):
|
||||
logger = get_dist_logger()
|
||||
if not num_epochs:
|
||||
if not max_num_samples:
|
||||
raise ValueError("Need to specify either max_num_samples "
|
||||
"or num_epochs")
|
||||
raise ValueError("Need to specify either max_num_samples " "or num_epochs")
|
||||
num_epochs = np.iinfo(np.int32).max - 1
|
||||
if not max_num_samples:
|
||||
max_num_samples = np.iinfo(np.int64).max - 1
|
||||
|
||||
# Filename of the index mapping
|
||||
indexmap_filename = data_prefix
|
||||
indexmap_filename += '_{}_indexmap'.format(name)
|
||||
indexmap_filename += "_{}_indexmap".format(name)
|
||||
if num_epochs != (np.iinfo(np.int32).max - 1):
|
||||
indexmap_filename += '_{}ep'.format(num_epochs)
|
||||
indexmap_filename += "_{}ep".format(num_epochs)
|
||||
if max_num_samples != (np.iinfo(np.int64).max - 1):
|
||||
indexmap_filename += '_{}mns'.format(max_num_samples)
|
||||
indexmap_filename += '_{}msl'.format(max_seq_length)
|
||||
indexmap_filename += '_{:0.2f}ssp'.format(short_seq_prob)
|
||||
indexmap_filename += '_{}s'.format(seed)
|
||||
indexmap_filename += '.npy'
|
||||
indexmap_filename += "_{}mns".format(max_num_samples)
|
||||
indexmap_filename += "_{}msl".format(max_seq_length)
|
||||
indexmap_filename += "_{:0.2f}ssp".format(short_seq_prob)
|
||||
indexmap_filename += "_{}s".format(seed)
|
||||
indexmap_filename += ".npy"
|
||||
|
||||
# Build the indexed mapping if not exist.
|
||||
if torch.distributed.get_rank() == 0 and \
|
||||
not os.path.isfile(indexmap_filename):
|
||||
print(' > WARNING: could not find index map file {}, building '
|
||||
'the indices on rank 0 ...'.format(indexmap_filename))
|
||||
if torch.distributed.get_rank() == 0 and not os.path.isfile(indexmap_filename):
|
||||
print(
|
||||
" > WARNING: could not find index map file {}, building "
|
||||
"the indices on rank 0 ...".format(indexmap_filename)
|
||||
)
|
||||
|
||||
# Make sure the types match the helpers input types.
|
||||
assert indexed_dataset.doc_idx.dtype == np.int64
|
||||
@@ -137,18 +149,27 @@ def get_samples_mapping_(indexed_dataset, data_prefix, num_epochs, max_num_sampl
|
||||
# Build samples mapping
|
||||
verbose = torch.distributed.get_rank() == 0
|
||||
start_time = time.time()
|
||||
logger.info('\n > building samples index mapping for {} ...'.format(name), ranks=[0])
|
||||
logger.info("\n > building samples index mapping for {} ...".format(name), ranks=[0])
|
||||
# First compile and then import.
|
||||
samples_mapping = helpers.build_mapping(indexed_dataset.doc_idx, indexed_dataset.sizes, num_epochs,
|
||||
max_num_samples, max_seq_length, short_seq_prob, seed, verbose,
|
||||
2 if binary_head else 1)
|
||||
logger.info('\n > done building samples index maping', ranks=[0])
|
||||
samples_mapping = helpers.build_mapping(
|
||||
indexed_dataset.doc_idx,
|
||||
indexed_dataset.sizes,
|
||||
num_epochs,
|
||||
max_num_samples,
|
||||
max_seq_length,
|
||||
short_seq_prob,
|
||||
seed,
|
||||
verbose,
|
||||
2 if binary_head else 1,
|
||||
)
|
||||
logger.info("\n > done building samples index maping", ranks=[0])
|
||||
np.save(indexmap_filename, samples_mapping, allow_pickle=True)
|
||||
logger.info('\n > saved the index mapping in {}'.format(indexmap_filename), ranks=[0])
|
||||
logger.info("\n > saved the index mapping in {}".format(indexmap_filename), ranks=[0])
|
||||
# Make sure all the ranks have built the mapping
|
||||
logger.info('\n > elapsed time to build and save samples mapping '
|
||||
'(seconds): {:4f}'.format(time.time() - start_time),
|
||||
ranks=[0])
|
||||
logger.info(
|
||||
"\n > elapsed time to build and save samples mapping " "(seconds): {:4f}".format(time.time() - start_time),
|
||||
ranks=[0],
|
||||
)
|
||||
# This should be a barrier but nccl barrier assumes
|
||||
# device_index=rank which is not the case for model
|
||||
# parallel case
|
||||
@@ -156,22 +177,38 @@ def get_samples_mapping_(indexed_dataset, data_prefix, num_epochs, max_num_sampl
|
||||
torch.distributed.all_reduce(counts, group=gpc.get_group(ParallelMode.DATA))
|
||||
if gpc.is_initialized(ParallelMode.PIPELINE):
|
||||
torch.distributed.all_reduce(counts, group=gpc.get_group(ParallelMode.PIPELINE))
|
||||
assert counts[0].item() == (torch.distributed.get_world_size() //
|
||||
torch.distributed.get_world_size(group=gpc.get_group(ParallelMode.SEQUENCE)))
|
||||
assert counts[0].item() == (
|
||||
torch.distributed.get_world_size()
|
||||
// torch.distributed.get_world_size(group=gpc.get_group(ParallelMode.SEQUENCE))
|
||||
)
|
||||
|
||||
# Load indexed dataset.
|
||||
start_time = time.time()
|
||||
samples_mapping = np.load(indexmap_filename, allow_pickle=True, mmap_mode='r')
|
||||
logger.info('\n > loading indexed mapping from {}'.format(indexmap_filename) +
|
||||
'\n loaded indexed file in {:3.3f} seconds'.format(time.time() - start_time) +
|
||||
'\n total number of samples: {}'.format(samples_mapping.shape[0]),
|
||||
ranks=[0])
|
||||
samples_mapping = np.load(indexmap_filename, allow_pickle=True, mmap_mode="r")
|
||||
logger.info(
|
||||
"\n > loading indexed mapping from {}".format(indexmap_filename)
|
||||
+ "\n loaded indexed file in {:3.3f} seconds".format(time.time() - start_time)
|
||||
+ "\n total number of samples: {}".format(samples_mapping.shape[0]),
|
||||
ranks=[0],
|
||||
)
|
||||
|
||||
return samples_mapping
|
||||
|
||||
|
||||
def build_training_sample(sample, target_seq_length, max_seq_length, vocab_id_list, vocab_id_to_token_dict, cls_id,
|
||||
sep_id, mask_id, pad_id, masked_lm_prob, np_rng, binary_head):
|
||||
def build_training_sample(
|
||||
sample,
|
||||
target_seq_length,
|
||||
max_seq_length,
|
||||
vocab_id_list,
|
||||
vocab_id_to_token_dict,
|
||||
cls_id,
|
||||
sep_id,
|
||||
mask_id,
|
||||
pad_id,
|
||||
masked_lm_prob,
|
||||
np_rng,
|
||||
binary_head,
|
||||
):
|
||||
"""Build training sample.
|
||||
|
||||
Arguments:
|
||||
@@ -215,22 +252,30 @@ def build_training_sample(sample, target_seq_length, max_seq_length, vocab_id_li
|
||||
|
||||
# Masking.
|
||||
max_predictions_per_seq = masked_lm_prob * max_num_tokens
|
||||
(tokens, masked_positions, masked_labels,
|
||||
_) = create_masked_lm_predictions(tokens, vocab_id_list, vocab_id_to_token_dict, masked_lm_prob, cls_id, sep_id,
|
||||
mask_id, max_predictions_per_seq, np_rng)
|
||||
(tokens, masked_positions, masked_labels, _) = create_masked_lm_predictions(
|
||||
tokens,
|
||||
vocab_id_list,
|
||||
vocab_id_to_token_dict,
|
||||
masked_lm_prob,
|
||||
cls_id,
|
||||
sep_id,
|
||||
mask_id,
|
||||
max_predictions_per_seq,
|
||||
np_rng,
|
||||
)
|
||||
|
||||
# Padding.
|
||||
tokens_np, tokentypes_np, labels_np, padding_mask_np, loss_mask_np \
|
||||
= pad_and_convert_to_numpy(tokens, tokentypes, masked_positions,
|
||||
masked_labels, pad_id, max_seq_length)
|
||||
tokens_np, tokentypes_np, labels_np, padding_mask_np, loss_mask_np = pad_and_convert_to_numpy(
|
||||
tokens, tokentypes, masked_positions, masked_labels, pad_id, max_seq_length
|
||||
)
|
||||
|
||||
train_sample = {
|
||||
'text': tokens_np,
|
||||
'types': tokentypes_np,
|
||||
'labels': labels_np,
|
||||
'is_random': int(is_next_random),
|
||||
'loss_mask': loss_mask_np,
|
||||
'padding_mask': padding_mask_np,
|
||||
'truncated': int(truncated)
|
||||
"text": tokens_np,
|
||||
"types": tokentypes_np,
|
||||
"labels": labels_np,
|
||||
"is_random": int(is_next_random),
|
||||
"loss_mask": loss_mask_np,
|
||||
"padding_mask": padding_mask_np,
|
||||
"truncated": int(truncated),
|
||||
}
|
||||
return train_sample
|
||||
|
||||
@@ -22,9 +22,7 @@ import torch
|
||||
|
||||
|
||||
class BlendableDataset(torch.utils.data.Dataset):
|
||||
|
||||
def __init__(self, datasets, weights):
|
||||
|
||||
self.datasets = datasets
|
||||
num_datasets = len(datasets)
|
||||
assert num_datasets == len(weights)
|
||||
@@ -46,12 +44,16 @@ class BlendableDataset(torch.utils.data.Dataset):
|
||||
self.dataset_sample_index = np.zeros(self.size, dtype=np.int64)
|
||||
|
||||
from . import helpers
|
||||
helpers.build_blending_indices(self.dataset_index,
|
||||
self.dataset_sample_index,
|
||||
weights, num_datasets, self.size,
|
||||
torch.distributed.get_rank() == 0)
|
||||
print('> elapsed time for building blendable dataset indices: '
|
||||
'{:.2f} (sec)'.format(time.time() - start_time))
|
||||
|
||||
helpers.build_blending_indices(
|
||||
self.dataset_index,
|
||||
self.dataset_sample_index,
|
||||
weights,
|
||||
num_datasets,
|
||||
self.size,
|
||||
torch.distributed.get_rank() == 0,
|
||||
)
|
||||
print("> elapsed time for building blendable dataset indices: " "{:.2f} (sec)".format(time.time() - start_time))
|
||||
|
||||
def __len__(self):
|
||||
return self.size
|
||||
|
||||
@@ -1,29 +1,34 @@
|
||||
from .blendable_dataset import BlendableDataset
|
||||
from .dataset_utils import get_datasets_weights_and_num_samples, get_indexed_dataset_, get_train_valid_test_split_
|
||||
from .bert_dataset import BertDataset
|
||||
from colossalai.logging import get_dist_logger
|
||||
|
||||
DSET_TYPE_BERT = 'standard_bert'
|
||||
DSET_TYPE_ICT = 'ict'
|
||||
DSET_TYPE_T5 = 't5'
|
||||
from .bert_dataset import BertDataset
|
||||
from .blendable_dataset import BlendableDataset
|
||||
from .dataset_utils import get_datasets_weights_and_num_samples, get_indexed_dataset_, get_train_valid_test_split_
|
||||
|
||||
DSET_TYPE_BERT = "standard_bert"
|
||||
DSET_TYPE_ICT = "ict"
|
||||
DSET_TYPE_T5 = "t5"
|
||||
|
||||
DSET_TYPES = [DSET_TYPE_BERT, DSET_TYPE_ICT, DSET_TYPE_T5]
|
||||
|
||||
|
||||
def _build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
|
||||
train_valid_test_num_samples,
|
||||
max_seq_length, masked_lm_prob,
|
||||
short_seq_prob, seed, skip_warmup,
|
||||
binary_head,
|
||||
dataset_type='standard_bert'):
|
||||
|
||||
def _build_train_valid_test_datasets(
|
||||
data_prefix,
|
||||
data_impl,
|
||||
splits_string,
|
||||
train_valid_test_num_samples,
|
||||
max_seq_length,
|
||||
masked_lm_prob,
|
||||
short_seq_prob,
|
||||
seed,
|
||||
skip_warmup,
|
||||
binary_head,
|
||||
dataset_type="standard_bert",
|
||||
):
|
||||
if dataset_type not in DSET_TYPES:
|
||||
raise ValueError("Invalid dataset_type: ", dataset_type)
|
||||
|
||||
# Indexed dataset.
|
||||
indexed_dataset = get_indexed_dataset_(data_prefix,
|
||||
data_impl,
|
||||
skip_warmup)
|
||||
indexed_dataset = get_indexed_dataset_(data_prefix, data_impl, skip_warmup)
|
||||
|
||||
# Get start and end indices of train/valid/train into doc-idx
|
||||
# Note that doc-idx is designed to be num-docs + 1 so we can
|
||||
@@ -34,22 +39,25 @@ def _build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
|
||||
logger = get_dist_logger()
|
||||
|
||||
# Print stats about the splits.
|
||||
logger.info('\n > dataset split:', ranks=[0])
|
||||
logger.info("\n > dataset split:", ranks=[0])
|
||||
|
||||
def print_split_stats(name, index):
|
||||
start_index = indexed_dataset.doc_idx[splits[index]]
|
||||
end_index = indexed_dataset.doc_idx[splits[index + 1]]
|
||||
logger.info('\n {}:'.format(name) +
|
||||
'\n document indices in [{}, {}) total of {} documents'.format(
|
||||
splits[index], splits[index + 1],
|
||||
splits[index + 1] - splits[index]) +
|
||||
'\n sentence indices in [{}, {}) total of {} sentences'.format(
|
||||
start_index, end_index,
|
||||
end_index - start_index),
|
||||
ranks=[0])
|
||||
print_split_stats('train', 0)
|
||||
print_split_stats('validation', 1)
|
||||
print_split_stats('test', 2)
|
||||
logger.info(
|
||||
"\n {}:".format(name)
|
||||
+ "\n document indices in [{}, {}) total of {} documents".format(
|
||||
splits[index], splits[index + 1], splits[index + 1] - splits[index]
|
||||
)
|
||||
+ "\n sentence indices in [{}, {}) total of {} sentences".format(
|
||||
start_index, end_index, end_index - start_index
|
||||
),
|
||||
ranks=[0],
|
||||
)
|
||||
|
||||
print_split_stats("train", 0)
|
||||
print_split_stats("validation", 1)
|
||||
print_split_stats("test", 2)
|
||||
|
||||
def build_dataset(index, name):
|
||||
dataset = None
|
||||
@@ -80,44 +88,53 @@ def _build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
|
||||
masked_lm_prob=masked_lm_prob,
|
||||
short_seq_prob=short_seq_prob,
|
||||
binary_head=binary_head,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Set the original pointer so dataset remains the main dataset.
|
||||
indexed_dataset.set_doc_idx(doc_idx_ptr)
|
||||
# Checks.
|
||||
assert indexed_dataset.doc_idx[0] == 0
|
||||
assert indexed_dataset.doc_idx.shape[0] == \
|
||||
(total_num_of_documents + 1)
|
||||
assert indexed_dataset.doc_idx.shape[0] == (total_num_of_documents + 1)
|
||||
return dataset
|
||||
|
||||
train_dataset = build_dataset(0, 'train')
|
||||
valid_dataset = build_dataset(1, 'valid')
|
||||
test_dataset = build_dataset(2, 'test')
|
||||
train_dataset = build_dataset(0, "train")
|
||||
valid_dataset = build_dataset(1, "valid")
|
||||
test_dataset = build_dataset(2, "test")
|
||||
|
||||
return (train_dataset, valid_dataset, test_dataset)
|
||||
|
||||
|
||||
def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
|
||||
train_valid_test_num_samples,
|
||||
max_seq_length, masked_lm_prob,
|
||||
short_seq_prob, seed, skip_warmup,
|
||||
binary_head,
|
||||
dataset_type='standard_bert'):
|
||||
|
||||
def build_train_valid_test_datasets(
|
||||
data_prefix,
|
||||
data_impl,
|
||||
splits_string,
|
||||
train_valid_test_num_samples,
|
||||
max_seq_length,
|
||||
masked_lm_prob,
|
||||
short_seq_prob,
|
||||
seed,
|
||||
skip_warmup,
|
||||
binary_head,
|
||||
dataset_type="standard_bert",
|
||||
):
|
||||
if len(data_prefix) == 1:
|
||||
return _build_train_valid_test_datasets(data_prefix[0],
|
||||
data_impl, splits_string,
|
||||
train_valid_test_num_samples,
|
||||
max_seq_length, masked_lm_prob,
|
||||
short_seq_prob, seed,
|
||||
skip_warmup,
|
||||
binary_head,
|
||||
dataset_type=dataset_type)
|
||||
return _build_train_valid_test_datasets(
|
||||
data_prefix[0],
|
||||
data_impl,
|
||||
splits_string,
|
||||
train_valid_test_num_samples,
|
||||
max_seq_length,
|
||||
masked_lm_prob,
|
||||
short_seq_prob,
|
||||
seed,
|
||||
skip_warmup,
|
||||
binary_head,
|
||||
dataset_type=dataset_type,
|
||||
)
|
||||
# Blending dataset.
|
||||
# Parse the values.
|
||||
output = get_datasets_weights_and_num_samples(data_prefix,
|
||||
train_valid_test_num_samples)
|
||||
output = get_datasets_weights_and_num_samples(data_prefix, train_valid_test_num_samples)
|
||||
prefixes, weights, datasets_train_valid_test_num_samples = output
|
||||
|
||||
# Build individual datasets.
|
||||
@@ -126,10 +143,18 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
|
||||
test_datasets = []
|
||||
for i in range(len(prefixes)):
|
||||
train_ds, valid_ds, test_ds = _build_train_valid_test_datasets(
|
||||
prefixes[i], data_impl, splits_string,
|
||||
prefixes[i],
|
||||
data_impl,
|
||||
splits_string,
|
||||
datasets_train_valid_test_num_samples[i],
|
||||
max_seq_length, masked_lm_prob, short_seq_prob,
|
||||
seed, skip_warmup, binary_head, dataset_type=dataset_type)
|
||||
max_seq_length,
|
||||
masked_lm_prob,
|
||||
short_seq_prob,
|
||||
seed,
|
||||
skip_warmup,
|
||||
binary_head,
|
||||
dataset_type=dataset_type,
|
||||
)
|
||||
if train_ds:
|
||||
train_datasets.append(train_ds)
|
||||
if valid_ds:
|
||||
@@ -148,5 +173,4 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
|
||||
if test_datasets:
|
||||
blending_test_dataset = BlendableDataset(test_datasets, weights)
|
||||
|
||||
return (blending_train_dataset, blending_valid_dataset,
|
||||
blending_test_dataset)
|
||||
return (blending_train_dataset, blending_valid_dataset, blending_test_dataset)
|
||||
|
||||
@@ -14,7 +14,6 @@
|
||||
# limitations under the License.
|
||||
"""Dataloaders."""
|
||||
|
||||
import random
|
||||
|
||||
import torch
|
||||
|
||||
@@ -22,61 +21,60 @@ from colossalai.legacy.context import ParallelMode
|
||||
from colossalai.legacy.core import global_context as gpc
|
||||
|
||||
|
||||
def build_pretraining_data_loader(dataset, consumed_samples, micro_batch_size, dataloader_type='single', num_workers=0):
|
||||
def build_pretraining_data_loader(dataset, consumed_samples, micro_batch_size, dataloader_type="single", num_workers=0):
|
||||
"""Build dataloader given an input dataset."""
|
||||
|
||||
if dataset is None:
|
||||
return None
|
||||
|
||||
# Megatron sampler
|
||||
if dataloader_type == 'single':
|
||||
batch_sampler = MegatronPretrainingSampler(total_samples=len(dataset),
|
||||
consumed_samples=consumed_samples,
|
||||
micro_batch_size=micro_batch_size,
|
||||
data_parallel_rank=gpc.get_local_rank(ParallelMode.DATA),
|
||||
data_parallel_size=gpc.get_world_size(ParallelMode.DATA))
|
||||
elif dataloader_type == 'cyclic':
|
||||
batch_sampler = MegatronPretrainingRandomSampler(total_samples=len(dataset),
|
||||
consumed_samples=consumed_samples,
|
||||
micro_batch_size=micro_batch_size,
|
||||
data_parallel_rank=gpc.get_local_rank(ParallelMode.DATA),
|
||||
data_parallel_size=gpc.get_world_size(ParallelMode.DATA))
|
||||
if dataloader_type == "single":
|
||||
batch_sampler = MegatronPretrainingSampler(
|
||||
total_samples=len(dataset),
|
||||
consumed_samples=consumed_samples,
|
||||
micro_batch_size=micro_batch_size,
|
||||
data_parallel_rank=gpc.get_local_rank(ParallelMode.DATA),
|
||||
data_parallel_size=gpc.get_world_size(ParallelMode.DATA),
|
||||
)
|
||||
elif dataloader_type == "cyclic":
|
||||
batch_sampler = MegatronPretrainingRandomSampler(
|
||||
total_samples=len(dataset),
|
||||
consumed_samples=consumed_samples,
|
||||
micro_batch_size=micro_batch_size,
|
||||
data_parallel_rank=gpc.get_local_rank(ParallelMode.DATA),
|
||||
data_parallel_size=gpc.get_world_size(ParallelMode.DATA),
|
||||
)
|
||||
else:
|
||||
raise Exception('{} dataloader type is not supported.'.format(dataloader_type))
|
||||
raise Exception("{} dataloader type is not supported.".format(dataloader_type))
|
||||
|
||||
# Torch dataloader.
|
||||
return torch.utils.data.DataLoader(dataset, batch_sampler=batch_sampler, num_workers=num_workers, pin_memory=True)
|
||||
|
||||
|
||||
class MegatronPretrainingSampler:
|
||||
|
||||
def __init__(self,
|
||||
total_samples,
|
||||
consumed_samples,
|
||||
micro_batch_size,
|
||||
data_parallel_rank,
|
||||
data_parallel_size,
|
||||
drop_last=True):
|
||||
def __init__(
|
||||
self, total_samples, consumed_samples, micro_batch_size, data_parallel_rank, data_parallel_size, drop_last=True
|
||||
):
|
||||
# Keep a copy of input params for later use.
|
||||
self.total_samples = total_samples
|
||||
self.consumed_samples = consumed_samples
|
||||
self.micro_batch_size = micro_batch_size
|
||||
self.data_parallel_rank = data_parallel_rank
|
||||
self.micro_batch_times_data_parallel_size = \
|
||||
self.micro_batch_size * data_parallel_size
|
||||
self.micro_batch_times_data_parallel_size = self.micro_batch_size * data_parallel_size
|
||||
self.drop_last = drop_last
|
||||
|
||||
# Sanity checks.
|
||||
assert self.total_samples > 0, \
|
||||
'no sample to consume: {}'.format(self.total_samples)
|
||||
assert self.consumed_samples < self.total_samples, \
|
||||
'no samples left to consume: {}, {}'.format(self.consumed_samples,
|
||||
self.total_samples)
|
||||
assert self.total_samples > 0, "no sample to consume: {}".format(self.total_samples)
|
||||
assert self.consumed_samples < self.total_samples, "no samples left to consume: {}, {}".format(
|
||||
self.consumed_samples, self.total_samples
|
||||
)
|
||||
assert self.micro_batch_size > 0
|
||||
assert data_parallel_size > 0
|
||||
assert self.data_parallel_rank < data_parallel_size, \
|
||||
'data_parallel_rank should be smaller than data size: {}, ' \
|
||||
'{}'.format(self.data_parallel_rank, data_parallel_size)
|
||||
assert (
|
||||
self.data_parallel_rank < data_parallel_size
|
||||
), "data_parallel_rank should be smaller than data size: {}, " "{}".format(
|
||||
self.data_parallel_rank, data_parallel_size
|
||||
)
|
||||
|
||||
def __len__(self):
|
||||
return self.total_samples
|
||||
@@ -103,7 +101,6 @@ class MegatronPretrainingSampler:
|
||||
|
||||
|
||||
class MegatronPretrainingRandomSampler:
|
||||
|
||||
def __init__(self, total_samples, consumed_samples, micro_batch_size, data_parallel_rank, data_parallel_size):
|
||||
# Keep a copy of input params for later use.
|
||||
self.total_samples = total_samples
|
||||
@@ -111,19 +108,18 @@ class MegatronPretrainingRandomSampler:
|
||||
self.micro_batch_size = micro_batch_size
|
||||
self.data_parallel_rank = data_parallel_rank
|
||||
self.data_parallel_size = data_parallel_size
|
||||
self.micro_batch_times_data_parallel_size = \
|
||||
self.micro_batch_size * data_parallel_size
|
||||
self.last_batch_size = \
|
||||
self.total_samples % self.micro_batch_times_data_parallel_size
|
||||
self.micro_batch_times_data_parallel_size = self.micro_batch_size * data_parallel_size
|
||||
self.last_batch_size = self.total_samples % self.micro_batch_times_data_parallel_size
|
||||
|
||||
# Sanity checks.
|
||||
assert self.total_samples > 0, \
|
||||
'no sample to consume: {}'.format(self.total_samples)
|
||||
assert self.total_samples > 0, "no sample to consume: {}".format(self.total_samples)
|
||||
assert self.micro_batch_size > 0
|
||||
assert data_parallel_size > 0
|
||||
assert self.data_parallel_rank < data_parallel_size, \
|
||||
'data_parallel_rank should be smaller than data size: {}, ' \
|
||||
'{}'.format(self.data_parallel_rank, data_parallel_size)
|
||||
assert (
|
||||
self.data_parallel_rank < data_parallel_size
|
||||
), "data_parallel_rank should be smaller than data size: {}, " "{}".format(
|
||||
self.data_parallel_rank, data_parallel_size
|
||||
)
|
||||
|
||||
def __len__(self):
|
||||
return self.total_samples
|
||||
@@ -135,8 +131,7 @@ class MegatronPretrainingRandomSampler:
|
||||
assert current_epoch_samples % self.micro_batch_times_data_parallel_size == 0
|
||||
|
||||
# data sharding and random sampling
|
||||
bucket_size = (self.total_samples // self.micro_batch_times_data_parallel_size) \
|
||||
* self.micro_batch_size
|
||||
bucket_size = (self.total_samples // self.micro_batch_times_data_parallel_size) * self.micro_batch_size
|
||||
bucket_offset = current_epoch_samples // self.data_parallel_size
|
||||
start_idx = self.data_parallel_rank * bucket_size
|
||||
|
||||
|
||||
@@ -18,32 +18,33 @@
|
||||
# https://github.com/google-research/albert/blob/master/create_pretraining_data.py
|
||||
# with some modifications.
|
||||
|
||||
import collections
|
||||
import math
|
||||
import time
|
||||
import collections
|
||||
from colossalai.logging import get_dist_logger
|
||||
|
||||
import numpy as np
|
||||
|
||||
from colossalai.logging import get_dist_logger
|
||||
|
||||
from .blendable_dataset import BlendableDataset
|
||||
from .indexed_dataset import make_dataset as make_indexed_dataset
|
||||
|
||||
DSET_TYPE_STD = 'standard_bert'
|
||||
DSET_TYPE_ICT = 'ict'
|
||||
DSET_TYPE_STD = "standard_bert"
|
||||
DSET_TYPE_ICT = "ict"
|
||||
|
||||
DSET_TYPES = [DSET_TYPE_ICT, DSET_TYPE_STD]
|
||||
|
||||
|
||||
def get_datasets_weights_and_num_samples(data_prefix,
|
||||
train_valid_test_num_samples):
|
||||
|
||||
def get_datasets_weights_and_num_samples(data_prefix, train_valid_test_num_samples):
|
||||
# The data prefix should be in the format of:
|
||||
# weight-1, data-prefix-1, weight-2, data-prefix-2, ..
|
||||
assert len(data_prefix) % 2 == 0
|
||||
num_datasets = len(data_prefix) // 2
|
||||
weights = [0]*num_datasets
|
||||
prefixes = [0]*num_datasets
|
||||
weights = [0] * num_datasets
|
||||
prefixes = [0] * num_datasets
|
||||
for i in range(num_datasets):
|
||||
weights[i] = float(data_prefix[2*i])
|
||||
prefixes[i] = (data_prefix[2*i+1]).strip()
|
||||
weights[i] = float(data_prefix[2 * i])
|
||||
prefixes[i] = (data_prefix[2 * i + 1]).strip()
|
||||
# Normalize weights
|
||||
weight_sum = 0.0
|
||||
for weight in weights:
|
||||
@@ -57,8 +58,8 @@ def get_datasets_weights_and_num_samples(data_prefix,
|
||||
datasets_train_valid_test_num_samples = []
|
||||
for weight in weights:
|
||||
datasets_train_valid_test_num_samples.append(
|
||||
[int(math.ceil(val * weight * 1.005))
|
||||
for val in train_valid_test_num_samples])
|
||||
[int(math.ceil(val * weight * 1.005)) for val in train_valid_test_num_samples]
|
||||
)
|
||||
|
||||
return prefixes, weights, datasets_train_valid_test_num_samples
|
||||
|
||||
@@ -68,11 +69,13 @@ def compile_helper():
|
||||
is invoked on a single process."""
|
||||
import os
|
||||
import subprocess
|
||||
|
||||
path = os.path.abspath(os.path.dirname(__file__))
|
||||
ret = subprocess.run(['make', '-C', path])
|
||||
ret = subprocess.run(["make", "-C", path])
|
||||
if ret.returncode != 0:
|
||||
print("Making C++ dataset helpers module failed, exiting.")
|
||||
import sys
|
||||
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
@@ -82,7 +85,7 @@ def get_a_and_b_segments(sample, np_rng):
|
||||
# Number of sentences in the sample.
|
||||
n_sentences = len(sample)
|
||||
# Make sure we always have two sentences.
|
||||
assert n_sentences > 1, 'make sure each sample has at least two sentences.'
|
||||
assert n_sentences > 1, "make sure each sample has at least two sentences."
|
||||
|
||||
# First part:
|
||||
# `a_end` is how many sentences go into the `A`.
|
||||
@@ -110,7 +113,7 @@ def get_a_and_b_segments(sample, np_rng):
|
||||
|
||||
def truncate_segments(tokens_a, tokens_b, len_a, len_b, max_num_tokens, np_rng):
|
||||
"""Truncates a pair of sequences to a maximum sequence length."""
|
||||
#print(len_a, len_b, max_num_tokens)
|
||||
# print(len_a, len_b, max_num_tokens)
|
||||
assert len_a > 0
|
||||
if len_a + len_b <= max_num_tokens:
|
||||
return False
|
||||
@@ -155,8 +158,7 @@ def create_tokens_and_tokentypes(tokens_a, tokens_b, cls_id, sep_id):
|
||||
return tokens, tokentypes
|
||||
|
||||
|
||||
MaskedLmInstance = collections.namedtuple("MaskedLmInstance",
|
||||
["index", "label"])
|
||||
MaskedLmInstance = collections.namedtuple("MaskedLmInstance", ["index", "label"])
|
||||
|
||||
|
||||
def is_start_piece(piece):
|
||||
@@ -168,16 +170,21 @@ def is_start_piece(piece):
|
||||
return not piece.startswith("##")
|
||||
|
||||
|
||||
def create_masked_lm_predictions(tokens,
|
||||
vocab_id_list, vocab_id_to_token_dict,
|
||||
masked_lm_prob,
|
||||
cls_id, sep_id, mask_id,
|
||||
max_predictions_per_seq,
|
||||
np_rng,
|
||||
max_ngrams=3,
|
||||
do_whole_word_mask=True,
|
||||
favor_longer_ngram=False,
|
||||
do_permutation=False):
|
||||
def create_masked_lm_predictions(
|
||||
tokens,
|
||||
vocab_id_list,
|
||||
vocab_id_to_token_dict,
|
||||
masked_lm_prob,
|
||||
cls_id,
|
||||
sep_id,
|
||||
mask_id,
|
||||
max_predictions_per_seq,
|
||||
np_rng,
|
||||
max_ngrams=3,
|
||||
do_whole_word_mask=True,
|
||||
favor_longer_ngram=False,
|
||||
do_permutation=False,
|
||||
):
|
||||
"""Creates the predictions for the masked LM objective.
|
||||
Note: Tokens here are vocab ids and not text tokens."""
|
||||
|
||||
@@ -187,7 +194,7 @@ def create_masked_lm_predictions(tokens,
|
||||
# on-the-fly whole word masking is possible.
|
||||
token_boundary = [0] * len(tokens)
|
||||
|
||||
for (i, token) in enumerate(tokens):
|
||||
for i, token in enumerate(tokens):
|
||||
if token == cls_id or token == sep_id:
|
||||
token_boundary[i] = 1
|
||||
continue
|
||||
@@ -197,8 +204,7 @@ def create_masked_lm_predictions(tokens,
|
||||
# Note that Whole Word Masking does *not* change the training code
|
||||
# at all -- we still predict each WordPiece independently, softmaxed
|
||||
# over the entire vocabulary.
|
||||
if (do_whole_word_mask and len(cand_indexes) >= 1 and
|
||||
not is_start_piece(vocab_id_to_token_dict[token])):
|
||||
if do_whole_word_mask and len(cand_indexes) >= 1 and not is_start_piece(vocab_id_to_token_dict[token]):
|
||||
cand_indexes[-1].append(i)
|
||||
else:
|
||||
cand_indexes.append([i])
|
||||
@@ -211,16 +217,14 @@ def create_masked_lm_predictions(tokens,
|
||||
masked_lm_labels = []
|
||||
|
||||
if masked_lm_prob == 0:
|
||||
return (output_tokens, masked_lm_positions,
|
||||
masked_lm_labels, token_boundary)
|
||||
return (output_tokens, masked_lm_positions, masked_lm_labels, token_boundary)
|
||||
|
||||
num_to_predict = min(max_predictions_per_seq,
|
||||
max(1, int(round(len(tokens) * masked_lm_prob))))
|
||||
num_to_predict = min(max_predictions_per_seq, max(1, int(round(len(tokens) * masked_lm_prob))))
|
||||
|
||||
# Note(mingdachen):
|
||||
# By default, we set the probabilities to favor shorter ngram sequences.
|
||||
ngrams = np.arange(1, max_ngrams + 1, dtype=np.int64)
|
||||
pvals = 1. / np.arange(1, max_ngrams + 1)
|
||||
pvals = 1.0 / np.arange(1, max_ngrams + 1)
|
||||
pvals /= pvals.sum(keepdims=True)
|
||||
|
||||
if favor_longer_ngram:
|
||||
@@ -230,7 +234,7 @@ def create_masked_lm_predictions(tokens,
|
||||
for idx in range(len(cand_indexes)):
|
||||
ngram_index = []
|
||||
for n in ngrams:
|
||||
ngram_index.append(cand_indexes[idx:idx + n])
|
||||
ngram_index.append(cand_indexes[idx : idx + n])
|
||||
ngram_indexes.append(ngram_index)
|
||||
|
||||
np_rng.shuffle(ngram_indexes)
|
||||
@@ -249,9 +253,10 @@ def create_masked_lm_predictions(tokens,
|
||||
if index in covered_indexes:
|
||||
continue
|
||||
|
||||
n = np_rng.choice(ngrams[:len(cand_index_set)],
|
||||
p=pvals[:len(cand_index_set)] /
|
||||
pvals[:len(cand_index_set)].sum(keepdims=True))
|
||||
n = np_rng.choice(
|
||||
ngrams[: len(cand_index_set)],
|
||||
p=pvals[: len(cand_index_set)] / pvals[: len(cand_index_set)].sum(keepdims=True),
|
||||
)
|
||||
index_set = sum(cand_index_set[n - 1], [])
|
||||
n -= 1
|
||||
# Note(mingdachen):
|
||||
@@ -309,9 +314,10 @@ def create_masked_lm_predictions(tokens,
|
||||
if index in covered_indexes or index in select_indexes:
|
||||
continue
|
||||
|
||||
n = np.random.choice(ngrams[:len(cand_index_set)],
|
||||
p=pvals[:len(cand_index_set)] /
|
||||
pvals[:len(cand_index_set)].sum(keepdims=True))
|
||||
n = np.random.choice(
|
||||
ngrams[: len(cand_index_set)],
|
||||
p=pvals[: len(cand_index_set)] / pvals[: len(cand_index_set)].sum(keepdims=True),
|
||||
)
|
||||
index_set = sum(cand_index_set[n - 1], [])
|
||||
n -= 1
|
||||
|
||||
@@ -353,8 +359,7 @@ def create_masked_lm_predictions(tokens,
|
||||
return (output_tokens, masked_lm_positions, masked_lm_labels, token_boundary)
|
||||
|
||||
|
||||
def pad_and_convert_to_numpy(tokens, tokentypes, masked_positions,
|
||||
masked_labels, pad_id, max_seq_length):
|
||||
def pad_and_convert_to_numpy(tokens, tokentypes, masked_positions, masked_labels, pad_id, max_seq_length):
|
||||
"""Pad sequences and convert them to numpy."""
|
||||
|
||||
# Some checks.
|
||||
@@ -370,8 +375,7 @@ def pad_and_convert_to_numpy(tokens, tokentypes, masked_positions,
|
||||
tokentypes_np = np.array(tokentypes + filler, dtype=np.int64)
|
||||
|
||||
# Padding mask.
|
||||
padding_mask_np = np.array([1] * num_tokens + [0] * padding_length,
|
||||
dtype=np.int64)
|
||||
padding_mask_np = np.array([1] * num_tokens + [0] * padding_length, dtype=np.int64)
|
||||
|
||||
# Lables and loss mask.
|
||||
labels = [-1] * max_seq_length
|
||||
@@ -386,26 +390,36 @@ def pad_and_convert_to_numpy(tokens, tokentypes, masked_positions,
|
||||
return tokens_np, tokentypes_np, labels_np, padding_mask_np, loss_mask_np
|
||||
|
||||
|
||||
def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
|
||||
train_valid_test_num_samples,
|
||||
max_seq_length, masked_lm_prob,
|
||||
short_seq_prob, seed, skip_warmup,
|
||||
binary_head,
|
||||
dataset_type='standard_bert'):
|
||||
|
||||
def build_train_valid_test_datasets(
|
||||
data_prefix,
|
||||
data_impl,
|
||||
splits_string,
|
||||
train_valid_test_num_samples,
|
||||
max_seq_length,
|
||||
masked_lm_prob,
|
||||
short_seq_prob,
|
||||
seed,
|
||||
skip_warmup,
|
||||
binary_head,
|
||||
dataset_type="standard_bert",
|
||||
):
|
||||
if len(data_prefix) == 1:
|
||||
return _build_train_valid_test_datasets(data_prefix[0],
|
||||
data_impl, splits_string,
|
||||
train_valid_test_num_samples,
|
||||
max_seq_length, masked_lm_prob,
|
||||
short_seq_prob, seed,
|
||||
skip_warmup,
|
||||
binary_head,
|
||||
dataset_type=dataset_type)
|
||||
return _build_train_valid_test_datasets(
|
||||
data_prefix[0],
|
||||
data_impl,
|
||||
splits_string,
|
||||
train_valid_test_num_samples,
|
||||
max_seq_length,
|
||||
masked_lm_prob,
|
||||
short_seq_prob,
|
||||
seed,
|
||||
skip_warmup,
|
||||
binary_head,
|
||||
dataset_type=dataset_type,
|
||||
)
|
||||
# Blending dataset.
|
||||
# Parse the values.
|
||||
output = get_datasets_weights_and_num_samples(data_prefix,
|
||||
train_valid_test_num_samples)
|
||||
output = get_datasets_weights_and_num_samples(data_prefix, train_valid_test_num_samples)
|
||||
prefixes, weights, datasets_train_valid_test_num_samples = output
|
||||
|
||||
# Build individual datasets.
|
||||
@@ -414,10 +428,18 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
|
||||
test_datasets = []
|
||||
for i in range(len(prefixes)):
|
||||
train_ds, valid_ds, test_ds = _build_train_valid_test_datasets(
|
||||
prefixes[i], data_impl, splits_string,
|
||||
prefixes[i],
|
||||
data_impl,
|
||||
splits_string,
|
||||
datasets_train_valid_test_num_samples[i],
|
||||
max_seq_length, masked_lm_prob, short_seq_prob,
|
||||
seed, skip_warmup, binary_head, dataset_type=dataset_type)
|
||||
max_seq_length,
|
||||
masked_lm_prob,
|
||||
short_seq_prob,
|
||||
seed,
|
||||
skip_warmup,
|
||||
binary_head,
|
||||
dataset_type=dataset_type,
|
||||
)
|
||||
if train_ds:
|
||||
train_datasets.append(train_ds)
|
||||
if valid_ds:
|
||||
@@ -436,31 +458,33 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
|
||||
if test_datasets:
|
||||
blending_test_dataset = BlendableDataset(test_datasets, weights)
|
||||
|
||||
return (blending_train_dataset, blending_valid_dataset,
|
||||
blending_test_dataset)
|
||||
return (blending_train_dataset, blending_valid_dataset, blending_test_dataset)
|
||||
|
||||
|
||||
def _build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
|
||||
train_valid_test_num_samples,
|
||||
max_seq_length, masked_lm_prob,
|
||||
short_seq_prob, seed, skip_warmup,
|
||||
binary_head,
|
||||
dataset_type='standard_bert'):
|
||||
def _build_train_valid_test_datasets(
|
||||
data_prefix,
|
||||
data_impl,
|
||||
splits_string,
|
||||
train_valid_test_num_samples,
|
||||
max_seq_length,
|
||||
masked_lm_prob,
|
||||
short_seq_prob,
|
||||
seed,
|
||||
skip_warmup,
|
||||
binary_head,
|
||||
dataset_type="standard_bert",
|
||||
):
|
||||
logger = get_dist_logger()
|
||||
|
||||
if dataset_type not in DSET_TYPES:
|
||||
raise ValueError("Invalid dataset_type: ", dataset_type)
|
||||
|
||||
# Indexed dataset.
|
||||
indexed_dataset = get_indexed_dataset_(data_prefix,
|
||||
data_impl,
|
||||
skip_warmup)
|
||||
indexed_dataset = get_indexed_dataset_(data_prefix, data_impl, skip_warmup)
|
||||
|
||||
if dataset_type == DSET_TYPE_ICT:
|
||||
args = get_args()
|
||||
title_dataset = get_indexed_dataset_(args.titles_data_path,
|
||||
data_impl,
|
||||
skip_warmup)
|
||||
title_dataset = get_indexed_dataset_(args.titles_data_path, data_impl, skip_warmup)
|
||||
|
||||
# Get start and end indices of train/valid/train into doc-idx
|
||||
# Note that doc-idx is designed to be num-docs + 1 so we can
|
||||
@@ -469,27 +493,29 @@ def _build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
|
||||
splits = get_train_valid_test_split_(splits_string, total_num_of_documents)
|
||||
|
||||
# Print stats about the splits.
|
||||
logger.info('\n > dataset split:')
|
||||
logger.info("\n > dataset split:")
|
||||
|
||||
def print_split_stats(name, index):
|
||||
start_index = indexed_dataset.doc_idx[splits[index]]
|
||||
end_index = indexed_dataset.doc_idx[splits[index + 1]]
|
||||
logger.info('\n {}:'.format(name) +
|
||||
'\n document indices in [{}, {}) total of {} documents'.format(
|
||||
splits[index],
|
||||
splits[index + 1],
|
||||
splits[index + 1] - splits[index]) +
|
||||
'\n sentence indices in [{}, {}) total of {} sentences'.format(
|
||||
start_index,
|
||||
end_index,
|
||||
end_index - start_index),
|
||||
ranks=[0])
|
||||
print_split_stats('train', 0)
|
||||
print_split_stats('validation', 1)
|
||||
print_split_stats('test', 2)
|
||||
logger.info(
|
||||
"\n {}:".format(name)
|
||||
+ "\n document indices in [{}, {}) total of {} documents".format(
|
||||
splits[index], splits[index + 1], splits[index + 1] - splits[index]
|
||||
)
|
||||
+ "\n sentence indices in [{}, {}) total of {} sentences".format(
|
||||
start_index, end_index, end_index - start_index
|
||||
),
|
||||
ranks=[0],
|
||||
)
|
||||
|
||||
print_split_stats("train", 0)
|
||||
print_split_stats("validation", 1)
|
||||
print_split_stats("test", 2)
|
||||
|
||||
def build_dataset(index, name):
|
||||
from .bert_dataset import BertDataset
|
||||
|
||||
dataset = None
|
||||
if splits[index + 1] > splits[index]:
|
||||
# Get the pointer to the original doc-idx so we can set it later.
|
||||
@@ -508,7 +534,7 @@ def _build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
|
||||
max_num_samples=train_valid_test_num_samples[index],
|
||||
max_seq_length=max_seq_length,
|
||||
seed=seed,
|
||||
binary_head=binary_head
|
||||
binary_head=binary_head,
|
||||
)
|
||||
|
||||
if dataset_type == DSET_TYPE_ICT:
|
||||
@@ -518,27 +544,26 @@ def _build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
|
||||
title_dataset=title_dataset,
|
||||
query_in_block_prob=args.query_in_block_prob,
|
||||
use_one_sent_docs=args.use_one_sent_docs,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
dataset = BertDataset(
|
||||
indexed_dataset=indexed_dataset,
|
||||
masked_lm_prob=masked_lm_prob,
|
||||
short_seq_prob=short_seq_prob,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Set the original pointer so dataset remains the main dataset.
|
||||
indexed_dataset.set_doc_idx(doc_idx_ptr)
|
||||
# Checks.
|
||||
assert indexed_dataset.doc_idx[0] == 0
|
||||
assert indexed_dataset.doc_idx.shape[0] == \
|
||||
(total_num_of_documents + 1)
|
||||
assert indexed_dataset.doc_idx.shape[0] == (total_num_of_documents + 1)
|
||||
return dataset
|
||||
|
||||
train_dataset = build_dataset(0, 'train')
|
||||
valid_dataset = build_dataset(1, 'valid')
|
||||
test_dataset = build_dataset(2, 'test')
|
||||
train_dataset = build_dataset(0, "train")
|
||||
valid_dataset = build_dataset(1, "valid")
|
||||
test_dataset = build_dataset(2, "test")
|
||||
|
||||
return (train_dataset, valid_dataset, test_dataset)
|
||||
|
||||
@@ -546,44 +571,41 @@ def _build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
|
||||
def get_indexed_dataset_(data_prefix, data_impl, skip_warmup):
|
||||
logger = get_dist_logger()
|
||||
start_time = time.time()
|
||||
indexed_dataset = make_indexed_dataset(data_prefix,
|
||||
data_impl,
|
||||
skip_warmup)
|
||||
indexed_dataset = make_indexed_dataset(data_prefix, data_impl, skip_warmup)
|
||||
assert indexed_dataset.sizes.shape[0] == indexed_dataset.doc_idx[-1]
|
||||
logger.info('\n > building dataset index ...', ranks=[0])
|
||||
logger.info('\n > finished creating indexed dataset in {:4f} '
|
||||
'seconds'.format(time.time() - start_time), ranks=[0])
|
||||
logger.info('\n > indexed dataset stats:' +
|
||||
'\n number of documents: {}'.format(
|
||||
indexed_dataset.doc_idx.shape[0] - 1) +
|
||||
'\n number of sentences: {}'.format(
|
||||
indexed_dataset.sizes.shape[0]),
|
||||
ranks=[0]
|
||||
)
|
||||
logger.info("\n > building dataset index ...", ranks=[0])
|
||||
logger.info(
|
||||
"\n > finished creating indexed dataset in {:4f} " "seconds".format(time.time() - start_time), ranks=[0]
|
||||
)
|
||||
logger.info(
|
||||
"\n > indexed dataset stats:"
|
||||
+ "\n number of documents: {}".format(indexed_dataset.doc_idx.shape[0] - 1)
|
||||
+ "\n number of sentences: {}".format(indexed_dataset.sizes.shape[0]),
|
||||
ranks=[0],
|
||||
)
|
||||
|
||||
return indexed_dataset
|
||||
|
||||
|
||||
def get_train_valid_test_split_(splits_string, size):
|
||||
""" Get dataset splits from comma or '/' separated string list."""
|
||||
"""Get dataset splits from comma or '/' separated string list."""
|
||||
|
||||
splits = []
|
||||
if splits_string.find(',') != -1:
|
||||
splits = [float(s) for s in splits_string.split(',')]
|
||||
elif splits_string.find('/') != -1:
|
||||
splits = [float(s) for s in splits_string.split('/')]
|
||||
if splits_string.find(",") != -1:
|
||||
splits = [float(s) for s in splits_string.split(",")]
|
||||
elif splits_string.find("/") != -1:
|
||||
splits = [float(s) for s in splits_string.split("/")]
|
||||
else:
|
||||
splits = [float(splits_string)]
|
||||
while len(splits) < 3:
|
||||
splits.append(0.)
|
||||
splits.append(0.0)
|
||||
splits = splits[:3]
|
||||
splits_sum = sum(splits)
|
||||
assert splits_sum > 0.0
|
||||
splits = [split / splits_sum for split in splits]
|
||||
splits_index = [0]
|
||||
for index, split in enumerate(splits):
|
||||
splits_index.append(splits_index[index] +
|
||||
int(round(split * float(size))))
|
||||
splits_index.append(splits_index[index] + int(round(split * float(size))))
|
||||
diff = splits_index[-1] - size
|
||||
for index in range(1, len(splits_index)):
|
||||
splits_index[index] -= diff
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -2,12 +2,11 @@ import itertools
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
from megatron import get_tokenizer
|
||||
from megatron import get_args
|
||||
from megatron import get_args, get_tokenizer
|
||||
from megatron.data.dataset_utils import get_indexed_dataset_
|
||||
from megatron.data.realm_dataset_utils import get_block_samples_mapping
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
|
||||
def make_attention_mask(source_block, target_block):
|
||||
"""
|
||||
@@ -20,16 +19,17 @@ def make_attention_mask(source_block, target_block):
|
||||
# (source_length, target_length)
|
||||
return mask
|
||||
|
||||
|
||||
def get_ict_dataset(use_titles=True, query_in_block_prob=1):
|
||||
"""Get a dataset which uses block samples mappings to get ICT/block indexing data (via get_block())
|
||||
rather than for training, since it is only built with a single epoch sample mapping.
|
||||
"""
|
||||
args = get_args()
|
||||
block_dataset = get_indexed_dataset_(args.data_path, 'mmap', True)
|
||||
titles_dataset = get_indexed_dataset_(args.titles_data_path, 'mmap', True)
|
||||
block_dataset = get_indexed_dataset_(args.data_path, "mmap", True)
|
||||
titles_dataset = get_indexed_dataset_(args.titles_data_path, "mmap", True)
|
||||
|
||||
kwargs = dict(
|
||||
name='full',
|
||||
name="full",
|
||||
block_dataset=block_dataset,
|
||||
title_dataset=titles_dataset,
|
||||
data_prefix=args.data_path,
|
||||
@@ -39,7 +39,7 @@ def get_ict_dataset(use_titles=True, query_in_block_prob=1):
|
||||
seed=1,
|
||||
query_in_block_prob=query_in_block_prob,
|
||||
use_titles=use_titles,
|
||||
use_one_sent_docs=args.use_one_sent_docs
|
||||
use_one_sent_docs=args.use_one_sent_docs,
|
||||
)
|
||||
dataset = ICTDataset(**kwargs)
|
||||
return dataset
|
||||
@@ -47,9 +47,22 @@ def get_ict_dataset(use_titles=True, query_in_block_prob=1):
|
||||
|
||||
class ICTDataset(Dataset):
|
||||
"""Dataset containing sentences and their blocks for an inverse cloze task."""
|
||||
def __init__(self, name, block_dataset, title_dataset, data_prefix,
|
||||
num_epochs, max_num_samples, max_seq_length, query_in_block_prob,
|
||||
seed, use_titles=True, use_one_sent_docs=False, binary_head=False):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name,
|
||||
block_dataset,
|
||||
title_dataset,
|
||||
data_prefix,
|
||||
num_epochs,
|
||||
max_num_samples,
|
||||
max_seq_length,
|
||||
query_in_block_prob,
|
||||
seed,
|
||||
use_titles=True,
|
||||
use_one_sent_docs=False,
|
||||
binary_head=False,
|
||||
):
|
||||
self.name = name
|
||||
self.seed = seed
|
||||
self.max_seq_length = max_seq_length
|
||||
@@ -61,8 +74,16 @@ class ICTDataset(Dataset):
|
||||
self.use_one_sent_docs = use_one_sent_docs
|
||||
|
||||
self.samples_mapping = get_block_samples_mapping(
|
||||
block_dataset, title_dataset, data_prefix, num_epochs,
|
||||
max_num_samples, max_seq_length, seed, name, use_one_sent_docs)
|
||||
block_dataset,
|
||||
title_dataset,
|
||||
data_prefix,
|
||||
num_epochs,
|
||||
max_num_samples,
|
||||
max_seq_length,
|
||||
seed,
|
||||
name,
|
||||
use_one_sent_docs,
|
||||
)
|
||||
self.tokenizer = get_tokenizer()
|
||||
self.vocab_id_list = list(self.tokenizer.inv_vocab.keys())
|
||||
self.vocab_id_to_token_list = self.tokenizer.inv_vocab
|
||||
@@ -99,8 +120,8 @@ class ICTDataset(Dataset):
|
||||
|
||||
# still need to truncate because blocks are concluded when
|
||||
# the sentence lengths have exceeded max_seq_length.
|
||||
query = query[:self.max_seq_length - 2]
|
||||
block = list(itertools.chain(*block))[:self.max_seq_length - title_pad_offset]
|
||||
query = query[: self.max_seq_length - 2]
|
||||
block = list(itertools.chain(*block))[: self.max_seq_length - title_pad_offset]
|
||||
|
||||
query_tokens, query_pad_mask = self.concat_and_pad_tokens(query)
|
||||
context_tokens, context_pad_mask = self.concat_and_pad_tokens(block, title)
|
||||
@@ -111,13 +132,13 @@ class ICTDataset(Dataset):
|
||||
block_data = sample_data.as_array()
|
||||
|
||||
sample = {
|
||||
'query_tokens': query_tokens,
|
||||
'query_mask': query_mask,
|
||||
'query_pad_mask': query_pad_mask,
|
||||
'context_tokens': context_tokens,
|
||||
'context_mask': context_mask,
|
||||
'context_pad_mask': context_pad_mask,
|
||||
'block_data': block_data,
|
||||
"query_tokens": query_tokens,
|
||||
"query_mask": query_mask,
|
||||
"query_pad_mask": query_pad_mask,
|
||||
"context_tokens": context_tokens,
|
||||
"context_mask": context_mask,
|
||||
"context_pad_mask": context_pad_mask,
|
||||
"block_data": block_data,
|
||||
}
|
||||
|
||||
return sample
|
||||
@@ -127,7 +148,7 @@ class ICTDataset(Dataset):
|
||||
block = [self.block_dataset[i] for i in range(start_idx, end_idx)]
|
||||
title = self.title_dataset[int(doc_idx)]
|
||||
|
||||
block = list(itertools.chain(*block))[:self.max_seq_length - (3 + len(title))]
|
||||
block = list(itertools.chain(*block))[: self.max_seq_length - (3 + len(title))]
|
||||
block_tokens, block_pad_mask = self.concat_and_pad_tokens(block, title)
|
||||
|
||||
return block_tokens, block_pad_mask
|
||||
|
||||
@@ -27,17 +27,17 @@ def __best_fitting_dtype(vocab_size=None):
|
||||
|
||||
|
||||
def get_available_dataset_impl():
|
||||
return ['lazy', 'cached', 'mmap']
|
||||
return ["lazy", "cached", "mmap"]
|
||||
|
||||
|
||||
def infer_dataset_impl(path):
|
||||
if IndexedDataset.exists(path):
|
||||
with open(index_file_path(path), 'rb') as f:
|
||||
with open(index_file_path(path), "rb") as f:
|
||||
magic = f.read(8)
|
||||
if magic == IndexedDataset._HDR_MAGIC:
|
||||
return 'cached'
|
||||
return "cached"
|
||||
elif magic == MMapIndexedDataset.Index._HDR_MAGIC[:8]:
|
||||
return 'mmap'
|
||||
return "mmap"
|
||||
else:
|
||||
return None
|
||||
else:
|
||||
@@ -47,7 +47,7 @@ def infer_dataset_impl(path):
|
||||
|
||||
|
||||
def make_builder(out_file, impl, vocab_size=None):
|
||||
if impl == 'mmap':
|
||||
if impl == "mmap":
|
||||
return MMapIndexedDatasetBuilder(out_file, dtype=__best_fitting_dtype(vocab_size))
|
||||
else:
|
||||
return IndexedDatasetBuilder(out_file)
|
||||
@@ -58,20 +58,20 @@ def make_dataset(path, impl, skip_warmup=False):
|
||||
print(f"Dataset does not exist: {path}")
|
||||
print("Path should be a basename that both .idx and .bin can be appended to get full filenames.")
|
||||
return None
|
||||
if impl == 'infer':
|
||||
if impl == "infer":
|
||||
impl = infer_dataset_impl(path)
|
||||
if impl == 'lazy' and IndexedDataset.exists(path):
|
||||
if impl == "lazy" and IndexedDataset.exists(path):
|
||||
return IndexedDataset(path)
|
||||
elif impl == 'cached' and IndexedDataset.exists(path):
|
||||
elif impl == "cached" and IndexedDataset.exists(path):
|
||||
return IndexedCachedDataset(path)
|
||||
elif impl == 'mmap' and MMapIndexedDataset.exists(path):
|
||||
elif impl == "mmap" and MMapIndexedDataset.exists(path):
|
||||
return MMapIndexedDataset(path, skip_warmup)
|
||||
print(f"Unknown dataset implementation: {impl}")
|
||||
return None
|
||||
|
||||
|
||||
def dataset_exists(path, impl):
|
||||
if impl == 'mmap':
|
||||
if impl == "mmap":
|
||||
return MMapIndexedDataset.exists(path)
|
||||
else:
|
||||
return IndexedDataset.exists(path)
|
||||
@@ -98,11 +98,11 @@ def code(dtype):
|
||||
|
||||
|
||||
def index_file_path(prefix_path):
|
||||
return prefix_path + '.idx'
|
||||
return prefix_path + ".idx"
|
||||
|
||||
|
||||
def data_file_path(prefix_path):
|
||||
return prefix_path + '.bin'
|
||||
return prefix_path + ".bin"
|
||||
|
||||
|
||||
def create_doc_idx(sizes):
|
||||
@@ -115,7 +115,8 @@ def create_doc_idx(sizes):
|
||||
|
||||
class IndexedDataset(torch.utils.data.Dataset):
|
||||
"""Loader for IndexedDataset"""
|
||||
_HDR_MAGIC = b'TNTIDX\x00\x00'
|
||||
|
||||
_HDR_MAGIC = b"TNTIDX\x00\x00"
|
||||
|
||||
def __init__(self, path):
|
||||
super().__init__()
|
||||
@@ -124,27 +125,28 @@ class IndexedDataset(torch.utils.data.Dataset):
|
||||
self.read_index(path)
|
||||
|
||||
def read_index(self, path):
|
||||
with open(index_file_path(path), 'rb') as f:
|
||||
with open(index_file_path(path), "rb") as f:
|
||||
magic = f.read(8)
|
||||
assert magic == self._HDR_MAGIC, ('Index file doesn\'t match expected format. '
|
||||
'Make sure that --dataset-impl is configured properly.')
|
||||
assert magic == self._HDR_MAGIC, (
|
||||
"Index file doesn't match expected format. " "Make sure that --dataset-impl is configured properly."
|
||||
)
|
||||
version = f.read(8)
|
||||
assert struct.unpack('<Q', version) == (1,)
|
||||
code, self.element_size = struct.unpack('<QQ', f.read(16))
|
||||
assert struct.unpack("<Q", version) == (1,)
|
||||
code, self.element_size = struct.unpack("<QQ", f.read(16))
|
||||
self.dtype = dtypes[code]
|
||||
self._len, self.s = struct.unpack('<QQ', f.read(16))
|
||||
self.doc_count = struct.unpack('<Q', f.read(8))
|
||||
self._len, self.s = struct.unpack("<QQ", f.read(16))
|
||||
self.doc_count = struct.unpack("<Q", f.read(8))
|
||||
self.dim_offsets = read_longs(f, self._len + 1)
|
||||
self.data_offsets = read_longs(f, self._len + 1)
|
||||
self.sizes = read_longs(f, self.s)
|
||||
self.doc_idx = read_longs(f, self.doc_count)
|
||||
|
||||
def read_data(self, path):
|
||||
self.data_file = open(data_file_path(path), 'rb', buffering=0)
|
||||
self.data_file = open(data_file_path(path), "rb", buffering=0)
|
||||
|
||||
def check_index(self, i):
|
||||
if i < 0 or i >= self._len:
|
||||
raise IndexError('index out of range')
|
||||
raise IndexError("index out of range")
|
||||
|
||||
def __del__(self):
|
||||
if self.data_file:
|
||||
@@ -157,7 +159,7 @@ class IndexedDataset(torch.utils.data.Dataset):
|
||||
if isinstance(idx, int):
|
||||
i = idx
|
||||
self.check_index(i)
|
||||
tensor_size = self.sizes[self.dim_offsets[i]:self.dim_offsets[i + 1]]
|
||||
tensor_size = self.sizes[self.dim_offsets[i] : self.dim_offsets[i + 1]]
|
||||
a = np.empty(tensor_size, dtype=self.dtype)
|
||||
self.data_file.seek(self.data_offsets[i] * self.element_size)
|
||||
self.data_file.readinto(a)
|
||||
@@ -166,7 +168,7 @@ class IndexedDataset(torch.utils.data.Dataset):
|
||||
start, stop, step = idx.indices(len(self))
|
||||
if step != 1:
|
||||
raise ValueError("Slices into indexed_dataset must be contiguous")
|
||||
sizes = self.sizes[self.dim_offsets[start]:self.dim_offsets[stop]]
|
||||
sizes = self.sizes[self.dim_offsets[start] : self.dim_offsets[stop]]
|
||||
size = sum(sizes)
|
||||
a = np.empty(size, dtype=self.dtype)
|
||||
self.data_file.seek(self.data_offsets[start] * self.element_size)
|
||||
@@ -186,15 +188,14 @@ class IndexedDataset(torch.utils.data.Dataset):
|
||||
|
||||
@staticmethod
|
||||
def exists(path):
|
||||
return (os.path.exists(index_file_path(path)) and os.path.exists(data_file_path(path)))
|
||||
return os.path.exists(index_file_path(path)) and os.path.exists(data_file_path(path))
|
||||
|
||||
@property
|
||||
def supports_prefetch(self):
|
||||
return False # avoid prefetching to save memory
|
||||
return False # avoid prefetching to save memory
|
||||
|
||||
|
||||
class IndexedCachedDataset(IndexedDataset):
|
||||
|
||||
def __init__(self, path):
|
||||
super().__init__(path)
|
||||
self.cache = None
|
||||
@@ -219,7 +220,7 @@ class IndexedCachedDataset(IndexedDataset):
|
||||
for i in indices:
|
||||
self.cache_index[i] = ptx
|
||||
size = self.data_offsets[i + 1] - self.data_offsets[i]
|
||||
a = self.cache[ptx:ptx + size]
|
||||
a = self.cache[ptx : ptx + size]
|
||||
self.data_file.seek(self.data_offsets[i] * self.element_size)
|
||||
self.data_file.readinto(a)
|
||||
ptx += size
|
||||
@@ -233,10 +234,10 @@ class IndexedCachedDataset(IndexedDataset):
|
||||
if isinstance(idx, int):
|
||||
i = idx
|
||||
self.check_index(i)
|
||||
tensor_size = self.sizes[self.dim_offsets[i]:self.dim_offsets[i + 1]]
|
||||
tensor_size = self.sizes[self.dim_offsets[i] : self.dim_offsets[i + 1]]
|
||||
a = np.empty(tensor_size, dtype=self.dtype)
|
||||
ptx = self.cache_index[i]
|
||||
np.copyto(a, self.cache[ptx:ptx + a.size])
|
||||
np.copyto(a, self.cache[ptx : ptx + a.size])
|
||||
return a
|
||||
elif isinstance(idx, slice):
|
||||
# Hack just to make this work, can optimizer later if necessary
|
||||
@@ -250,7 +251,7 @@ class IndexedDatasetBuilder(object):
|
||||
element_sizes = {np.uint8: 1, np.int8: 1, np.int16: 2, np.int32: 4, np.int64: 8, float: 4, np.double: 8}
|
||||
|
||||
def __init__(self, out_file, dtype=np.int32):
|
||||
self.out_file = open(out_file, 'wb')
|
||||
self.out_file = open(out_file, "wb")
|
||||
self.dtype = dtype
|
||||
self.data_offsets = [0]
|
||||
self.dim_offsets = [0]
|
||||
@@ -280,7 +281,7 @@ class IndexedDatasetBuilder(object):
|
||||
for dim_offset in index.dim_offsets[1:]:
|
||||
self.dim_offsets.append(begin + dim_offset)
|
||||
|
||||
with open(data_file_path(another_file), 'rb') as f:
|
||||
with open(data_file_path(another_file), "rb") as f:
|
||||
while True:
|
||||
data = f.read(1024)
|
||||
if data:
|
||||
@@ -290,12 +291,12 @@ class IndexedDatasetBuilder(object):
|
||||
|
||||
def finalize(self, index_file):
|
||||
self.out_file.close()
|
||||
index = open(index_file, 'wb')
|
||||
index.write(b'TNTIDX\x00\x00')
|
||||
index.write(struct.pack('<Q', 1))
|
||||
index.write(struct.pack('<QQ', code(self.dtype), self.element_size))
|
||||
index.write(struct.pack('<QQ', len(self.data_offsets) - 1, len(self.sizes)))
|
||||
index.write(struct.pack('<Q', len(self.doc_idx)))
|
||||
index = open(index_file, "wb")
|
||||
index.write(b"TNTIDX\x00\x00")
|
||||
index.write(struct.pack("<Q", 1))
|
||||
index.write(struct.pack("<QQ", code(self.dtype), self.element_size))
|
||||
index.write(struct.pack("<QQ", len(self.data_offsets) - 1, len(self.sizes)))
|
||||
index.write(struct.pack("<Q", len(self.doc_idx)))
|
||||
write_longs(index, self.dim_offsets)
|
||||
write_longs(index, self.data_offsets)
|
||||
write_longs(index, self.sizes)
|
||||
@@ -304,27 +305,24 @@ class IndexedDatasetBuilder(object):
|
||||
|
||||
|
||||
def _warmup_mmap_file(path):
|
||||
with open(path, 'rb') as stream:
|
||||
with open(path, "rb") as stream:
|
||||
while stream.read(100 * 1024 * 1024):
|
||||
pass
|
||||
|
||||
|
||||
class MMapIndexedDataset(torch.utils.data.Dataset):
|
||||
|
||||
class Index(object):
|
||||
_HDR_MAGIC = b'MMIDIDX\x00\x00'
|
||||
_HDR_MAGIC = b"MMIDIDX\x00\x00"
|
||||
|
||||
@classmethod
|
||||
def writer(cls, path, dtype):
|
||||
|
||||
class _Writer(object):
|
||||
|
||||
def __enter__(self):
|
||||
self._file = open(path, 'wb')
|
||||
self._file = open(path, "wb")
|
||||
|
||||
self._file.write(cls._HDR_MAGIC)
|
||||
self._file.write(struct.pack('<Q', 1))
|
||||
self._file.write(struct.pack('<B', code(dtype)))
|
||||
self._file.write(struct.pack("<Q", 1))
|
||||
self._file.write(struct.pack("<B", code(dtype)))
|
||||
|
||||
return self
|
||||
|
||||
@@ -343,19 +341,19 @@ class MMapIndexedDataset(torch.utils.data.Dataset):
|
||||
def write(self, sizes, doc_idx):
|
||||
pointers = self._get_pointers(sizes)
|
||||
|
||||
self._file.write(struct.pack('<Q', len(sizes)))
|
||||
self._file.write(struct.pack('<Q', len(doc_idx)))
|
||||
self._file.write(struct.pack("<Q", len(sizes)))
|
||||
self._file.write(struct.pack("<Q", len(doc_idx)))
|
||||
|
||||
sizes = np.array(sizes, dtype=np.int32)
|
||||
self._file.write(sizes.tobytes(order='C'))
|
||||
self._file.write(sizes.tobytes(order="C"))
|
||||
del sizes
|
||||
|
||||
pointers = np.array(pointers, dtype=np.int64)
|
||||
self._file.write(pointers.tobytes(order='C'))
|
||||
self._file.write(pointers.tobytes(order="C"))
|
||||
del pointers
|
||||
|
||||
doc_idx = np.array(doc_idx, dtype=np.int64)
|
||||
self._file.write(doc_idx.tobytes(order='C'))
|
||||
self._file.write(doc_idx.tobytes(order="C"))
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
self._file.close()
|
||||
@@ -363,39 +361,41 @@ class MMapIndexedDataset(torch.utils.data.Dataset):
|
||||
return _Writer()
|
||||
|
||||
def __init__(self, path, skip_warmup=False):
|
||||
with open(path, 'rb') as stream:
|
||||
with open(path, "rb") as stream:
|
||||
magic_test = stream.read(9)
|
||||
assert self._HDR_MAGIC == magic_test, ('Index file doesn\'t match expected format. '
|
||||
'Make sure that --dataset-impl is configured properly.')
|
||||
version = struct.unpack('<Q', stream.read(8))
|
||||
assert self._HDR_MAGIC == magic_test, (
|
||||
"Index file doesn't match expected format. " "Make sure that --dataset-impl is configured properly."
|
||||
)
|
||||
version = struct.unpack("<Q", stream.read(8))
|
||||
assert (1,) == version
|
||||
|
||||
dtype_code, = struct.unpack('<B', stream.read(1))
|
||||
(dtype_code,) = struct.unpack("<B", stream.read(1))
|
||||
self._dtype = dtypes[dtype_code]
|
||||
self._dtype_size = self._dtype().itemsize
|
||||
|
||||
self._len = struct.unpack('<Q', stream.read(8))[0]
|
||||
self._doc_count = struct.unpack('<Q', stream.read(8))[0]
|
||||
self._len = struct.unpack("<Q", stream.read(8))[0]
|
||||
self._doc_count = struct.unpack("<Q", stream.read(8))[0]
|
||||
offset = stream.tell()
|
||||
|
||||
if not skip_warmup:
|
||||
print(" warming up index mmap file...")
|
||||
_warmup_mmap_file(path)
|
||||
|
||||
self._bin_buffer_mmap = np.memmap(path, mode='r', order='C')
|
||||
self._bin_buffer_mmap = np.memmap(path, mode="r", order="C")
|
||||
self._bin_buffer = memoryview(self._bin_buffer_mmap)
|
||||
print(" reading sizes...")
|
||||
self._sizes = np.frombuffer(self._bin_buffer, dtype=np.int32, count=self._len, offset=offset)
|
||||
print(" reading pointers...")
|
||||
self._pointers = np.frombuffer(self._bin_buffer,
|
||||
dtype=np.int64,
|
||||
count=self._len,
|
||||
offset=offset + self._sizes.nbytes)
|
||||
self._pointers = np.frombuffer(
|
||||
self._bin_buffer, dtype=np.int64, count=self._len, offset=offset + self._sizes.nbytes
|
||||
)
|
||||
print(" reading document index...")
|
||||
self._doc_idx = np.frombuffer(self._bin_buffer,
|
||||
dtype=np.int64,
|
||||
count=self._doc_count,
|
||||
offset=offset + self._sizes.nbytes + self._pointers.nbytes)
|
||||
self._doc_idx = np.frombuffer(
|
||||
self._bin_buffer,
|
||||
dtype=np.int64,
|
||||
count=self._doc_count,
|
||||
offset=offset + self._sizes.nbytes + self._pointers.nbytes,
|
||||
)
|
||||
|
||||
def __del__(self):
|
||||
self._bin_buffer_mmap._mmap.close()
|
||||
@@ -443,7 +443,7 @@ class MMapIndexedDataset(torch.utils.data.Dataset):
|
||||
print(" warming up data mmap file...")
|
||||
_warmup_mmap_file(data_file_path(self._path))
|
||||
print(" creating numpy buffer of mmap...")
|
||||
self._bin_buffer_mmap = np.memmap(data_file_path(self._path), mode='r', order='C')
|
||||
self._bin_buffer_mmap = np.memmap(data_file_path(self._path), mode="r", order="C")
|
||||
print(" creating memory view of numpy buffer...")
|
||||
self._bin_buffer = memoryview(self._bin_buffer_mmap)
|
||||
|
||||
@@ -474,7 +474,7 @@ class MMapIndexedDataset(torch.utils.data.Dataset):
|
||||
return sents
|
||||
|
||||
def get(self, idx, offset=0, length=None):
|
||||
""" Retrieves a single item from the dataset with the option to only
|
||||
"""Retrieves a single item from the dataset with the option to only
|
||||
return a portion of the item.
|
||||
|
||||
get(idx) is the same as [idx] but get() does not support slicing.
|
||||
@@ -506,20 +506,19 @@ class MMapIndexedDataset(torch.utils.data.Dataset):
|
||||
|
||||
@staticmethod
|
||||
def exists(path):
|
||||
return (os.path.exists(index_file_path(path)) and os.path.exists(data_file_path(path)))
|
||||
return os.path.exists(index_file_path(path)) and os.path.exists(data_file_path(path))
|
||||
|
||||
|
||||
class MMapIndexedDatasetBuilder(object):
|
||||
|
||||
def __init__(self, out_file, dtype=np.int64):
|
||||
self._data_file = open(out_file, 'wb')
|
||||
self._data_file = open(out_file, "wb")
|
||||
self._dtype = dtype
|
||||
self._sizes = []
|
||||
self._doc_idx = [0]
|
||||
|
||||
def add_item(self, tensor):
|
||||
np_array = np.array(tensor.numpy(), dtype=self._dtype)
|
||||
self._data_file.write(np_array.tobytes(order='C'))
|
||||
self._data_file.write(np_array.tobytes(order="C"))
|
||||
self._sizes.append(np_array.size)
|
||||
|
||||
def end_document(self):
|
||||
@@ -534,7 +533,7 @@ class MMapIndexedDatasetBuilder(object):
|
||||
self._sizes.append(size)
|
||||
|
||||
# Concatenate data
|
||||
with open(data_file_path(another_file), 'rb') as f:
|
||||
with open(data_file_path(another_file), "rb") as f:
|
||||
shutil.copyfileobj(f, self._data_file)
|
||||
|
||||
def finalize(self, index_file):
|
||||
|
||||
@@ -2,13 +2,12 @@
|
||||
# put some code used during development and manual testing of
|
||||
# indexed_dataset.
|
||||
|
||||
from megatron.data import indexed_dataset
|
||||
from megatron.tokenizer import build_tokenizer
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
|
||||
import torch
|
||||
from megatron.data import indexed_dataset
|
||||
from megatron.tokenizer import build_tokenizer
|
||||
|
||||
script_dir = os.path.dirname(os.path.realpath(__file__))
|
||||
sys.path.append(os.path.join(script_dir, "../../../"))
|
||||
@@ -42,7 +41,7 @@ def test_indexed_dataset(args):
|
||||
|
||||
def test_indexed_dataset_get(args):
|
||||
ds = indexed_dataset.make_dataset(args.data, args.dataset_impl)
|
||||
tokenizer = build_tokenizer(args)
|
||||
build_tokenizer(args)
|
||||
size = ds.sizes[0]
|
||||
print(f"size: {size}")
|
||||
full = ds.get(0)
|
||||
@@ -61,6 +60,7 @@ def test_indexed_dataset_get(args):
|
||||
print(part)
|
||||
# print(tokenizer.detokenize(part.data.tolist()))
|
||||
|
||||
|
||||
# def test_albert_dataset(args):
|
||||
# # tokenizer = FullBertTokenizer(args.vocab, do_lower_case=True)
|
||||
# # idataset = indexed_dataset.make_dataset(args.data, args.dataset_impl)
|
||||
@@ -81,34 +81,27 @@ def test_indexed_dataset_get(args):
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--data', type=str, help='prefix to data files')
|
||||
parser.add_argument('--dataset-impl', type=str, default='infer',
|
||||
choices=['lazy', 'cached', 'mmap', 'infer'])
|
||||
parser.add_argument('--count', type=int, default=10,
|
||||
help='Number of samples/documents to print')
|
||||
parser.add_argument("--data", type=str, help="prefix to data files")
|
||||
parser.add_argument("--dataset-impl", type=str, default="infer", choices=["lazy", "cached", "mmap", "infer"])
|
||||
parser.add_argument("--count", type=int, default=10, help="Number of samples/documents to print")
|
||||
|
||||
group = parser.add_argument_group(title='tokenizer')
|
||||
group.add_argument('--tokenizer-type', type=str, required=True,
|
||||
choices=['BertWordPieceLowerCase',
|
||||
'GPT2BPETokenizer'],
|
||||
help='What type of tokenizer to use.')
|
||||
group.add_argument('--vocab-file', type=str, default=None,
|
||||
help='Path to the vocab file')
|
||||
group.add_argument('--merge-file', type=str, default=None,
|
||||
help='Path to the BPE merge file (if necessary).')
|
||||
group = parser.add_argument_group(title="tokenizer")
|
||||
group.add_argument(
|
||||
"--tokenizer-type",
|
||||
type=str,
|
||||
required=True,
|
||||
choices=["BertWordPieceLowerCase", "GPT2BPETokenizer"],
|
||||
help="What type of tokenizer to use.",
|
||||
)
|
||||
group.add_argument("--vocab-file", type=str, default=None, help="Path to the vocab file")
|
||||
group.add_argument("--merge-file", type=str, default=None, help="Path to the BPE merge file (if necessary).")
|
||||
|
||||
parser.add_argument('--epochs', type=int, default=5,
|
||||
help='Number of epochs to plan for')
|
||||
parser.add_argument('--max-num-samples', type=int, default=None,
|
||||
help='Maximum number of samples to plan for')
|
||||
parser.add_argument('--masked-lm-prob', type=float, default=0.15,
|
||||
help='probability of masking tokens')
|
||||
parser.add_argument('--seq-length', type=int, default=512,
|
||||
help='maximum sequence length')
|
||||
parser.add_argument('--short-seq-prob', type=float, default=0.1,
|
||||
help='probability of creating a short sequence')
|
||||
parser.add_argument('--seed', type=int, default=1234,
|
||||
help='random seed')
|
||||
parser.add_argument("--epochs", type=int, default=5, help="Number of epochs to plan for")
|
||||
parser.add_argument("--max-num-samples", type=int, default=None, help="Maximum number of samples to plan for")
|
||||
parser.add_argument("--masked-lm-prob", type=float, default=0.15, help="probability of masking tokens")
|
||||
parser.add_argument("--seq-length", type=int, default=512, help="maximum sequence length")
|
||||
parser.add_argument("--short-seq-prob", type=float, default=0.1, help="probability of creating a short sequence")
|
||||
parser.add_argument("--seed", type=int, default=1234, help="random seed")
|
||||
args = parser.parse_args()
|
||||
args.rank = 0
|
||||
args.make_vocab_size_divisible_by = 128
|
||||
@@ -117,7 +110,7 @@ def main():
|
||||
if args.dataset_impl == "infer":
|
||||
args.dataset_impl = indexed_dataset.infer_dataset_impl(args.data)
|
||||
|
||||
# test_albert_dataset(args)
|
||||
# test_albert_dataset(args)
|
||||
test_indexed_dataset_get(args)
|
||||
|
||||
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
import torch
|
||||
|
||||
|
||||
class DummyDataloader():
|
||||
|
||||
class DummyDataloader:
|
||||
def __init__(self, batch_size, vocab_size, seq_length):
|
||||
self.batch_size = batch_size
|
||||
self.vocab_size = vocab_size
|
||||
@@ -10,30 +9,44 @@ class DummyDataloader():
|
||||
self.step = 0
|
||||
|
||||
def generate(self):
|
||||
tokens = torch.randint(low=0, high=self.vocab_size, size=(
|
||||
self.batch_size,
|
||||
self.seq_length,
|
||||
))
|
||||
types = torch.randint(low=0, high=3, size=(
|
||||
self.batch_size,
|
||||
self.seq_length,
|
||||
))
|
||||
tokens = torch.randint(
|
||||
low=0,
|
||||
high=self.vocab_size,
|
||||
size=(
|
||||
self.batch_size,
|
||||
self.seq_length,
|
||||
),
|
||||
)
|
||||
types = torch.randint(
|
||||
low=0,
|
||||
high=3,
|
||||
size=(
|
||||
self.batch_size,
|
||||
self.seq_length,
|
||||
),
|
||||
)
|
||||
sentence_order = torch.randint(low=0, high=2, size=(self.batch_size,))
|
||||
loss_mask = torch.randint(low=0, high=2, size=(
|
||||
self.batch_size,
|
||||
self.seq_length,
|
||||
))
|
||||
loss_mask = torch.randint(
|
||||
low=0,
|
||||
high=2,
|
||||
size=(
|
||||
self.batch_size,
|
||||
self.seq_length,
|
||||
),
|
||||
)
|
||||
lm_labels = torch.randint(low=0, high=self.vocab_size, size=(self.batch_size, self.seq_length))
|
||||
padding_mask = torch.randint(low=0, high=2, size=(self.batch_size, self.seq_length))
|
||||
return dict(text=tokens,
|
||||
types=types,
|
||||
is_random=sentence_order,
|
||||
loss_mask=loss_mask,
|
||||
labels=lm_labels,
|
||||
padding_mask=padding_mask)
|
||||
return dict(
|
||||
text=tokens,
|
||||
types=types,
|
||||
is_random=sentence_order,
|
||||
loss_mask=loss_mask,
|
||||
labels=lm_labels,
|
||||
padding_mask=padding_mask,
|
||||
)
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
return self.generate()
|
||||
return self.generate()
|
||||
|
||||
@@ -16,7 +16,6 @@
|
||||
|
||||
from .tokenizer import build_tokenizer
|
||||
|
||||
|
||||
_TOKENIZER = None
|
||||
_PADDED_VOCAB_SIZE = -1
|
||||
|
||||
|
||||
@@ -15,13 +15,12 @@
|
||||
|
||||
"""Tokenization classes."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import collections
|
||||
import re
|
||||
import unicodedata
|
||||
|
||||
import six
|
||||
|
||||
|
||||
@@ -43,14 +42,13 @@ def validate_case_matches_checkpoint(do_lower_case, init_checkpoint):
|
||||
model_name = m.group(1)
|
||||
|
||||
lower_models = [
|
||||
"uncased_L-24_H-1024_A-16", "uncased_L-12_H-768_A-12",
|
||||
"multilingual_L-12_H-768_A-12", "chinese_L-12_H-768_A-12"
|
||||
"uncased_L-24_H-1024_A-16",
|
||||
"uncased_L-12_H-768_A-12",
|
||||
"multilingual_L-12_H-768_A-12",
|
||||
"chinese_L-12_H-768_A-12",
|
||||
]
|
||||
|
||||
cased_models = [
|
||||
"cased_L-12_H-768_A-12", "cased_L-24_H-1024_A-16",
|
||||
"multi_cased_L-12_H-768_A-12"
|
||||
]
|
||||
cased_models = ["cased_L-12_H-768_A-12", "cased_L-24_H-1024_A-16", "multi_cased_L-12_H-768_A-12"]
|
||||
|
||||
is_bad_config = False
|
||||
if model_name in lower_models and not do_lower_case:
|
||||
@@ -71,8 +69,8 @@ def validate_case_matches_checkpoint(do_lower_case, init_checkpoint):
|
||||
"However, `%s` seems to be a %s model, so you "
|
||||
"should pass in `--do_lower_case=%s` so that the fine-tuning matches "
|
||||
"how the model was pre-training. If this error is wrong, please "
|
||||
"just comment out this check." % (actual_flag, init_checkpoint,
|
||||
model_name, case_name, opposite_flag))
|
||||
"just comment out this check." % (actual_flag, init_checkpoint, model_name, case_name, opposite_flag)
|
||||
)
|
||||
|
||||
|
||||
def convert_to_unicode(text):
|
||||
@@ -183,27 +181,27 @@ class FullTokenizer(object):
|
||||
|
||||
@staticmethod
|
||||
def convert_tokens_to_string(tokens, clean_up_tokenization_spaces=True):
|
||||
""" Converts a sequence of tokens (string) in a single string. """
|
||||
"""Converts a sequence of tokens (string) in a single string."""
|
||||
|
||||
def clean_up_tokenization(out_string):
|
||||
""" Clean up a list of simple English tokenization artifacts
|
||||
"""Clean up a list of simple English tokenization artifacts
|
||||
like spaces before punctuations and abbreviated forms.
|
||||
"""
|
||||
out_string = (
|
||||
out_string.replace(" .", ".")
|
||||
.replace(" ?", "?")
|
||||
.replace(" !", "!")
|
||||
.replace(" ,", ",")
|
||||
.replace(" ' ", "'")
|
||||
.replace(" n't", "n't")
|
||||
.replace(" 'm", "'m")
|
||||
.replace(" 's", "'s")
|
||||
.replace(" 've", "'ve")
|
||||
.replace(" 're", "'re")
|
||||
.replace(" ?", "?")
|
||||
.replace(" !", "!")
|
||||
.replace(" ,", ",")
|
||||
.replace(" ' ", "'")
|
||||
.replace(" n't", "n't")
|
||||
.replace(" 'm", "'m")
|
||||
.replace(" 's", "'s")
|
||||
.replace(" 've", "'ve")
|
||||
.replace(" 're", "'re")
|
||||
)
|
||||
return out_string
|
||||
|
||||
text = ' '.join(tokens).replace(' ##', '').strip()
|
||||
text = " ".join(tokens).replace(" ##", "").strip()
|
||||
if clean_up_tokenization_spaces:
|
||||
clean_text = clean_up_tokenization(text)
|
||||
return clean_text
|
||||
@@ -303,14 +301,16 @@ class BasicTokenizer(object):
|
||||
# as is Japanese Hiragana and Katakana. Those alphabets are used to write
|
||||
# space-separated words, so they are not treated specially and handled
|
||||
# like the all of the other languages.
|
||||
if ((cp >= 0x4E00 and cp <= 0x9FFF) or #
|
||||
(cp >= 0x3400 and cp <= 0x4DBF) or #
|
||||
(cp >= 0x20000 and cp <= 0x2A6DF) or #
|
||||
(cp >= 0x2A700 and cp <= 0x2B73F) or #
|
||||
(cp >= 0x2B740 and cp <= 0x2B81F) or #
|
||||
(cp >= 0x2B820 and cp <= 0x2CEAF) or
|
||||
(cp >= 0xF900 and cp <= 0xFAFF) or #
|
||||
(cp >= 0x2F800 and cp <= 0x2FA1F)): #
|
||||
if (
|
||||
(cp >= 0x4E00 and cp <= 0x9FFF)
|
||||
or (cp >= 0x3400 and cp <= 0x4DBF) #
|
||||
or (cp >= 0x20000 and cp <= 0x2A6DF) #
|
||||
or (cp >= 0x2A700 and cp <= 0x2B73F) #
|
||||
or (cp >= 0x2B740 and cp <= 0x2B81F) #
|
||||
or (cp >= 0x2B820 and cp <= 0x2CEAF) #
|
||||
or (cp >= 0xF900 and cp <= 0xFAFF)
|
||||
or (cp >= 0x2F800 and cp <= 0x2FA1F) #
|
||||
): #
|
||||
return True
|
||||
|
||||
return False
|
||||
@@ -320,7 +320,7 @@ class BasicTokenizer(object):
|
||||
output = []
|
||||
for char in text:
|
||||
cp = ord(char)
|
||||
if cp == 0 or cp == 0xfffd or _is_control(char):
|
||||
if cp == 0 or cp == 0xFFFD or _is_control(char):
|
||||
continue
|
||||
if _is_whitespace(char):
|
||||
output.append(" ")
|
||||
@@ -422,8 +422,7 @@ def _is_punctuation(char):
|
||||
# Characters such as "^", "$", and "`" are not in the Unicode
|
||||
# Punctuation class but we treat them as punctuation anyways, for
|
||||
# consistency.
|
||||
if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or
|
||||
(cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)):
|
||||
if (cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126):
|
||||
return True
|
||||
cat = unicodedata.category(char)
|
||||
if cat.startswith("P"):
|
||||
|
||||
@@ -25,16 +25,15 @@ from .bert_tokenization import FullTokenizer as FullBertTokenizer
|
||||
def build_tokenizer(vocab_file, tokenizer_type, vocab_extra_ids=0):
|
||||
"""Initialize tokenizer."""
|
||||
if not gpc.is_initialized(ParallelMode.GLOBAL) or gpc.get_global_rank() == 0:
|
||||
print('> building {} tokenizer ...'.format(tokenizer_type), flush=True)
|
||||
print("> building {} tokenizer ...".format(tokenizer_type), flush=True)
|
||||
|
||||
# Select and instantiate the tokenizer.
|
||||
if tokenizer_type == 'BertWordPieceLowerCase':
|
||||
if tokenizer_type == "BertWordPieceLowerCase":
|
||||
tokenizer = _BertWordPieceTokenizer(vocab_file=vocab_file, lower_case=True, vocab_extra_ids=vocab_extra_ids)
|
||||
elif tokenizer_type == 'BertWordPieceCase':
|
||||
elif tokenizer_type == "BertWordPieceCase":
|
||||
tokenizer = _BertWordPieceTokenizer(vocab_file=vocab_file, lower_case=False, vocab_extra_ids=vocab_extra_ids)
|
||||
else:
|
||||
raise NotImplementedError('{} tokenizer is not '
|
||||
'implemented.'.format(tokenizer_type))
|
||||
raise NotImplementedError("{} tokenizer is not " "implemented.".format(tokenizer_type))
|
||||
|
||||
# Add vocab size.
|
||||
padded_vocab_size = _vocab_size_with_padding(tokenizer.vocab_size)
|
||||
@@ -55,9 +54,11 @@ def _vocab_size_with_padding(orig_vocab_size, make_vocab_size_divisible_by=128):
|
||||
while (after % multiple) != 0:
|
||||
after += 1
|
||||
if not gpc.is_initialized(ParallelMode.GLOBAL) or gpc.get_global_rank() == 0:
|
||||
print(' > padded vocab (size: {}) with {} dummy tokens '
|
||||
'(new size: {})'.format(orig_vocab_size, after - orig_vocab_size, after),
|
||||
flush=True)
|
||||
print(
|
||||
" > padded vocab (size: {}) with {} dummy tokens "
|
||||
"(new size: {})".format(orig_vocab_size, after - orig_vocab_size, after),
|
||||
flush=True,
|
||||
)
|
||||
return after
|
||||
|
||||
|
||||
@@ -77,46 +78,38 @@ class AbstractTokenizer(ABC):
|
||||
@abstractmethod
|
||||
def vocab(self):
|
||||
"""Dictionary from vocab text token to id token."""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def inv_vocab(self):
|
||||
"""Dictionary from vocab id token to text token."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def tokenize(self, text):
|
||||
pass
|
||||
|
||||
def detokenize(self, token_ids):
|
||||
raise NotImplementedError('detokenizer is not implemented for {} '
|
||||
'tokenizer'.format(self.name))
|
||||
raise NotImplementedError("detokenizer is not implemented for {} " "tokenizer".format(self.name))
|
||||
|
||||
@property
|
||||
def cls(self):
|
||||
raise NotImplementedError('CLS is not provided for {} '
|
||||
'tokenizer'.format(self.name))
|
||||
raise NotImplementedError("CLS is not provided for {} " "tokenizer".format(self.name))
|
||||
|
||||
@property
|
||||
def sep(self):
|
||||
raise NotImplementedError('SEP is not provided for {} '
|
||||
'tokenizer'.format(self.name))
|
||||
raise NotImplementedError("SEP is not provided for {} " "tokenizer".format(self.name))
|
||||
|
||||
@property
|
||||
def pad(self):
|
||||
raise NotImplementedError('PAD is not provided for {} '
|
||||
'tokenizer'.format(self.name))
|
||||
raise NotImplementedError("PAD is not provided for {} " "tokenizer".format(self.name))
|
||||
|
||||
@property
|
||||
def eod(self):
|
||||
raise NotImplementedError('EOD is not provided for {} '
|
||||
'tokenizer'.format(self.name))
|
||||
raise NotImplementedError("EOD is not provided for {} " "tokenizer".format(self.name))
|
||||
|
||||
@property
|
||||
def mask(self):
|
||||
raise NotImplementedError('MASK is not provided for {} '
|
||||
'tokenizer'.format(self.name))
|
||||
raise NotImplementedError("MASK is not provided for {} " "tokenizer".format(self.name))
|
||||
|
||||
|
||||
class _BertWordPieceTokenizer(AbstractTokenizer):
|
||||
@@ -124,24 +117,24 @@ class _BertWordPieceTokenizer(AbstractTokenizer):
|
||||
|
||||
def __init__(self, vocab_file, lower_case=True, vocab_extra_ids=0):
|
||||
if lower_case:
|
||||
name = 'BERT Lower Case'
|
||||
name = "BERT Lower Case"
|
||||
else:
|
||||
name = 'BERT Upper Case'
|
||||
name = "BERT Upper Case"
|
||||
super().__init__(name)
|
||||
self.tokenizer = FullBertTokenizer(vocab_file, do_lower_case=lower_case)
|
||||
self.cls_id = self.tokenizer.vocab['[CLS]']
|
||||
self.sep_id = self.tokenizer.vocab['[SEP]']
|
||||
self.pad_id = self.tokenizer.vocab['[PAD]']
|
||||
self.mask_id = self.tokenizer.vocab['[MASK]']
|
||||
self.cls_id = self.tokenizer.vocab["[CLS]"]
|
||||
self.sep_id = self.tokenizer.vocab["[SEP]"]
|
||||
self.pad_id = self.tokenizer.vocab["[PAD]"]
|
||||
self.mask_id = self.tokenizer.vocab["[MASK]"]
|
||||
self._additional_special_tokens = []
|
||||
|
||||
# (dsachan) Add BOS and EOS tokens
|
||||
SPECIAL_TOKENS = {'eos_token': '[EOS]', 'bos_token': '[BOS]'}
|
||||
self._bos_token = '[BOS]'
|
||||
SPECIAL_TOKENS = {"eos_token": "[EOS]", "bos_token": "[BOS]"}
|
||||
self._bos_token = "[BOS]"
|
||||
self.add_token(self._bos_token)
|
||||
self._bos_token_id = self.vocab.get(self._bos_token)
|
||||
|
||||
self._eos_token = '[EOS]'
|
||||
self._eos_token = "[EOS]"
|
||||
self.add_token(self._eos_token)
|
||||
self._eos_token_id = self.vocab.get(self._eos_token)
|
||||
|
||||
@@ -185,7 +178,7 @@ class _BertWordPieceTokenizer(AbstractTokenizer):
|
||||
|
||||
def decode_token_ids(self, token_ids):
|
||||
tokens = self.tokenizer.convert_ids_to_tokens(token_ids)
|
||||
exclude_list = ['[PAD]', '[CLS]']
|
||||
exclude_list = ["[PAD]", "[CLS]"]
|
||||
non_pads = [t for t in tokens if t not in exclude_list]
|
||||
|
||||
result = ""
|
||||
@@ -215,32 +208,32 @@ class _BertWordPieceTokenizer(AbstractTokenizer):
|
||||
|
||||
@property
|
||||
def bos_token(self):
|
||||
""" Beginning of sentence token id """
|
||||
"""Beginning of sentence token id"""
|
||||
return self._bos_token
|
||||
|
||||
@property
|
||||
def eos_token(self):
|
||||
""" End of sentence token id """
|
||||
"""End of sentence token id"""
|
||||
return self._eos_token
|
||||
|
||||
@property
|
||||
def additional_special_tokens(self):
|
||||
""" All the additional special tokens you may want to use (list of strings)."""
|
||||
"""All the additional special tokens you may want to use (list of strings)."""
|
||||
return self._additional_special_tokens
|
||||
|
||||
@property
|
||||
def bos_token_id(self):
|
||||
""" Id of the beginning of sentence token in the vocabulary."""
|
||||
"""Id of the beginning of sentence token in the vocabulary."""
|
||||
return self._bos_token_id
|
||||
|
||||
@property
|
||||
def eos_token_id(self):
|
||||
""" Id of the end of sentence token in the vocabulary."""
|
||||
"""Id of the end of sentence token in the vocabulary."""
|
||||
return self._eos_token_id
|
||||
|
||||
@property
|
||||
def additional_special_tokens_ids(self):
|
||||
""" Ids of all the additional special tokens in the vocabulary (list of integers)."""
|
||||
"""Ids of all the additional special tokens in the vocabulary (list of integers)."""
|
||||
return [self.vocab.get(token) for token in self._additional_special_tokens]
|
||||
|
||||
@additional_special_tokens.setter
|
||||
|
||||
@@ -1,17 +1,12 @@
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from colossalai.legacy.context import ParallelMode
|
||||
from colossalai.legacy.core import global_context as gpc
|
||||
from colossalai.logging import get_dist_logger
|
||||
|
||||
from .cross_entropy import vocab_cross_entropy
|
||||
|
||||
|
||||
class BertLoss(nn.Module):
|
||||
|
||||
def forward(self, lm_loss, sop_logits, loss_mask, sentence_order):
|
||||
lm_loss_ = lm_loss.float()
|
||||
loss_mask = loss_mask.float()
|
||||
|
||||
@@ -1,11 +1,8 @@
|
||||
import torch
|
||||
from torch.cuda.amp import custom_bwd, custom_fwd
|
||||
|
||||
from colossalai.legacy.context.parallel_mode import ParallelMode
|
||||
|
||||
|
||||
class _VocabCrossEntropy(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
@custom_fwd
|
||||
def forward(ctx, vocab_parallel_logits, target):
|
||||
@@ -59,7 +56,7 @@ class _VocabCrossEntropy(torch.autograd.Function):
|
||||
|
||||
# Add the gradient from matching classes.
|
||||
arange_1d = torch.arange(start=0, end=grad_2d.size()[0], device=grad_2d.device)
|
||||
grad_2d[arange_1d, masked_target_1d] -= (1.0 - target_mask.view(-1).float())
|
||||
grad_2d[arange_1d, masked_target_1d] -= 1.0 - target_mask.view(-1).float()
|
||||
|
||||
# Finally elementwise multiplication with the output gradients.
|
||||
grad_input.mul_(grad_output.unsqueeze(dim=-1))
|
||||
|
||||
@@ -1,11 +1,9 @@
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def ensure_divisibility(numerator, denominator):
|
||||
"""Ensure that numerator is divisible by the denominator."""
|
||||
assert numerator % denominator == 0, '{} is not divisible by {}'.format(
|
||||
numerator, denominator)
|
||||
assert numerator % denominator == 0, "{} is not divisible by {}".format(numerator, denominator)
|
||||
|
||||
|
||||
def divide(numerator, denominator):
|
||||
@@ -15,8 +13,7 @@ def divide(numerator, denominator):
|
||||
return numerator // denominator
|
||||
|
||||
|
||||
def split_tensor_along_last_dim(tensor, num_partitions,
|
||||
contiguous_split_chunks=False):
|
||||
def split_tensor_along_last_dim(tensor, num_partitions, contiguous_split_chunks=False):
|
||||
"""Split a tensor along its last dimension.
|
||||
Arguments:
|
||||
tensor: input tensor.
|
||||
@@ -38,12 +35,11 @@ def split_tensor_along_last_dim(tensor, num_partitions,
|
||||
|
||||
class VocabUtility:
|
||||
"""Split the vocabulary into `world_size` chunks amd return the
|
||||
first and last index of the vocabulary belonging to the `rank`
|
||||
partition: Note that indices in [fist, last)"""
|
||||
first and last index of the vocabulary belonging to the `rank`
|
||||
partition: Note that indices in [fist, last)"""
|
||||
|
||||
@staticmethod
|
||||
def vocab_range_from_per_partition_vocab_size(per_partition_vocab_size,
|
||||
rank, world_size):
|
||||
def vocab_range_from_per_partition_vocab_size(per_partition_vocab_size, rank, world_size):
|
||||
index_f = rank * per_partition_vocab_size
|
||||
index_l = index_f + per_partition_vocab_size
|
||||
return index_f, index_l
|
||||
@@ -51,5 +47,4 @@ class VocabUtility:
|
||||
@staticmethod
|
||||
def vocab_range_from_global_vocab_size(global_vocab_size, rank, world_size):
|
||||
per_partition_vocab_size = divide(global_vocab_size, world_size)
|
||||
return VocabUtility.vocab_range_from_per_partition_vocab_size(
|
||||
per_partition_vocab_size, rank, world_size)
|
||||
return VocabUtility.vocab_range_from_per_partition_vocab_size(per_partition_vocab_size, rank, world_size)
|
||||
|
||||
@@ -21,16 +21,17 @@ import math
|
||||
class AnnealingLR(object):
|
||||
"""Anneals the learning rate."""
|
||||
|
||||
def __init__(self,
|
||||
optimizer,
|
||||
max_lr,
|
||||
min_lr,
|
||||
warmup_steps,
|
||||
decay_steps,
|
||||
decay_style,
|
||||
use_checkpoint_lr_scheduler=True,
|
||||
override_lr_scheduler=False):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
optimizer,
|
||||
max_lr,
|
||||
min_lr,
|
||||
warmup_steps,
|
||||
decay_steps,
|
||||
decay_style,
|
||||
use_checkpoint_lr_scheduler=True,
|
||||
override_lr_scheduler=False,
|
||||
):
|
||||
# Class values.
|
||||
self.optimizer = optimizer
|
||||
|
||||
@@ -50,23 +51,21 @@ class AnnealingLR(object):
|
||||
self.override_lr_scheduler = override_lr_scheduler
|
||||
self.use_checkpoint_lr_scheduler = use_checkpoint_lr_scheduler
|
||||
if self.override_lr_scheduler:
|
||||
assert not self.use_checkpoint_lr_scheduler, 'both override and '\
|
||||
'use-checkpoint are set.'
|
||||
assert not self.use_checkpoint_lr_scheduler, "both override and " "use-checkpoint are set."
|
||||
|
||||
# Set the learning rate
|
||||
self.step(0)
|
||||
|
||||
def get_lr(self):
|
||||
"""Learning rate decay functions from:
|
||||
https://openreview.net/pdf?id=BJYwwY9ll pg. 4"""
|
||||
https://openreview.net/pdf?id=BJYwwY9ll pg. 4"""
|
||||
|
||||
# Use linear warmup for the initial part.
|
||||
if self.warmup_steps > 0 and self.num_steps <= self.warmup_steps:
|
||||
return self.max_lr * float(self.num_steps) / \
|
||||
float(self.warmup_steps)
|
||||
return self.max_lr * float(self.num_steps) / float(self.warmup_steps)
|
||||
|
||||
# If the learning rate is constant, just return the initial value.
|
||||
if self.decay_style == 'constant':
|
||||
if self.decay_style == "constant":
|
||||
return self.max_lr
|
||||
|
||||
# For any steps larger than `self.decay_steps`, use `self.min_lr`.
|
||||
@@ -81,13 +80,12 @@ class AnnealingLR(object):
|
||||
assert decay_ratio <= 1.0
|
||||
delta_lr = self.max_lr - self.min_lr
|
||||
|
||||
if self.decay_style == 'linear':
|
||||
coeff = (1.0 - decay_ratio)
|
||||
elif self.decay_style == 'cosine':
|
||||
if self.decay_style == "linear":
|
||||
coeff = 1.0 - decay_ratio
|
||||
elif self.decay_style == "cosine":
|
||||
coeff = 0.5 * (math.cos(math.pi * decay_ratio) + 1.0)
|
||||
else:
|
||||
raise Exception('{} decay style is not supported.'.format(
|
||||
self.decay_style))
|
||||
raise Exception("{} decay style is not supported.".format(self.decay_style))
|
||||
|
||||
return self.min_lr + coeff * delta_lr
|
||||
|
||||
@@ -96,16 +94,16 @@ class AnnealingLR(object):
|
||||
self.num_steps += increment
|
||||
new_lr = self.get_lr()
|
||||
for group in self.optimizer.param_groups:
|
||||
group['lr'] = new_lr
|
||||
group["lr"] = new_lr
|
||||
|
||||
def state_dict(self):
|
||||
state_dict = {
|
||||
'max_lr': self.max_lr,
|
||||
'warmup_steps': self.warmup_steps,
|
||||
'num_steps': self.num_steps,
|
||||
'decay_style': self.decay_style,
|
||||
'decay_steps': self.decay_steps,
|
||||
'min_lr': self.min_lr
|
||||
"max_lr": self.max_lr,
|
||||
"warmup_steps": self.warmup_steps,
|
||||
"num_steps": self.num_steps,
|
||||
"decay_style": self.decay_style,
|
||||
"decay_steps": self.decay_steps,
|
||||
"min_lr": self.min_lr,
|
||||
}
|
||||
return state_dict
|
||||
|
||||
@@ -116,43 +114,35 @@ class AnnealingLR(object):
|
||||
return cls_value
|
||||
|
||||
if not self.use_checkpoint_lr_scheduler:
|
||||
assert cls_value == sd_value, \
|
||||
f'AnnealingLR: class input value {cls_value} and checkpoint' \
|
||||
f'value {sd_value} for {name} do not match'
|
||||
assert cls_value == sd_value, (
|
||||
f"AnnealingLR: class input value {cls_value} and checkpoint" f"value {sd_value} for {name} do not match"
|
||||
)
|
||||
return sd_value
|
||||
|
||||
def load_state_dict(self, sd):
|
||||
|
||||
if 'start_lr' in sd:
|
||||
max_lr_ = sd['start_lr']
|
||||
if "start_lr" in sd:
|
||||
max_lr_ = sd["start_lr"]
|
||||
else:
|
||||
max_lr_ = sd['max_lr']
|
||||
self.max_lr = self._check_and_set(self.max_lr, max_lr_,
|
||||
'learning rate')
|
||||
max_lr_ = sd["max_lr"]
|
||||
self.max_lr = self._check_and_set(self.max_lr, max_lr_, "learning rate")
|
||||
|
||||
self.min_lr = self._check_and_set(self.min_lr, sd['min_lr'],
|
||||
'minimum learning rate')
|
||||
self.min_lr = self._check_and_set(self.min_lr, sd["min_lr"], "minimum learning rate")
|
||||
|
||||
if 'warmup_iter' in sd:
|
||||
warmup_steps_ = sd['warmup_iter']
|
||||
if "warmup_iter" in sd:
|
||||
warmup_steps_ = sd["warmup_iter"]
|
||||
else:
|
||||
warmup_steps_ = sd['warmup_steps']
|
||||
self.warmup_steps = self._check_and_set(self.warmup_steps,
|
||||
warmup_steps_,
|
||||
'warmup iterations')
|
||||
warmup_steps_ = sd["warmup_steps"]
|
||||
self.warmup_steps = self._check_and_set(self.warmup_steps, warmup_steps_, "warmup iterations")
|
||||
|
||||
if 'end_iter' in sd:
|
||||
decay_steps_ = sd['end_iter']
|
||||
if "end_iter" in sd:
|
||||
decay_steps_ = sd["end_iter"]
|
||||
else:
|
||||
decay_steps_ = sd['decay_steps']
|
||||
self.decay_steps = self._check_and_set(self.decay_steps, decay_steps_,
|
||||
'total number of iterations')
|
||||
self.decay_style = self._check_and_set(self.decay_style,
|
||||
sd['decay_style'],
|
||||
'decay style')
|
||||
decay_steps_ = sd["decay_steps"]
|
||||
self.decay_steps = self._check_and_set(self.decay_steps, decay_steps_, "total number of iterations")
|
||||
self.decay_style = self._check_and_set(self.decay_style, sd["decay_style"], "decay style")
|
||||
|
||||
if 'num_iters' in sd:
|
||||
num_steps = sd['num_iters']
|
||||
if "num_iters" in sd:
|
||||
num_steps = sd["num_iters"]
|
||||
else:
|
||||
num_steps = sd['num_steps']
|
||||
num_steps = sd["num_steps"]
|
||||
self.step(increment=num_steps)
|
||||
|
||||
@@ -1,2 +0,0 @@
|
||||
|
||||
|
||||
|
||||
@@ -16,7 +16,6 @@ from .layers.init_method import init_normal, output_init_normal
|
||||
|
||||
|
||||
class BertForPretrain(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size,
|
||||
@@ -34,7 +33,9 @@ class BertForPretrain(nn.Module):
|
||||
):
|
||||
super().__init__()
|
||||
self.seq_parallel_size = gpc.get_world_size(ParallelMode.SEQUENCE)
|
||||
assert max_sequence_length % self.seq_parallel_size == 0, 'sequence length is not divisible by the sequence parallel size'
|
||||
assert (
|
||||
max_sequence_length % self.seq_parallel_size == 0
|
||||
), "sequence length is not divisible by the sequence parallel size"
|
||||
self.sub_seq_length = max_sequence_length // self.seq_parallel_size
|
||||
self.init_std = init_std
|
||||
self.num_layers = num_layers
|
||||
@@ -43,28 +44,32 @@ class BertForPretrain(nn.Module):
|
||||
num_tokentypes = 0
|
||||
|
||||
self.preprocessor = PreProcessor(self.sub_seq_length)
|
||||
self.embedding = Embedding(hidden_size=hidden_size,
|
||||
vocab_size=vocab_size,
|
||||
max_sequence_length=max_sequence_length,
|
||||
embedding_dropout_prob=dropout_prob,
|
||||
num_tokentypes=num_tokentypes)
|
||||
self.embedding = Embedding(
|
||||
hidden_size=hidden_size,
|
||||
vocab_size=vocab_size,
|
||||
max_sequence_length=max_sequence_length,
|
||||
embedding_dropout_prob=dropout_prob,
|
||||
num_tokentypes=num_tokentypes,
|
||||
)
|
||||
self.bert_layers = nn.ModuleList()
|
||||
|
||||
for i in range(num_layers):
|
||||
bert_layer = BertLayer(layer_number=i + 1,
|
||||
hidden_size=hidden_size,
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_dropout=dropout_prob,
|
||||
mlp_ratio=mlp_ratio,
|
||||
hidden_dropout=dropout_prob,
|
||||
convert_fp16_to_fp32_in_softmax=convert_fp16_to_fp32_in_softmax,
|
||||
is_naive_fp16=is_naive_fp16)
|
||||
bert_layer = BertLayer(
|
||||
layer_number=i + 1,
|
||||
hidden_size=hidden_size,
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_dropout=dropout_prob,
|
||||
mlp_ratio=mlp_ratio,
|
||||
hidden_dropout=dropout_prob,
|
||||
convert_fp16_to_fp32_in_softmax=convert_fp16_to_fp32_in_softmax,
|
||||
is_naive_fp16=is_naive_fp16,
|
||||
)
|
||||
self.bert_layers.append(bert_layer)
|
||||
|
||||
self.layer_norm = LayerNorm(hidden_size)
|
||||
self.head = BertDualHead(hidden_size,
|
||||
self.embedding.word_embedding_weight.size(0),
|
||||
add_binary_head=add_binary_head)
|
||||
self.head = BertDualHead(
|
||||
hidden_size, self.embedding.word_embedding_weight.size(0), add_binary_head=add_binary_head
|
||||
)
|
||||
self.reset_parameters()
|
||||
|
||||
def _init_normal(self, tensor):
|
||||
@@ -122,27 +127,30 @@ class BertForPretrain(nn.Module):
|
||||
|
||||
|
||||
class PipelineBertForPretrain(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
vocab_size,
|
||||
hidden_size,
|
||||
max_sequence_length,
|
||||
num_attention_heads,
|
||||
num_layers,
|
||||
add_binary_head,
|
||||
is_naive_fp16,
|
||||
num_tokentypes=2,
|
||||
dropout_prob=0.1,
|
||||
mlp_ratio=4,
|
||||
init_std=0.02,
|
||||
convert_fp16_to_fp32_in_softmax=False,
|
||||
first_stage=True,
|
||||
last_stage=True,
|
||||
start_idx=None,
|
||||
end_idx=None):
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size,
|
||||
hidden_size,
|
||||
max_sequence_length,
|
||||
num_attention_heads,
|
||||
num_layers,
|
||||
add_binary_head,
|
||||
is_naive_fp16,
|
||||
num_tokentypes=2,
|
||||
dropout_prob=0.1,
|
||||
mlp_ratio=4,
|
||||
init_std=0.02,
|
||||
convert_fp16_to_fp32_in_softmax=False,
|
||||
first_stage=True,
|
||||
last_stage=True,
|
||||
start_idx=None,
|
||||
end_idx=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.seq_parallel_size = gpc.get_world_size(ParallelMode.SEQUENCE)
|
||||
assert max_sequence_length % self.seq_parallel_size == 0, 'sequence length is not divisible by the sequence parallel size'
|
||||
assert (
|
||||
max_sequence_length % self.seq_parallel_size == 0
|
||||
), "sequence length is not divisible by the sequence parallel size"
|
||||
self.sub_seq_length = max_sequence_length // self.seq_parallel_size
|
||||
self.init_std = init_std
|
||||
self.num_layers = num_layers
|
||||
@@ -156,11 +164,13 @@ class PipelineBertForPretrain(nn.Module):
|
||||
self.preprocessor = PreProcessor(self.sub_seq_length)
|
||||
|
||||
if self.first_stage:
|
||||
self.embedding = Embedding(hidden_size=hidden_size,
|
||||
vocab_size=vocab_size,
|
||||
max_sequence_length=max_sequence_length,
|
||||
embedding_dropout_prob=dropout_prob,
|
||||
num_tokentypes=num_tokentypes)
|
||||
self.embedding = Embedding(
|
||||
hidden_size=hidden_size,
|
||||
vocab_size=vocab_size,
|
||||
max_sequence_length=max_sequence_length,
|
||||
embedding_dropout_prob=dropout_prob,
|
||||
num_tokentypes=num_tokentypes,
|
||||
)
|
||||
|
||||
# transformer layers
|
||||
self.bert_layers = nn.ModuleList()
|
||||
@@ -170,14 +180,16 @@ class PipelineBertForPretrain(nn.Module):
|
||||
end_idx = num_layers
|
||||
|
||||
for i in range(start_idx, end_idx):
|
||||
bert_layer = BertLayer(layer_number=i + 1,
|
||||
hidden_size=hidden_size,
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_dropout=dropout_prob,
|
||||
mlp_ratio=mlp_ratio,
|
||||
hidden_dropout=dropout_prob,
|
||||
convert_fp16_to_fp32_in_softmax=convert_fp16_to_fp32_in_softmax,
|
||||
is_naive_fp16=is_naive_fp16)
|
||||
bert_layer = BertLayer(
|
||||
layer_number=i + 1,
|
||||
hidden_size=hidden_size,
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_dropout=dropout_prob,
|
||||
mlp_ratio=mlp_ratio,
|
||||
hidden_dropout=dropout_prob,
|
||||
convert_fp16_to_fp32_in_softmax=convert_fp16_to_fp32_in_softmax,
|
||||
is_naive_fp16=is_naive_fp16,
|
||||
)
|
||||
self.bert_layers.append(bert_layer)
|
||||
|
||||
if self.last_stage:
|
||||
@@ -256,7 +268,7 @@ def _filter_kwargs(func, kwargs):
|
||||
return {k: v for k, v in kwargs.items() if k in sig.parameters}
|
||||
|
||||
|
||||
def build_pipeline_bert(num_layers, num_chunks, device=torch.device('cuda'), **kwargs):
|
||||
def build_pipeline_bert(num_layers, num_chunks, device=torch.device("cuda"), **kwargs):
|
||||
logger = get_dist_logger()
|
||||
pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE)
|
||||
pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
|
||||
@@ -265,12 +277,12 @@ def build_pipeline_bert(num_layers, num_chunks, device=torch.device('cuda'), **k
|
||||
parts = partition_uniform(num_layers, pipeline_size, num_chunks)[pipeline_rank]
|
||||
models = []
|
||||
for start, end in parts:
|
||||
kwargs['num_layers'] = num_layers
|
||||
kwargs['start_idx'] = start
|
||||
kwargs['end_idx'] = end
|
||||
kwargs['first_stage'] = start == 0
|
||||
kwargs['last_stage'] = end == num_layers
|
||||
logger.info(f'Rank{rank} build layer {start}-{end}, {end-start}/{num_layers} layers')
|
||||
kwargs["num_layers"] = num_layers
|
||||
kwargs["start_idx"] = start
|
||||
kwargs["end_idx"] = end
|
||||
kwargs["first_stage"] = start == 0
|
||||
kwargs["last_stage"] = end == num_layers
|
||||
logger.info(f"Rank{rank} build layer {start}-{end}, {end-start}/{num_layers} layers")
|
||||
chunk = PipelineBertForPretrain(**_filter_kwargs(PipelineBertForPretrain.__init__, kwargs)).to(device)
|
||||
if start == 0:
|
||||
wrapper.register_module(chunk.embedding.word_embeddings)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from .embedding import VocabEmbedding, Embedding
|
||||
from .bert_layer import BertLayer
|
||||
from .embedding import Embedding, VocabEmbedding
|
||||
from .head import BertDualHead
|
||||
from .preprocess import PreProcessor
|
||||
|
||||
@@ -20,18 +20,20 @@ class BertLayer(nn.Module):
|
||||
output of the same size.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
layer_number,
|
||||
hidden_size,
|
||||
num_attention_heads,
|
||||
attention_dropout,
|
||||
mlp_ratio,
|
||||
hidden_dropout,
|
||||
is_naive_fp16,
|
||||
apply_residual_connection_post_layernorm=False,
|
||||
fp32_residual_connection=False,
|
||||
bias_dropout_fusion: bool = True,
|
||||
convert_fp16_to_fp32_in_softmax: bool = False):
|
||||
def __init__(
|
||||
self,
|
||||
layer_number,
|
||||
hidden_size,
|
||||
num_attention_heads,
|
||||
attention_dropout,
|
||||
mlp_ratio,
|
||||
hidden_dropout,
|
||||
is_naive_fp16,
|
||||
apply_residual_connection_post_layernorm=False,
|
||||
fp32_residual_connection=False,
|
||||
bias_dropout_fusion: bool = True,
|
||||
convert_fp16_to_fp32_in_softmax: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.layer_number = layer_number
|
||||
|
||||
@@ -50,7 +52,8 @@ class BertLayer(nn.Module):
|
||||
layer_number=layer_number,
|
||||
apply_query_key_layer_scaling=True,
|
||||
convert_fp16_to_fp32_in_softmax=convert_fp16_to_fp32_in_softmax,
|
||||
fp16=is_naive_fp16)
|
||||
fp16=is_naive_fp16,
|
||||
)
|
||||
|
||||
self.hidden_dropout = hidden_dropout
|
||||
self.bias_dropout_fusion = bias_dropout_fusion
|
||||
@@ -90,8 +93,9 @@ class BertLayer(nn.Module):
|
||||
|
||||
# re-enable torch grad to enable fused optimization.
|
||||
with torch.enable_grad():
|
||||
layernorm_input = bias_dropout_add_func(attention_output, attention_bias.expand_as(residual), residual,
|
||||
self.hidden_dropout)
|
||||
layernorm_input = bias_dropout_add_func(
|
||||
attention_output, attention_bias.expand_as(residual), residual, self.hidden_dropout
|
||||
)
|
||||
|
||||
# Layer norm post the self attention.
|
||||
layernorm_output = self.post_attention_layernorm(layernorm_input)
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import torch
|
||||
|
||||
|
||||
def bias_dropout_add(x, bias, residual, prob, training):
|
||||
# type: (Tensor, Tensor, Tensor, float, bool) -> Tensor
|
||||
out = torch.nn.functional.dropout(x + bias, p=prob, training=training)
|
||||
@@ -10,4 +11,5 @@ def bias_dropout_add(x, bias, residual, prob, training):
|
||||
def get_bias_dropout_add(training):
|
||||
def _bias_dropout_add(x, bias, residual, prob):
|
||||
return bias_dropout_add(x, bias, residual, prob, training)
|
||||
return _bias_dropout_add
|
||||
|
||||
return _bias_dropout_add
|
||||
|
||||
@@ -5,7 +5,6 @@ import torch.nn.init as init
|
||||
|
||||
|
||||
class VocabEmbedding(torch.nn.Module):
|
||||
|
||||
def __init__(self, num_embeddings, embedding_dim):
|
||||
super(VocabEmbedding, self).__init__()
|
||||
# Keep the input dimensions.
|
||||
@@ -13,26 +12,29 @@ class VocabEmbedding(torch.nn.Module):
|
||||
self.embedding_dim = embedding_dim
|
||||
self.padding_idx = None
|
||||
self.max_norm = None
|
||||
self.norm_type = 2.
|
||||
self.norm_type = 2.0
|
||||
self.scale_grad_by_freq = False
|
||||
self.sparse = False
|
||||
self._weight = None
|
||||
|
||||
# Allocate weights and initialize.
|
||||
self.weight = nn.Parameter(torch.empty(
|
||||
self.num_embeddings, self.embedding_dim))
|
||||
self.weight = nn.Parameter(torch.empty(self.num_embeddings, self.embedding_dim))
|
||||
init.xavier_uniform_(self.weight)
|
||||
|
||||
def forward(self, hidden_state):
|
||||
output = F.embedding(hidden_state, self.weight,
|
||||
self.padding_idx, self.max_norm,
|
||||
self.norm_type, self.scale_grad_by_freq,
|
||||
self.sparse)
|
||||
output = F.embedding(
|
||||
hidden_state,
|
||||
self.weight,
|
||||
self.padding_idx,
|
||||
self.max_norm,
|
||||
self.norm_type,
|
||||
self.scale_grad_by_freq,
|
||||
self.sparse,
|
||||
)
|
||||
return output
|
||||
|
||||
def __repr__(self):
|
||||
return f'VocabEmbedding(num_embeddings={self.num_embeddings}, ' \
|
||||
f'embedding_dim={self.embedding_dim})'
|
||||
return f"VocabEmbedding(num_embeddings={self.num_embeddings}, " f"embedding_dim={self.embedding_dim})"
|
||||
|
||||
|
||||
class Embedding(nn.Module):
|
||||
@@ -48,12 +50,7 @@ class Embedding(nn.Module):
|
||||
will ignore this embedding
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
hidden_size,
|
||||
vocab_size,
|
||||
max_sequence_length,
|
||||
embedding_dropout_prob,
|
||||
num_tokentypes):
|
||||
def __init__(self, hidden_size, vocab_size, max_sequence_length, embedding_dropout_prob, num_tokentypes):
|
||||
super(Embedding, self).__init__()
|
||||
|
||||
self.hidden_size = hidden_size
|
||||
@@ -62,16 +59,14 @@ class Embedding(nn.Module):
|
||||
self.word_embeddings = VocabEmbedding(vocab_size, self.hidden_size)
|
||||
|
||||
# Position embedding (serial).
|
||||
self.position_embeddings = torch.nn.Embedding(
|
||||
max_sequence_length, self.hidden_size)
|
||||
self.position_embeddings = torch.nn.Embedding(max_sequence_length, self.hidden_size)
|
||||
|
||||
# Token type embedding.
|
||||
# Add this as an optional field that can be added through
|
||||
# method call so we can load a pretrain model without
|
||||
# token types and add them as needed.
|
||||
if self.num_tokentypes > 0:
|
||||
self.tokentype_embeddings = torch.nn.Embedding(self.num_tokentypes,
|
||||
self.hidden_size)
|
||||
self.tokentype_embeddings = torch.nn.Embedding(self.num_tokentypes, self.hidden_size)
|
||||
else:
|
||||
self.tokentype_embeddings = None
|
||||
|
||||
|
||||
@@ -3,12 +3,10 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from loss_func.cross_entropy import vocab_cross_entropy
|
||||
|
||||
import colossalai
|
||||
from colossalai.kernel import LayerNorm
|
||||
from colossalai.legacy.context import ParallelMode
|
||||
from colossalai.legacy.core import global_context as gpc
|
||||
|
||||
from .embedding import VocabEmbedding
|
||||
from .linear import Linear
|
||||
from .pooler import Pooler
|
||||
|
||||
@@ -26,7 +24,6 @@ class BertLMHead(nn.Module):
|
||||
vocab_size,
|
||||
hidden_size,
|
||||
):
|
||||
|
||||
super(BertLMHead, self).__init__()
|
||||
self.bias = torch.nn.Parameter(torch.zeros(vocab_size))
|
||||
|
||||
@@ -46,7 +43,6 @@ class BertLMHead(nn.Module):
|
||||
|
||||
|
||||
class BertBinaryHead(nn.Module):
|
||||
|
||||
def __init__(self, hidden_size):
|
||||
super().__init__()
|
||||
self.pooler = Pooler(hidden_size)
|
||||
@@ -62,7 +58,6 @@ class BertBinaryHead(nn.Module):
|
||||
|
||||
|
||||
class BertDualHead(nn.Module):
|
||||
|
||||
def __init__(self, hidden_size, vocab_size, add_binary_head):
|
||||
super().__init__()
|
||||
self.lm_head = BertLMHead(vocab_size, hidden_size)
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
import torch
|
||||
import math
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def init_normal(tensor, sigma):
|
||||
"""Init method based on N(0, sigma)."""
|
||||
torch.nn.init.normal_(tensor, mean=0.0, std=sigma)
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn import Parameter
|
||||
import torch.nn.functional as F
|
||||
import torch.nn.init as init
|
||||
from torch.nn import Parameter
|
||||
|
||||
|
||||
class Linear(nn.Module):
|
||||
@@ -24,11 +24,7 @@ class Linear(nn.Module):
|
||||
adding bias but instead return it.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
input_size,
|
||||
output_size,
|
||||
bias=True,
|
||||
skip_bias_add=False):
|
||||
def __init__(self, input_size, output_size, bias=True, skip_bias_add=False):
|
||||
super(Linear, self).__init__()
|
||||
|
||||
# Keep input parameters
|
||||
@@ -36,9 +32,12 @@ class Linear(nn.Module):
|
||||
self.output_size = output_size
|
||||
self.skip_bias_add = skip_bias_add
|
||||
|
||||
self.weight = Parameter(torch.empty(self.output_size,
|
||||
self.input_size,
|
||||
))
|
||||
self.weight = Parameter(
|
||||
torch.empty(
|
||||
self.output_size,
|
||||
self.input_size,
|
||||
)
|
||||
)
|
||||
init.normal_(self.weight)
|
||||
if bias:
|
||||
self.bias = Parameter(torch.empty(self.output_size))
|
||||
@@ -46,7 +45,7 @@ class Linear(nn.Module):
|
||||
with torch.no_grad():
|
||||
self.bias.zero_()
|
||||
else:
|
||||
self.register_parameter('bias', None)
|
||||
self.register_parameter("bias", None)
|
||||
|
||||
def forward(self, input_):
|
||||
# Matrix multiply.
|
||||
@@ -59,5 +58,7 @@ class Linear(nn.Module):
|
||||
return output
|
||||
|
||||
def __repr__(self):
|
||||
return f'Linear(in_features={self.input_size}, out_features={self.output_size}, ' + \
|
||||
f'bias={self.bias is not None}, skip_bias_add={self.skip_bias_add})'
|
||||
return (
|
||||
f"Linear(in_features={self.input_size}, out_features={self.output_size}, "
|
||||
+ f"bias={self.bias is not None}, skip_bias_add={self.skip_bias_add})"
|
||||
)
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .linear import Linear
|
||||
from colossalai.kernel.jit import bias_gelu_impl
|
||||
|
||||
from .linear import Linear
|
||||
|
||||
|
||||
class TransformerMLP(nn.Module):
|
||||
"""MLP.
|
||||
@@ -18,19 +18,13 @@ class TransformerMLP(nn.Module):
|
||||
super(TransformerMLP, self).__init__()
|
||||
|
||||
# Project to 4h.
|
||||
self.dense_h_to_4h = Linear(
|
||||
hidden_size,
|
||||
int(hidden_size*mlp_ratio),
|
||||
skip_bias_add=True)
|
||||
self.dense_h_to_4h = Linear(hidden_size, int(hidden_size * mlp_ratio), skip_bias_add=True)
|
||||
|
||||
self.bias_gelu_fusion = fuse_gelu
|
||||
self.activation_func = F.gelu
|
||||
|
||||
# Project back to h.
|
||||
self.dense_4h_to_h = Linear(
|
||||
int(hidden_size*mlp_ratio),
|
||||
hidden_size,
|
||||
skip_bias_add=True)
|
||||
self.dense_4h_to_h = Linear(int(hidden_size * mlp_ratio), hidden_size, skip_bias_add=True)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
# hidden states should be in the shape of [s, b, h]
|
||||
@@ -39,11 +33,9 @@ class TransformerMLP(nn.Module):
|
||||
intermediate_parallel, bias_parallel = self.dense_h_to_4h(hidden_states)
|
||||
|
||||
if self.bias_gelu_fusion:
|
||||
intermediate_parallel = \
|
||||
bias_gelu_impl(intermediate_parallel, bias_parallel)
|
||||
intermediate_parallel = bias_gelu_impl(intermediate_parallel, bias_parallel)
|
||||
else:
|
||||
intermediate_parallel = \
|
||||
self.activation_func(intermediate_parallel + bias_parallel)
|
||||
intermediate_parallel = self.activation_func(intermediate_parallel + bias_parallel)
|
||||
|
||||
# [s, b, h]
|
||||
output, output_bias = self.dense_4h_to_h(intermediate_parallel)
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from .linear import Linear
|
||||
|
||||
|
||||
|
||||
@@ -6,7 +6,6 @@ from colossalai.legacy.core import global_context as gpc
|
||||
|
||||
|
||||
class PreProcessor(nn.Module):
|
||||
|
||||
def __init__(self, sub_seq_length):
|
||||
super().__init__()
|
||||
self.sub_seq_length = sub_seq_length
|
||||
@@ -15,10 +14,9 @@ class PreProcessor(nn.Module):
|
||||
# Create position ids
|
||||
seq_length = token_ids.size(1)
|
||||
local_rank = gpc.get_local_rank(ParallelMode.SEQUENCE)
|
||||
position_ids = torch.arange(seq_length * local_rank,
|
||||
seq_length * (local_rank + 1),
|
||||
dtype=torch.long,
|
||||
device=token_ids.device)
|
||||
position_ids = torch.arange(
|
||||
seq_length * local_rank, seq_length * (local_rank + 1), dtype=torch.long, device=token_ids.device
|
||||
)
|
||||
position_ids = position_ids.unsqueeze(0).expand_as(token_ids)
|
||||
|
||||
return position_ids
|
||||
@@ -42,7 +40,7 @@ class PreProcessor(nn.Module):
|
||||
extended_attention_mask = attention_mask_bss.unsqueeze(1)
|
||||
|
||||
# Convert attention mask to binary:
|
||||
extended_attention_mask = (extended_attention_mask < 0.5)
|
||||
extended_attention_mask = extended_attention_mask < 0.5
|
||||
|
||||
return extended_attention_mask
|
||||
|
||||
|
||||
@@ -12,7 +12,6 @@ from colossalai.kernel import LayerNorm
|
||||
from colossalai.legacy.amp import AMP_TYPE
|
||||
from colossalai.legacy.context.parallel_mode import ParallelMode
|
||||
from colossalai.legacy.core import global_context as gpc
|
||||
from colossalai.legacy.engine.schedule import PipelineSchedule
|
||||
from colossalai.legacy.utils import is_using_pp
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.nn.optimizer import FusedAdam
|
||||
@@ -31,7 +30,7 @@ def process_batch_data(batch_data):
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-s', '--synthetic', action="store_true", help="whether use synthetic data")
|
||||
parser.add_argument("-s", "--synthetic", action="store_true", help="whether use synthetic data")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
@@ -48,37 +47,39 @@ def pipeline_data_process_func(stage_output, micro_batch_data):
|
||||
|
||||
def main():
|
||||
# initialize
|
||||
args = parse_args()
|
||||
colossalai.launch_from_torch(config='./config.py', seed=1234, backend='nccl')
|
||||
parse_args()
|
||||
colossalai.launch_from_torch(config="./config.py", seed=1234, backend="nccl")
|
||||
|
||||
logger = get_dist_logger()
|
||||
|
||||
# build synthetic dataloader
|
||||
BATCH_SIZE_PER_GPUS = gpc.config.GLOBAL_BATCH_SIZE // gpc.get_world_size(ParallelMode.DATA)
|
||||
VOCAB_SIZE = 30528
|
||||
trainloader = DummyDataloader(batch_size=BATCH_SIZE_PER_GPUS,
|
||||
vocab_size=VOCAB_SIZE,
|
||||
seq_length=gpc.config.SEQ_LENGTH)
|
||||
validloader = DummyDataloader(batch_size=BATCH_SIZE_PER_GPUS,
|
||||
vocab_size=VOCAB_SIZE,
|
||||
seq_length=gpc.config.SEQ_LENGTH)
|
||||
trainloader = DummyDataloader(
|
||||
batch_size=BATCH_SIZE_PER_GPUS, vocab_size=VOCAB_SIZE, seq_length=gpc.config.SEQ_LENGTH
|
||||
)
|
||||
validloader = DummyDataloader(
|
||||
batch_size=BATCH_SIZE_PER_GPUS, vocab_size=VOCAB_SIZE, seq_length=gpc.config.SEQ_LENGTH
|
||||
)
|
||||
|
||||
logger.info("Dataloaders are built", ranks=[0])
|
||||
|
||||
# build model
|
||||
if hasattr(gpc.config, 'fp16') and gpc.config.fp16.get('mode') == AMP_TYPE.NAIVE:
|
||||
if hasattr(gpc.config, "fp16") and gpc.config.fp16.get("mode") == AMP_TYPE.NAIVE:
|
||||
is_naive_fp16 = True
|
||||
else:
|
||||
is_naive_fp16 = False
|
||||
|
||||
use_pipeline = is_using_pp()
|
||||
kwargs = dict(vocab_size=VOCAB_SIZE,
|
||||
hidden_size=gpc.config.HIDDEN_SIZE,
|
||||
max_sequence_length=gpc.config.SEQ_LENGTH,
|
||||
num_attention_heads=gpc.config.NUM_ATTENTION_HEADS,
|
||||
convert_fp16_to_fp32_in_softmax=True,
|
||||
is_naive_fp16=is_naive_fp16,
|
||||
add_binary_head=gpc.config.ADD_BINARY_HEAD)
|
||||
kwargs = dict(
|
||||
vocab_size=VOCAB_SIZE,
|
||||
hidden_size=gpc.config.HIDDEN_SIZE,
|
||||
max_sequence_length=gpc.config.SEQ_LENGTH,
|
||||
num_attention_heads=gpc.config.NUM_ATTENTION_HEADS,
|
||||
convert_fp16_to_fp32_in_softmax=True,
|
||||
is_naive_fp16=is_naive_fp16,
|
||||
add_binary_head=gpc.config.ADD_BINARY_HEAD,
|
||||
)
|
||||
|
||||
if use_pipeline:
|
||||
model = build_pipeline_bert(num_layers=gpc.config.DEPTH, num_chunks=1, **kwargs)
|
||||
@@ -99,35 +100,39 @@ def main():
|
||||
logger.info("Criterion is built", ranks=[0])
|
||||
|
||||
# layernorm and bias has no weight decay
|
||||
weight_decay_params = {'params': []}
|
||||
no_weight_decay_params = {'params': [], 'weight_decay': 0.0}
|
||||
weight_decay_params = {"params": []}
|
||||
no_weight_decay_params = {"params": [], "weight_decay": 0.0}
|
||||
for module_ in model.modules():
|
||||
if isinstance(module_, LayerNorm):
|
||||
no_weight_decay_params['params'].extend([p for p in list(module_._parameters.values()) if p is not None])
|
||||
no_weight_decay_params["params"].extend([p for p in list(module_._parameters.values()) if p is not None])
|
||||
else:
|
||||
weight_decay_params['params'].extend(
|
||||
[p for n, p in list(module_._parameters.items()) if p is not None and n != 'bias'])
|
||||
no_weight_decay_params['params'].extend(
|
||||
[p for n, p in list(module_._parameters.items()) if p is not None and n == 'bias'])
|
||||
weight_decay_params["params"].extend(
|
||||
[p for n, p in list(module_._parameters.items()) if p is not None and n != "bias"]
|
||||
)
|
||||
no_weight_decay_params["params"].extend(
|
||||
[p for n, p in list(module_._parameters.items()) if p is not None and n == "bias"]
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"without weight decay param: {len(no_weight_decay_params['params'])}, with weight decay param: {len(weight_decay_params['params'])}"
|
||||
)
|
||||
# optimizer
|
||||
optimizer = FusedAdam((weight_decay_params, no_weight_decay_params),
|
||||
lr=gpc.config.LR,
|
||||
weight_decay=gpc.config.WEIGHT_DECAY)
|
||||
optimizer = FusedAdam(
|
||||
(weight_decay_params, no_weight_decay_params), lr=gpc.config.LR, weight_decay=gpc.config.WEIGHT_DECAY
|
||||
)
|
||||
logger.info("Optimizer is built", ranks=[0])
|
||||
|
||||
# lr scheduler
|
||||
# follow Megatron-LM setting
|
||||
warmup_steps = int(gpc.config.DECAY_ITERS * gpc.config.WARMUP_FRACTION)
|
||||
lr_scheduler = AnnealingLR(optimizer=optimizer,
|
||||
max_lr=gpc.config.LR,
|
||||
min_lr=gpc.config.MIN_LR,
|
||||
warmup_steps=warmup_steps,
|
||||
decay_steps=gpc.config.DECAY_ITERS,
|
||||
decay_style='linear')
|
||||
lr_scheduler = AnnealingLR(
|
||||
optimizer=optimizer,
|
||||
max_lr=gpc.config.LR,
|
||||
min_lr=gpc.config.MIN_LR,
|
||||
warmup_steps=warmup_steps,
|
||||
decay_steps=gpc.config.DECAY_ITERS,
|
||||
decay_style="linear",
|
||||
)
|
||||
logger.info(f"LR Scheduler is built with {warmup_steps} warmup steps and {gpc.config.DECAY_ITERS} decay steps")
|
||||
|
||||
# # init
|
||||
@@ -135,7 +140,6 @@ def main():
|
||||
|
||||
# build timer
|
||||
timer = MultiTimer()
|
||||
skip_iters = 0
|
||||
|
||||
# build loss tracker
|
||||
accumulated_train_loss = torch.zeros(1, dtype=torch.float32).cuda()
|
||||
@@ -150,7 +154,7 @@ def main():
|
||||
logger.info("start training")
|
||||
|
||||
for step in range(1, gpc.config.TRAIN_ITERS + 1):
|
||||
timer.start('train-iterations')
|
||||
timer.start("train-iterations")
|
||||
engine.train()
|
||||
if use_pipeline:
|
||||
engine.zero_grad()
|
||||
@@ -158,13 +162,14 @@ def main():
|
||||
engine.step()
|
||||
else:
|
||||
tokens, types, sentence_order, loss_mask, lm_labels, padding_mask = get_batch_for_sequence_parallel(
|
||||
trainloader)
|
||||
trainloader
|
||||
)
|
||||
engine.zero_grad()
|
||||
lm_loss, sop_output = engine(tokens, padding_mask, types, lm_labels)
|
||||
train_loss = engine.criterion(lm_loss, sop_output, loss_mask, sentence_order)
|
||||
engine.backward(train_loss)
|
||||
engine.step()
|
||||
timer.stop('train-iterations', keep_in_history=True)
|
||||
timer.stop("train-iterations", keep_in_history=True)
|
||||
|
||||
if not gpc.is_initialized(ParallelMode.PIPELINE) or gpc.is_last_rank(ParallelMode.PIPELINE):
|
||||
accumulated_train_loss += train_loss
|
||||
@@ -177,12 +182,18 @@ def main():
|
||||
for j in range(gpc.config.EVAL_ITERS):
|
||||
with torch.no_grad():
|
||||
if use_pipeline:
|
||||
_, _, eval_loss = engine.execute_schedule(valid_data_iter,
|
||||
forward_only=True,
|
||||
return_output_label=False)
|
||||
_, _, eval_loss = engine.execute_schedule(
|
||||
valid_data_iter, forward_only=True, return_output_label=False
|
||||
)
|
||||
else:
|
||||
tokens, types, sentence_order, loss_mask, lm_labels, padding_mask = get_batch_for_sequence_parallel(
|
||||
validloader)
|
||||
(
|
||||
tokens,
|
||||
types,
|
||||
sentence_order,
|
||||
loss_mask,
|
||||
lm_labels,
|
||||
padding_mask,
|
||||
) = get_batch_for_sequence_parallel(validloader)
|
||||
lm_loss, sop_output = engine(tokens, padding_mask, types, lm_labels)
|
||||
eval_loss = engine.criterion(lm_loss, sop_output, loss_mask, sentence_order)
|
||||
|
||||
@@ -196,18 +207,22 @@ def main():
|
||||
timer_string = []
|
||||
for n, t in timer:
|
||||
timer_string.append(f"{n}: {t.get_history_mean()*1000:.5f}")
|
||||
timer_string = ' | '.join(timer_string)
|
||||
lr = list(engine.optimizer.param_groups)[0]['lr']
|
||||
timer_string = " | ".join(timer_string)
|
||||
lr = list(engine.optimizer.param_groups)[0]["lr"]
|
||||
loss_scale = engine.optimizer.optim.loss_scale.item()
|
||||
|
||||
if gpc.is_initialized(ParallelMode.PIPELINE):
|
||||
ranks = [gpc.get_ranks_in_group(ParallelMode.PIPELINE)[-1]]
|
||||
else:
|
||||
ranks = [0]
|
||||
logger.info(f'Step {step} / {gpc.config.TRAIN_ITERS} | Train Loss: {accumulated_train_loss.item():.5g} ' +
|
||||
f'| Eval Loss: {accumulated_eval_loss.item():.5g} ' + f'| Loss Scale: {loss_scale}' +
|
||||
f"| Learning rate: {lr} | " + timer_string,
|
||||
ranks=ranks)
|
||||
logger.info(
|
||||
f"Step {step} / {gpc.config.TRAIN_ITERS} | Train Loss: {accumulated_train_loss.item():.5g} "
|
||||
+ f"| Eval Loss: {accumulated_eval_loss.item():.5g} "
|
||||
+ f"| Loss Scale: {loss_scale}"
|
||||
+ f"| Learning rate: {lr} | "
|
||||
+ timer_string,
|
||||
ranks=ranks,
|
||||
)
|
||||
|
||||
for n, t in timer:
|
||||
t.reset()
|
||||
@@ -215,5 +230,5 @@ def main():
|
||||
accumulated_train_loss.zero_()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
Reference in New Issue
Block a user