[misc] update pre-commit and run all files (#4752)

* [misc] update pre-commit

* [misc] run pre-commit

* [misc] remove useless configuration files

* [misc] ignore cuda for clang-format
This commit is contained in:
Hongxin Liu
2023-09-19 14:20:26 +08:00
committed by GitHub
parent 3c6b831c26
commit 079bf3cb26
1268 changed files with 50037 additions and 38444 deletions

View File

@@ -4,7 +4,7 @@ from colossalai.legacy.amp import AMP_TYPE
TRAIN_ITERS = 10
DECAY_ITERS = 4
WARMUP_FRACTION = 0.01
GLOBAL_BATCH_SIZE = 32 # dp world size * sentences per GPU
GLOBAL_BATCH_SIZE = 32 # dp world size * sentences per GPU
EVAL_ITERS = 10
EVAL_INTERVAL = 10
LR = 0.0001
@@ -28,8 +28,8 @@ SEED = 1234
NUM_MICRO_BATCHES = 4
# colossalai config
parallel = dict(pipeline=1, tensor=dict(size=2, mode='sequence'))
parallel = dict(pipeline=1, tensor=dict(size=2, mode="sequence"))
fp16 = dict(mode=AMP_TYPE.NAIVE, verbose=True)
gradient_handler = [dict(type='SequenceParallelGradientHandler')]
gradient_handler = [dict(type="SequenceParallelGradientHandler")]

View File

@@ -15,16 +15,13 @@ def cyclic_iter(iter):
yield x
def build_train_valid_test_data_iterators(train_iters,
global_batch_size,
eval_interval,
eval_iters,
dataloader_type='single',
**kwargs):
def build_train_valid_test_data_iterators(
train_iters, global_batch_size, eval_interval, eval_iters, dataloader_type="single", **kwargs
):
(train_dataloader, valid_dataloader, test_dataloader) = (None, None, None)
logger = get_dist_logger()
logger.info('> building train, validation, and test datasets ...', ranks=[0])
logger.info("> building train, validation, and test datasets ...", ranks=[0])
# Backward compatibility, assume fixed batch size.
# if iteration > 0 and consumed_train_samples == 0:
@@ -38,29 +35,29 @@ def build_train_valid_test_data_iterators(train_iters,
# Data loader only on rank 0 of each model parallel group.
if not gpc.is_initialized(ParallelMode.TENSOR) or gpc.get_local_rank(ParallelMode.TENSOR) == 0:
# Number of train/valid/test samples.
train_samples = train_iters * global_batch_size
eval_iters_ = (train_iters // eval_interval + 1) * eval_iters
test_iters = eval_iters
train_val_test_num_samples = [train_samples, eval_iters_ * global_batch_size, test_iters * global_batch_size]
logger.info(' > datasets target sizes (minimum size):')
logger.info(' train: {}'.format(train_val_test_num_samples[0]), ranks=[0])
logger.info(' validation: {}'.format(train_val_test_num_samples[1]), ranks=[0])
logger.info(' test: {}'.format(train_val_test_num_samples[2]), ranks=[0])
logger.info(" > datasets target sizes (minimum size):")
logger.info(" train: {}".format(train_val_test_num_samples[0]), ranks=[0])
logger.info(" validation: {}".format(train_val_test_num_samples[1]), ranks=[0])
logger.info(" test: {}".format(train_val_test_num_samples[2]), ranks=[0])
# Build the datasets.
train_ds, valid_ds, test_ds = build_train_valid_test_datasets(
train_valid_test_num_samples=train_val_test_num_samples, **kwargs)
train_valid_test_num_samples=train_val_test_num_samples, **kwargs
)
# Build dataloaders.
dp_size = gpc.get_world_size(ParallelMode.DATA)
train_dataloader = build_pretraining_data_loader(train_ds,
consumed_samples=0,
micro_batch_size=global_batch_size // dp_size)
valid_dataloader = build_pretraining_data_loader(valid_ds,
consumed_samples=0,
micro_batch_size=global_batch_size // dp_size)
train_dataloader = build_pretraining_data_loader(
train_ds, consumed_samples=0, micro_batch_size=global_batch_size // dp_size
)
valid_dataloader = build_pretraining_data_loader(
valid_ds, consumed_samples=0, micro_batch_size=global_batch_size // dp_size
)
test_dataloader = build_pretraining_data_loader(test_ds, 0, micro_batch_size=global_batch_size // dp_size)
# Flags to know if we need to do training/validation/testing.
@@ -73,29 +70,26 @@ def build_train_valid_test_data_iterators(train_iters,
flags = torch.cuda.LongTensor([0, 0, 0])
# Broadcast num tokens.
torch.distributed.broadcast(flags,
gpc.get_ranks_in_group(ParallelMode.TENSOR)[0],
group=gpc.get_group(ParallelMode.TENSOR))
torch.distributed.broadcast(
flags, gpc.get_ranks_in_group(ParallelMode.TENSOR)[0], group=gpc.get_group(ParallelMode.TENSOR)
)
# Build iterators.
dl_type = dataloader_type
assert dl_type in ['single', 'cyclic']
assert dl_type in ["single", "cyclic"]
if train_dataloader is not None:
train_data_iterator = iter(train_dataloader) if dl_type == 'single' \
else iter(cyclic_iter(train_dataloader))
train_data_iterator = iter(train_dataloader) if dl_type == "single" else iter(cyclic_iter(train_dataloader))
else:
train_data_iterator = None
if valid_dataloader is not None:
valid_data_iterator = iter(valid_dataloader) if dl_type == 'single' \
else iter(cyclic_iter(valid_dataloader))
valid_data_iterator = iter(valid_dataloader) if dl_type == "single" else iter(cyclic_iter(valid_dataloader))
else:
valid_data_iterator = None
if test_dataloader is not None:
test_data_iterator = iter(test_dataloader) if dl_type == 'single' \
else iter(cyclic_iter(test_dataloader))
test_data_iterator = iter(test_dataloader) if dl_type == "single" else iter(cyclic_iter(test_dataloader))
else:
test_data_iterator = None

View File

@@ -15,7 +15,7 @@ def _build_key_size_numel_dictionaries(keys, data):
if not gpc.is_initialized(ParallelMode.TENSOR) or gpc.get_local_rank(ParallelMode.TENSOR) == 0:
offset = 0
for key in keys:
assert data[key].dim() < max_dim, 'you should increase MAX_DATA_DIM'
assert data[key].dim() < max_dim, "you should increase MAX_DATA_DIM"
size = data[key].size()
for i, s in enumerate(size):
sizes[i + offset] = s
@@ -23,9 +23,9 @@ def _build_key_size_numel_dictionaries(keys, data):
# Move to GPU and broadcast.
sizes_cuda = torch.cuda.LongTensor(sizes)
torch.distributed.broadcast(sizes_cuda,
gpc.get_ranks_in_group(ParallelMode.TENSOR)[0],
group=gpc.get_group(ParallelMode.TENSOR))
torch.distributed.broadcast(
sizes_cuda, gpc.get_ranks_in_group(ParallelMode.TENSOR)[0], group=gpc.get_group(ParallelMode.TENSOR)
)
# Move back to cpu and unpack.
sizes_cpu = sizes_cuda.cpu()
@@ -73,9 +73,9 @@ def broadcast_data(keys, data, datatype):
flatten_data = torch.empty(total_numel, device=torch.cuda.current_device(), dtype=datatype)
# Broadcast
torch.distributed.broadcast(flatten_data,
gpc.get_ranks_in_group(ParallelMode.TENSOR)[0],
group=gpc.get_group(ParallelMode.TENSOR))
torch.distributed.broadcast(
flatten_data, gpc.get_ranks_in_group(ParallelMode.TENSOR)[0], group=gpc.get_group(ParallelMode.TENSOR)
)
# Unpack
output = {}
@@ -93,7 +93,7 @@ def get_batch(data_iterator):
"""Build the batch."""
# Items and their type.
keys = ['text', 'types', 'labels', 'is_random', 'loss_mask', 'padding_mask']
keys = ["text", "types", "labels", "is_random", "loss_mask", "padding_mask"]
datatype = torch.int64
# Broadcast data.
@@ -104,12 +104,12 @@ def get_batch(data_iterator):
data_b = broadcast_data(keys, data, datatype)
# Unpack.
tokens = data_b['text'].long()
types = data_b['types'].long()
sentence_order = data_b['is_random'].long()
loss_mask = data_b['loss_mask'].float()
lm_labels = data_b['labels'].long()
padding_mask = data_b['padding_mask'].long()
tokens = data_b["text"].long()
types = data_b["types"].long()
sentence_order = data_b["is_random"].long()
loss_mask = data_b["loss_mask"].float()
lm_labels = data_b["labels"].long()
padding_mask = data_b["padding_mask"].long()
return tokens, types, sentence_order, loss_mask, lm_labels, padding_mask
@@ -118,7 +118,7 @@ def get_batch_for_sequence_parallel(data_iterator):
"""Build the batch."""
# Items and their type.
keys = ['text', 'types', 'labels', 'is_random', 'loss_mask', 'padding_mask']
keys = ["text", "types", "labels", "is_random", "loss_mask", "padding_mask"]
datatype = torch.int64
# Broadcast data.
@@ -134,24 +134,23 @@ def get_batch_for_sequence_parallel(data_iterator):
global_rank = torch.distributed.get_rank()
local_world_size = 1 if not gpc.is_initialized(ParallelMode.TENSOR) else gpc.get_world_size(ParallelMode.TENSOR)
local_rank = global_rank % local_world_size
seq_length = data_b['text'].size(1)
seq_length = data_b["text"].size(1)
sub_seq_length = seq_length // local_world_size
sub_seq_start = local_rank * sub_seq_length
sub_seq_end = (local_rank + 1) * sub_seq_length
#
# # Unpack.
tokens = data_b['text'][:, sub_seq_start:sub_seq_end].long()
types = data_b['types'][:, sub_seq_start:sub_seq_end].long()
sentence_order = data_b['is_random'].long()
loss_mask = data_b['loss_mask'][:, sub_seq_start:sub_seq_end].float()
lm_labels = data_b['labels'][:, sub_seq_start:sub_seq_end].long()
padding_mask = data_b['padding_mask'].long()
tokens = data_b["text"][:, sub_seq_start:sub_seq_end].long()
types = data_b["types"][:, sub_seq_start:sub_seq_end].long()
sentence_order = data_b["is_random"].long()
loss_mask = data_b["loss_mask"][:, sub_seq_start:sub_seq_end].float()
lm_labels = data_b["labels"][:, sub_seq_start:sub_seq_end].long()
padding_mask = data_b["padding_mask"].long()
return tokens, types, sentence_order, loss_mask, lm_labels, padding_mask
class SequenceParallelDataIterator:
def __init__(self, data_iter):
self.data_iter = data_iter

View File

@@ -41,10 +41,19 @@ except:
class BertDataset(Dataset):
def __init__(self, name, indexed_dataset, data_prefix, num_epochs, max_num_samples, masked_lm_prob, max_seq_length,
short_seq_prob, seed, binary_head):
def __init__(
self,
name,
indexed_dataset,
data_prefix,
num_epochs,
max_num_samples,
masked_lm_prob,
max_seq_length,
short_seq_prob,
seed,
binary_head,
):
# Params to store.
self.name = name
self.seed = seed
@@ -61,11 +70,12 @@ class BertDataset(Dataset):
data_prefix,
num_epochs,
max_num_samples,
self.max_seq_length - 3, # account for added tokens,
self.max_seq_length - 3, # account for added tokens,
short_seq_prob,
self.seed,
self.name,
self.binary_head)
self.binary_head,
)
# Vocab stuff.
tokenizer = get_tokenizer()
@@ -89,7 +99,7 @@ class BertDataset(Dataset):
return build_training_sample(
sample,
seq_length,
self.max_seq_length, # needed for padding
self.max_seq_length, # needed for padding
self.vocab_id_list,
self.vocab_id_to_token_dict,
self.cls_id,
@@ -98,37 +108,39 @@ class BertDataset(Dataset):
self.pad_id,
self.masked_lm_prob,
np_rng,
self.binary_head)
self.binary_head,
)
def get_samples_mapping_(indexed_dataset, data_prefix, num_epochs, max_num_samples, max_seq_length, short_seq_prob,
seed, name, binary_head):
def get_samples_mapping_(
indexed_dataset, data_prefix, num_epochs, max_num_samples, max_seq_length, short_seq_prob, seed, name, binary_head
):
logger = get_dist_logger()
if not num_epochs:
if not max_num_samples:
raise ValueError("Need to specify either max_num_samples "
"or num_epochs")
raise ValueError("Need to specify either max_num_samples " "or num_epochs")
num_epochs = np.iinfo(np.int32).max - 1
if not max_num_samples:
max_num_samples = np.iinfo(np.int64).max - 1
# Filename of the index mapping
indexmap_filename = data_prefix
indexmap_filename += '_{}_indexmap'.format(name)
indexmap_filename += "_{}_indexmap".format(name)
if num_epochs != (np.iinfo(np.int32).max - 1):
indexmap_filename += '_{}ep'.format(num_epochs)
indexmap_filename += "_{}ep".format(num_epochs)
if max_num_samples != (np.iinfo(np.int64).max - 1):
indexmap_filename += '_{}mns'.format(max_num_samples)
indexmap_filename += '_{}msl'.format(max_seq_length)
indexmap_filename += '_{:0.2f}ssp'.format(short_seq_prob)
indexmap_filename += '_{}s'.format(seed)
indexmap_filename += '.npy'
indexmap_filename += "_{}mns".format(max_num_samples)
indexmap_filename += "_{}msl".format(max_seq_length)
indexmap_filename += "_{:0.2f}ssp".format(short_seq_prob)
indexmap_filename += "_{}s".format(seed)
indexmap_filename += ".npy"
# Build the indexed mapping if not exist.
if torch.distributed.get_rank() == 0 and \
not os.path.isfile(indexmap_filename):
print(' > WARNING: could not find index map file {}, building '
'the indices on rank 0 ...'.format(indexmap_filename))
if torch.distributed.get_rank() == 0 and not os.path.isfile(indexmap_filename):
print(
" > WARNING: could not find index map file {}, building "
"the indices on rank 0 ...".format(indexmap_filename)
)
# Make sure the types match the helpers input types.
assert indexed_dataset.doc_idx.dtype == np.int64
@@ -137,18 +149,27 @@ def get_samples_mapping_(indexed_dataset, data_prefix, num_epochs, max_num_sampl
# Build samples mapping
verbose = torch.distributed.get_rank() == 0
start_time = time.time()
logger.info('\n > building samples index mapping for {} ...'.format(name), ranks=[0])
logger.info("\n > building samples index mapping for {} ...".format(name), ranks=[0])
# First compile and then import.
samples_mapping = helpers.build_mapping(indexed_dataset.doc_idx, indexed_dataset.sizes, num_epochs,
max_num_samples, max_seq_length, short_seq_prob, seed, verbose,
2 if binary_head else 1)
logger.info('\n > done building samples index maping', ranks=[0])
samples_mapping = helpers.build_mapping(
indexed_dataset.doc_idx,
indexed_dataset.sizes,
num_epochs,
max_num_samples,
max_seq_length,
short_seq_prob,
seed,
verbose,
2 if binary_head else 1,
)
logger.info("\n > done building samples index maping", ranks=[0])
np.save(indexmap_filename, samples_mapping, allow_pickle=True)
logger.info('\n > saved the index mapping in {}'.format(indexmap_filename), ranks=[0])
logger.info("\n > saved the index mapping in {}".format(indexmap_filename), ranks=[0])
# Make sure all the ranks have built the mapping
logger.info('\n > elapsed time to build and save samples mapping '
'(seconds): {:4f}'.format(time.time() - start_time),
ranks=[0])
logger.info(
"\n > elapsed time to build and save samples mapping " "(seconds): {:4f}".format(time.time() - start_time),
ranks=[0],
)
# This should be a barrier but nccl barrier assumes
# device_index=rank which is not the case for model
# parallel case
@@ -156,22 +177,38 @@ def get_samples_mapping_(indexed_dataset, data_prefix, num_epochs, max_num_sampl
torch.distributed.all_reduce(counts, group=gpc.get_group(ParallelMode.DATA))
if gpc.is_initialized(ParallelMode.PIPELINE):
torch.distributed.all_reduce(counts, group=gpc.get_group(ParallelMode.PIPELINE))
assert counts[0].item() == (torch.distributed.get_world_size() //
torch.distributed.get_world_size(group=gpc.get_group(ParallelMode.SEQUENCE)))
assert counts[0].item() == (
torch.distributed.get_world_size()
// torch.distributed.get_world_size(group=gpc.get_group(ParallelMode.SEQUENCE))
)
# Load indexed dataset.
start_time = time.time()
samples_mapping = np.load(indexmap_filename, allow_pickle=True, mmap_mode='r')
logger.info('\n > loading indexed mapping from {}'.format(indexmap_filename) +
'\n loaded indexed file in {:3.3f} seconds'.format(time.time() - start_time) +
'\n total number of samples: {}'.format(samples_mapping.shape[0]),
ranks=[0])
samples_mapping = np.load(indexmap_filename, allow_pickle=True, mmap_mode="r")
logger.info(
"\n > loading indexed mapping from {}".format(indexmap_filename)
+ "\n loaded indexed file in {:3.3f} seconds".format(time.time() - start_time)
+ "\n total number of samples: {}".format(samples_mapping.shape[0]),
ranks=[0],
)
return samples_mapping
def build_training_sample(sample, target_seq_length, max_seq_length, vocab_id_list, vocab_id_to_token_dict, cls_id,
sep_id, mask_id, pad_id, masked_lm_prob, np_rng, binary_head):
def build_training_sample(
sample,
target_seq_length,
max_seq_length,
vocab_id_list,
vocab_id_to_token_dict,
cls_id,
sep_id,
mask_id,
pad_id,
masked_lm_prob,
np_rng,
binary_head,
):
"""Build training sample.
Arguments:
@@ -215,22 +252,30 @@ def build_training_sample(sample, target_seq_length, max_seq_length, vocab_id_li
# Masking.
max_predictions_per_seq = masked_lm_prob * max_num_tokens
(tokens, masked_positions, masked_labels,
_) = create_masked_lm_predictions(tokens, vocab_id_list, vocab_id_to_token_dict, masked_lm_prob, cls_id, sep_id,
mask_id, max_predictions_per_seq, np_rng)
(tokens, masked_positions, masked_labels, _) = create_masked_lm_predictions(
tokens,
vocab_id_list,
vocab_id_to_token_dict,
masked_lm_prob,
cls_id,
sep_id,
mask_id,
max_predictions_per_seq,
np_rng,
)
# Padding.
tokens_np, tokentypes_np, labels_np, padding_mask_np, loss_mask_np \
= pad_and_convert_to_numpy(tokens, tokentypes, masked_positions,
masked_labels, pad_id, max_seq_length)
tokens_np, tokentypes_np, labels_np, padding_mask_np, loss_mask_np = pad_and_convert_to_numpy(
tokens, tokentypes, masked_positions, masked_labels, pad_id, max_seq_length
)
train_sample = {
'text': tokens_np,
'types': tokentypes_np,
'labels': labels_np,
'is_random': int(is_next_random),
'loss_mask': loss_mask_np,
'padding_mask': padding_mask_np,
'truncated': int(truncated)
"text": tokens_np,
"types": tokentypes_np,
"labels": labels_np,
"is_random": int(is_next_random),
"loss_mask": loss_mask_np,
"padding_mask": padding_mask_np,
"truncated": int(truncated),
}
return train_sample

View File

@@ -22,9 +22,7 @@ import torch
class BlendableDataset(torch.utils.data.Dataset):
def __init__(self, datasets, weights):
self.datasets = datasets
num_datasets = len(datasets)
assert num_datasets == len(weights)
@@ -46,12 +44,16 @@ class BlendableDataset(torch.utils.data.Dataset):
self.dataset_sample_index = np.zeros(self.size, dtype=np.int64)
from . import helpers
helpers.build_blending_indices(self.dataset_index,
self.dataset_sample_index,
weights, num_datasets, self.size,
torch.distributed.get_rank() == 0)
print('> elapsed time for building blendable dataset indices: '
'{:.2f} (sec)'.format(time.time() - start_time))
helpers.build_blending_indices(
self.dataset_index,
self.dataset_sample_index,
weights,
num_datasets,
self.size,
torch.distributed.get_rank() == 0,
)
print("> elapsed time for building blendable dataset indices: " "{:.2f} (sec)".format(time.time() - start_time))
def __len__(self):
return self.size

View File

@@ -1,29 +1,34 @@
from .blendable_dataset import BlendableDataset
from .dataset_utils import get_datasets_weights_and_num_samples, get_indexed_dataset_, get_train_valid_test_split_
from .bert_dataset import BertDataset
from colossalai.logging import get_dist_logger
DSET_TYPE_BERT = 'standard_bert'
DSET_TYPE_ICT = 'ict'
DSET_TYPE_T5 = 't5'
from .bert_dataset import BertDataset
from .blendable_dataset import BlendableDataset
from .dataset_utils import get_datasets_weights_and_num_samples, get_indexed_dataset_, get_train_valid_test_split_
DSET_TYPE_BERT = "standard_bert"
DSET_TYPE_ICT = "ict"
DSET_TYPE_T5 = "t5"
DSET_TYPES = [DSET_TYPE_BERT, DSET_TYPE_ICT, DSET_TYPE_T5]
def _build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
train_valid_test_num_samples,
max_seq_length, masked_lm_prob,
short_seq_prob, seed, skip_warmup,
binary_head,
dataset_type='standard_bert'):
def _build_train_valid_test_datasets(
data_prefix,
data_impl,
splits_string,
train_valid_test_num_samples,
max_seq_length,
masked_lm_prob,
short_seq_prob,
seed,
skip_warmup,
binary_head,
dataset_type="standard_bert",
):
if dataset_type not in DSET_TYPES:
raise ValueError("Invalid dataset_type: ", dataset_type)
# Indexed dataset.
indexed_dataset = get_indexed_dataset_(data_prefix,
data_impl,
skip_warmup)
indexed_dataset = get_indexed_dataset_(data_prefix, data_impl, skip_warmup)
# Get start and end indices of train/valid/train into doc-idx
# Note that doc-idx is designed to be num-docs + 1 so we can
@@ -34,22 +39,25 @@ def _build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
logger = get_dist_logger()
# Print stats about the splits.
logger.info('\n > dataset split:', ranks=[0])
logger.info("\n > dataset split:", ranks=[0])
def print_split_stats(name, index):
start_index = indexed_dataset.doc_idx[splits[index]]
end_index = indexed_dataset.doc_idx[splits[index + 1]]
logger.info('\n {}:'.format(name) +
'\n document indices in [{}, {}) total of {} documents'.format(
splits[index], splits[index + 1],
splits[index + 1] - splits[index]) +
'\n sentence indices in [{}, {}) total of {} sentences'.format(
start_index, end_index,
end_index - start_index),
ranks=[0])
print_split_stats('train', 0)
print_split_stats('validation', 1)
print_split_stats('test', 2)
logger.info(
"\n {}:".format(name)
+ "\n document indices in [{}, {}) total of {} documents".format(
splits[index], splits[index + 1], splits[index + 1] - splits[index]
)
+ "\n sentence indices in [{}, {}) total of {} sentences".format(
start_index, end_index, end_index - start_index
),
ranks=[0],
)
print_split_stats("train", 0)
print_split_stats("validation", 1)
print_split_stats("test", 2)
def build_dataset(index, name):
dataset = None
@@ -80,44 +88,53 @@ def _build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
masked_lm_prob=masked_lm_prob,
short_seq_prob=short_seq_prob,
binary_head=binary_head,
**kwargs
**kwargs,
)
# Set the original pointer so dataset remains the main dataset.
indexed_dataset.set_doc_idx(doc_idx_ptr)
# Checks.
assert indexed_dataset.doc_idx[0] == 0
assert indexed_dataset.doc_idx.shape[0] == \
(total_num_of_documents + 1)
assert indexed_dataset.doc_idx.shape[0] == (total_num_of_documents + 1)
return dataset
train_dataset = build_dataset(0, 'train')
valid_dataset = build_dataset(1, 'valid')
test_dataset = build_dataset(2, 'test')
train_dataset = build_dataset(0, "train")
valid_dataset = build_dataset(1, "valid")
test_dataset = build_dataset(2, "test")
return (train_dataset, valid_dataset, test_dataset)
def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
train_valid_test_num_samples,
max_seq_length, masked_lm_prob,
short_seq_prob, seed, skip_warmup,
binary_head,
dataset_type='standard_bert'):
def build_train_valid_test_datasets(
data_prefix,
data_impl,
splits_string,
train_valid_test_num_samples,
max_seq_length,
masked_lm_prob,
short_seq_prob,
seed,
skip_warmup,
binary_head,
dataset_type="standard_bert",
):
if len(data_prefix) == 1:
return _build_train_valid_test_datasets(data_prefix[0],
data_impl, splits_string,
train_valid_test_num_samples,
max_seq_length, masked_lm_prob,
short_seq_prob, seed,
skip_warmup,
binary_head,
dataset_type=dataset_type)
return _build_train_valid_test_datasets(
data_prefix[0],
data_impl,
splits_string,
train_valid_test_num_samples,
max_seq_length,
masked_lm_prob,
short_seq_prob,
seed,
skip_warmup,
binary_head,
dataset_type=dataset_type,
)
# Blending dataset.
# Parse the values.
output = get_datasets_weights_and_num_samples(data_prefix,
train_valid_test_num_samples)
output = get_datasets_weights_and_num_samples(data_prefix, train_valid_test_num_samples)
prefixes, weights, datasets_train_valid_test_num_samples = output
# Build individual datasets.
@@ -126,10 +143,18 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
test_datasets = []
for i in range(len(prefixes)):
train_ds, valid_ds, test_ds = _build_train_valid_test_datasets(
prefixes[i], data_impl, splits_string,
prefixes[i],
data_impl,
splits_string,
datasets_train_valid_test_num_samples[i],
max_seq_length, masked_lm_prob, short_seq_prob,
seed, skip_warmup, binary_head, dataset_type=dataset_type)
max_seq_length,
masked_lm_prob,
short_seq_prob,
seed,
skip_warmup,
binary_head,
dataset_type=dataset_type,
)
if train_ds:
train_datasets.append(train_ds)
if valid_ds:
@@ -148,5 +173,4 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
if test_datasets:
blending_test_dataset = BlendableDataset(test_datasets, weights)
return (blending_train_dataset, blending_valid_dataset,
blending_test_dataset)
return (blending_train_dataset, blending_valid_dataset, blending_test_dataset)

View File

@@ -14,7 +14,6 @@
# limitations under the License.
"""Dataloaders."""
import random
import torch
@@ -22,61 +21,60 @@ from colossalai.legacy.context import ParallelMode
from colossalai.legacy.core import global_context as gpc
def build_pretraining_data_loader(dataset, consumed_samples, micro_batch_size, dataloader_type='single', num_workers=0):
def build_pretraining_data_loader(dataset, consumed_samples, micro_batch_size, dataloader_type="single", num_workers=0):
"""Build dataloader given an input dataset."""
if dataset is None:
return None
# Megatron sampler
if dataloader_type == 'single':
batch_sampler = MegatronPretrainingSampler(total_samples=len(dataset),
consumed_samples=consumed_samples,
micro_batch_size=micro_batch_size,
data_parallel_rank=gpc.get_local_rank(ParallelMode.DATA),
data_parallel_size=gpc.get_world_size(ParallelMode.DATA))
elif dataloader_type == 'cyclic':
batch_sampler = MegatronPretrainingRandomSampler(total_samples=len(dataset),
consumed_samples=consumed_samples,
micro_batch_size=micro_batch_size,
data_parallel_rank=gpc.get_local_rank(ParallelMode.DATA),
data_parallel_size=gpc.get_world_size(ParallelMode.DATA))
if dataloader_type == "single":
batch_sampler = MegatronPretrainingSampler(
total_samples=len(dataset),
consumed_samples=consumed_samples,
micro_batch_size=micro_batch_size,
data_parallel_rank=gpc.get_local_rank(ParallelMode.DATA),
data_parallel_size=gpc.get_world_size(ParallelMode.DATA),
)
elif dataloader_type == "cyclic":
batch_sampler = MegatronPretrainingRandomSampler(
total_samples=len(dataset),
consumed_samples=consumed_samples,
micro_batch_size=micro_batch_size,
data_parallel_rank=gpc.get_local_rank(ParallelMode.DATA),
data_parallel_size=gpc.get_world_size(ParallelMode.DATA),
)
else:
raise Exception('{} dataloader type is not supported.'.format(dataloader_type))
raise Exception("{} dataloader type is not supported.".format(dataloader_type))
# Torch dataloader.
return torch.utils.data.DataLoader(dataset, batch_sampler=batch_sampler, num_workers=num_workers, pin_memory=True)
class MegatronPretrainingSampler:
def __init__(self,
total_samples,
consumed_samples,
micro_batch_size,
data_parallel_rank,
data_parallel_size,
drop_last=True):
def __init__(
self, total_samples, consumed_samples, micro_batch_size, data_parallel_rank, data_parallel_size, drop_last=True
):
# Keep a copy of input params for later use.
self.total_samples = total_samples
self.consumed_samples = consumed_samples
self.micro_batch_size = micro_batch_size
self.data_parallel_rank = data_parallel_rank
self.micro_batch_times_data_parallel_size = \
self.micro_batch_size * data_parallel_size
self.micro_batch_times_data_parallel_size = self.micro_batch_size * data_parallel_size
self.drop_last = drop_last
# Sanity checks.
assert self.total_samples > 0, \
'no sample to consume: {}'.format(self.total_samples)
assert self.consumed_samples < self.total_samples, \
'no samples left to consume: {}, {}'.format(self.consumed_samples,
self.total_samples)
assert self.total_samples > 0, "no sample to consume: {}".format(self.total_samples)
assert self.consumed_samples < self.total_samples, "no samples left to consume: {}, {}".format(
self.consumed_samples, self.total_samples
)
assert self.micro_batch_size > 0
assert data_parallel_size > 0
assert self.data_parallel_rank < data_parallel_size, \
'data_parallel_rank should be smaller than data size: {}, ' \
'{}'.format(self.data_parallel_rank, data_parallel_size)
assert (
self.data_parallel_rank < data_parallel_size
), "data_parallel_rank should be smaller than data size: {}, " "{}".format(
self.data_parallel_rank, data_parallel_size
)
def __len__(self):
return self.total_samples
@@ -103,7 +101,6 @@ class MegatronPretrainingSampler:
class MegatronPretrainingRandomSampler:
def __init__(self, total_samples, consumed_samples, micro_batch_size, data_parallel_rank, data_parallel_size):
# Keep a copy of input params for later use.
self.total_samples = total_samples
@@ -111,19 +108,18 @@ class MegatronPretrainingRandomSampler:
self.micro_batch_size = micro_batch_size
self.data_parallel_rank = data_parallel_rank
self.data_parallel_size = data_parallel_size
self.micro_batch_times_data_parallel_size = \
self.micro_batch_size * data_parallel_size
self.last_batch_size = \
self.total_samples % self.micro_batch_times_data_parallel_size
self.micro_batch_times_data_parallel_size = self.micro_batch_size * data_parallel_size
self.last_batch_size = self.total_samples % self.micro_batch_times_data_parallel_size
# Sanity checks.
assert self.total_samples > 0, \
'no sample to consume: {}'.format(self.total_samples)
assert self.total_samples > 0, "no sample to consume: {}".format(self.total_samples)
assert self.micro_batch_size > 0
assert data_parallel_size > 0
assert self.data_parallel_rank < data_parallel_size, \
'data_parallel_rank should be smaller than data size: {}, ' \
'{}'.format(self.data_parallel_rank, data_parallel_size)
assert (
self.data_parallel_rank < data_parallel_size
), "data_parallel_rank should be smaller than data size: {}, " "{}".format(
self.data_parallel_rank, data_parallel_size
)
def __len__(self):
return self.total_samples
@@ -135,8 +131,7 @@ class MegatronPretrainingRandomSampler:
assert current_epoch_samples % self.micro_batch_times_data_parallel_size == 0
# data sharding and random sampling
bucket_size = (self.total_samples // self.micro_batch_times_data_parallel_size) \
* self.micro_batch_size
bucket_size = (self.total_samples // self.micro_batch_times_data_parallel_size) * self.micro_batch_size
bucket_offset = current_epoch_samples // self.data_parallel_size
start_idx = self.data_parallel_rank * bucket_size

View File

@@ -18,32 +18,33 @@
# https://github.com/google-research/albert/blob/master/create_pretraining_data.py
# with some modifications.
import collections
import math
import time
import collections
from colossalai.logging import get_dist_logger
import numpy as np
from colossalai.logging import get_dist_logger
from .blendable_dataset import BlendableDataset
from .indexed_dataset import make_dataset as make_indexed_dataset
DSET_TYPE_STD = 'standard_bert'
DSET_TYPE_ICT = 'ict'
DSET_TYPE_STD = "standard_bert"
DSET_TYPE_ICT = "ict"
DSET_TYPES = [DSET_TYPE_ICT, DSET_TYPE_STD]
def get_datasets_weights_and_num_samples(data_prefix,
train_valid_test_num_samples):
def get_datasets_weights_and_num_samples(data_prefix, train_valid_test_num_samples):
# The data prefix should be in the format of:
# weight-1, data-prefix-1, weight-2, data-prefix-2, ..
assert len(data_prefix) % 2 == 0
num_datasets = len(data_prefix) // 2
weights = [0]*num_datasets
prefixes = [0]*num_datasets
weights = [0] * num_datasets
prefixes = [0] * num_datasets
for i in range(num_datasets):
weights[i] = float(data_prefix[2*i])
prefixes[i] = (data_prefix[2*i+1]).strip()
weights[i] = float(data_prefix[2 * i])
prefixes[i] = (data_prefix[2 * i + 1]).strip()
# Normalize weights
weight_sum = 0.0
for weight in weights:
@@ -57,8 +58,8 @@ def get_datasets_weights_and_num_samples(data_prefix,
datasets_train_valid_test_num_samples = []
for weight in weights:
datasets_train_valid_test_num_samples.append(
[int(math.ceil(val * weight * 1.005))
for val in train_valid_test_num_samples])
[int(math.ceil(val * weight * 1.005)) for val in train_valid_test_num_samples]
)
return prefixes, weights, datasets_train_valid_test_num_samples
@@ -68,11 +69,13 @@ def compile_helper():
is invoked on a single process."""
import os
import subprocess
path = os.path.abspath(os.path.dirname(__file__))
ret = subprocess.run(['make', '-C', path])
ret = subprocess.run(["make", "-C", path])
if ret.returncode != 0:
print("Making C++ dataset helpers module failed, exiting.")
import sys
sys.exit(1)
@@ -82,7 +85,7 @@ def get_a_and_b_segments(sample, np_rng):
# Number of sentences in the sample.
n_sentences = len(sample)
# Make sure we always have two sentences.
assert n_sentences > 1, 'make sure each sample has at least two sentences.'
assert n_sentences > 1, "make sure each sample has at least two sentences."
# First part:
# `a_end` is how many sentences go into the `A`.
@@ -110,7 +113,7 @@ def get_a_and_b_segments(sample, np_rng):
def truncate_segments(tokens_a, tokens_b, len_a, len_b, max_num_tokens, np_rng):
"""Truncates a pair of sequences to a maximum sequence length."""
#print(len_a, len_b, max_num_tokens)
# print(len_a, len_b, max_num_tokens)
assert len_a > 0
if len_a + len_b <= max_num_tokens:
return False
@@ -155,8 +158,7 @@ def create_tokens_and_tokentypes(tokens_a, tokens_b, cls_id, sep_id):
return tokens, tokentypes
MaskedLmInstance = collections.namedtuple("MaskedLmInstance",
["index", "label"])
MaskedLmInstance = collections.namedtuple("MaskedLmInstance", ["index", "label"])
def is_start_piece(piece):
@@ -168,16 +170,21 @@ def is_start_piece(piece):
return not piece.startswith("##")
def create_masked_lm_predictions(tokens,
vocab_id_list, vocab_id_to_token_dict,
masked_lm_prob,
cls_id, sep_id, mask_id,
max_predictions_per_seq,
np_rng,
max_ngrams=3,
do_whole_word_mask=True,
favor_longer_ngram=False,
do_permutation=False):
def create_masked_lm_predictions(
tokens,
vocab_id_list,
vocab_id_to_token_dict,
masked_lm_prob,
cls_id,
sep_id,
mask_id,
max_predictions_per_seq,
np_rng,
max_ngrams=3,
do_whole_word_mask=True,
favor_longer_ngram=False,
do_permutation=False,
):
"""Creates the predictions for the masked LM objective.
Note: Tokens here are vocab ids and not text tokens."""
@@ -187,7 +194,7 @@ def create_masked_lm_predictions(tokens,
# on-the-fly whole word masking is possible.
token_boundary = [0] * len(tokens)
for (i, token) in enumerate(tokens):
for i, token in enumerate(tokens):
if token == cls_id or token == sep_id:
token_boundary[i] = 1
continue
@@ -197,8 +204,7 @@ def create_masked_lm_predictions(tokens,
# Note that Whole Word Masking does *not* change the training code
# at all -- we still predict each WordPiece independently, softmaxed
# over the entire vocabulary.
if (do_whole_word_mask and len(cand_indexes) >= 1 and
not is_start_piece(vocab_id_to_token_dict[token])):
if do_whole_word_mask and len(cand_indexes) >= 1 and not is_start_piece(vocab_id_to_token_dict[token]):
cand_indexes[-1].append(i)
else:
cand_indexes.append([i])
@@ -211,16 +217,14 @@ def create_masked_lm_predictions(tokens,
masked_lm_labels = []
if masked_lm_prob == 0:
return (output_tokens, masked_lm_positions,
masked_lm_labels, token_boundary)
return (output_tokens, masked_lm_positions, masked_lm_labels, token_boundary)
num_to_predict = min(max_predictions_per_seq,
max(1, int(round(len(tokens) * masked_lm_prob))))
num_to_predict = min(max_predictions_per_seq, max(1, int(round(len(tokens) * masked_lm_prob))))
# Note(mingdachen):
# By default, we set the probabilities to favor shorter ngram sequences.
ngrams = np.arange(1, max_ngrams + 1, dtype=np.int64)
pvals = 1. / np.arange(1, max_ngrams + 1)
pvals = 1.0 / np.arange(1, max_ngrams + 1)
pvals /= pvals.sum(keepdims=True)
if favor_longer_ngram:
@@ -230,7 +234,7 @@ def create_masked_lm_predictions(tokens,
for idx in range(len(cand_indexes)):
ngram_index = []
for n in ngrams:
ngram_index.append(cand_indexes[idx:idx + n])
ngram_index.append(cand_indexes[idx : idx + n])
ngram_indexes.append(ngram_index)
np_rng.shuffle(ngram_indexes)
@@ -249,9 +253,10 @@ def create_masked_lm_predictions(tokens,
if index in covered_indexes:
continue
n = np_rng.choice(ngrams[:len(cand_index_set)],
p=pvals[:len(cand_index_set)] /
pvals[:len(cand_index_set)].sum(keepdims=True))
n = np_rng.choice(
ngrams[: len(cand_index_set)],
p=pvals[: len(cand_index_set)] / pvals[: len(cand_index_set)].sum(keepdims=True),
)
index_set = sum(cand_index_set[n - 1], [])
n -= 1
# Note(mingdachen):
@@ -309,9 +314,10 @@ def create_masked_lm_predictions(tokens,
if index in covered_indexes or index in select_indexes:
continue
n = np.random.choice(ngrams[:len(cand_index_set)],
p=pvals[:len(cand_index_set)] /
pvals[:len(cand_index_set)].sum(keepdims=True))
n = np.random.choice(
ngrams[: len(cand_index_set)],
p=pvals[: len(cand_index_set)] / pvals[: len(cand_index_set)].sum(keepdims=True),
)
index_set = sum(cand_index_set[n - 1], [])
n -= 1
@@ -353,8 +359,7 @@ def create_masked_lm_predictions(tokens,
return (output_tokens, masked_lm_positions, masked_lm_labels, token_boundary)
def pad_and_convert_to_numpy(tokens, tokentypes, masked_positions,
masked_labels, pad_id, max_seq_length):
def pad_and_convert_to_numpy(tokens, tokentypes, masked_positions, masked_labels, pad_id, max_seq_length):
"""Pad sequences and convert them to numpy."""
# Some checks.
@@ -370,8 +375,7 @@ def pad_and_convert_to_numpy(tokens, tokentypes, masked_positions,
tokentypes_np = np.array(tokentypes + filler, dtype=np.int64)
# Padding mask.
padding_mask_np = np.array([1] * num_tokens + [0] * padding_length,
dtype=np.int64)
padding_mask_np = np.array([1] * num_tokens + [0] * padding_length, dtype=np.int64)
# Lables and loss mask.
labels = [-1] * max_seq_length
@@ -386,26 +390,36 @@ def pad_and_convert_to_numpy(tokens, tokentypes, masked_positions,
return tokens_np, tokentypes_np, labels_np, padding_mask_np, loss_mask_np
def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
train_valid_test_num_samples,
max_seq_length, masked_lm_prob,
short_seq_prob, seed, skip_warmup,
binary_head,
dataset_type='standard_bert'):
def build_train_valid_test_datasets(
data_prefix,
data_impl,
splits_string,
train_valid_test_num_samples,
max_seq_length,
masked_lm_prob,
short_seq_prob,
seed,
skip_warmup,
binary_head,
dataset_type="standard_bert",
):
if len(data_prefix) == 1:
return _build_train_valid_test_datasets(data_prefix[0],
data_impl, splits_string,
train_valid_test_num_samples,
max_seq_length, masked_lm_prob,
short_seq_prob, seed,
skip_warmup,
binary_head,
dataset_type=dataset_type)
return _build_train_valid_test_datasets(
data_prefix[0],
data_impl,
splits_string,
train_valid_test_num_samples,
max_seq_length,
masked_lm_prob,
short_seq_prob,
seed,
skip_warmup,
binary_head,
dataset_type=dataset_type,
)
# Blending dataset.
# Parse the values.
output = get_datasets_weights_and_num_samples(data_prefix,
train_valid_test_num_samples)
output = get_datasets_weights_and_num_samples(data_prefix, train_valid_test_num_samples)
prefixes, weights, datasets_train_valid_test_num_samples = output
# Build individual datasets.
@@ -414,10 +428,18 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
test_datasets = []
for i in range(len(prefixes)):
train_ds, valid_ds, test_ds = _build_train_valid_test_datasets(
prefixes[i], data_impl, splits_string,
prefixes[i],
data_impl,
splits_string,
datasets_train_valid_test_num_samples[i],
max_seq_length, masked_lm_prob, short_seq_prob,
seed, skip_warmup, binary_head, dataset_type=dataset_type)
max_seq_length,
masked_lm_prob,
short_seq_prob,
seed,
skip_warmup,
binary_head,
dataset_type=dataset_type,
)
if train_ds:
train_datasets.append(train_ds)
if valid_ds:
@@ -436,31 +458,33 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
if test_datasets:
blending_test_dataset = BlendableDataset(test_datasets, weights)
return (blending_train_dataset, blending_valid_dataset,
blending_test_dataset)
return (blending_train_dataset, blending_valid_dataset, blending_test_dataset)
def _build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
train_valid_test_num_samples,
max_seq_length, masked_lm_prob,
short_seq_prob, seed, skip_warmup,
binary_head,
dataset_type='standard_bert'):
def _build_train_valid_test_datasets(
data_prefix,
data_impl,
splits_string,
train_valid_test_num_samples,
max_seq_length,
masked_lm_prob,
short_seq_prob,
seed,
skip_warmup,
binary_head,
dataset_type="standard_bert",
):
logger = get_dist_logger()
if dataset_type not in DSET_TYPES:
raise ValueError("Invalid dataset_type: ", dataset_type)
# Indexed dataset.
indexed_dataset = get_indexed_dataset_(data_prefix,
data_impl,
skip_warmup)
indexed_dataset = get_indexed_dataset_(data_prefix, data_impl, skip_warmup)
if dataset_type == DSET_TYPE_ICT:
args = get_args()
title_dataset = get_indexed_dataset_(args.titles_data_path,
data_impl,
skip_warmup)
title_dataset = get_indexed_dataset_(args.titles_data_path, data_impl, skip_warmup)
# Get start and end indices of train/valid/train into doc-idx
# Note that doc-idx is designed to be num-docs + 1 so we can
@@ -469,27 +493,29 @@ def _build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
splits = get_train_valid_test_split_(splits_string, total_num_of_documents)
# Print stats about the splits.
logger.info('\n > dataset split:')
logger.info("\n > dataset split:")
def print_split_stats(name, index):
start_index = indexed_dataset.doc_idx[splits[index]]
end_index = indexed_dataset.doc_idx[splits[index + 1]]
logger.info('\n {}:'.format(name) +
'\n document indices in [{}, {}) total of {} documents'.format(
splits[index],
splits[index + 1],
splits[index + 1] - splits[index]) +
'\n sentence indices in [{}, {}) total of {} sentences'.format(
start_index,
end_index,
end_index - start_index),
ranks=[0])
print_split_stats('train', 0)
print_split_stats('validation', 1)
print_split_stats('test', 2)
logger.info(
"\n {}:".format(name)
+ "\n document indices in [{}, {}) total of {} documents".format(
splits[index], splits[index + 1], splits[index + 1] - splits[index]
)
+ "\n sentence indices in [{}, {}) total of {} sentences".format(
start_index, end_index, end_index - start_index
),
ranks=[0],
)
print_split_stats("train", 0)
print_split_stats("validation", 1)
print_split_stats("test", 2)
def build_dataset(index, name):
from .bert_dataset import BertDataset
dataset = None
if splits[index + 1] > splits[index]:
# Get the pointer to the original doc-idx so we can set it later.
@@ -508,7 +534,7 @@ def _build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
max_num_samples=train_valid_test_num_samples[index],
max_seq_length=max_seq_length,
seed=seed,
binary_head=binary_head
binary_head=binary_head,
)
if dataset_type == DSET_TYPE_ICT:
@@ -518,27 +544,26 @@ def _build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
title_dataset=title_dataset,
query_in_block_prob=args.query_in_block_prob,
use_one_sent_docs=args.use_one_sent_docs,
**kwargs
**kwargs,
)
else:
dataset = BertDataset(
indexed_dataset=indexed_dataset,
masked_lm_prob=masked_lm_prob,
short_seq_prob=short_seq_prob,
**kwargs
**kwargs,
)
# Set the original pointer so dataset remains the main dataset.
indexed_dataset.set_doc_idx(doc_idx_ptr)
# Checks.
assert indexed_dataset.doc_idx[0] == 0
assert indexed_dataset.doc_idx.shape[0] == \
(total_num_of_documents + 1)
assert indexed_dataset.doc_idx.shape[0] == (total_num_of_documents + 1)
return dataset
train_dataset = build_dataset(0, 'train')
valid_dataset = build_dataset(1, 'valid')
test_dataset = build_dataset(2, 'test')
train_dataset = build_dataset(0, "train")
valid_dataset = build_dataset(1, "valid")
test_dataset = build_dataset(2, "test")
return (train_dataset, valid_dataset, test_dataset)
@@ -546,44 +571,41 @@ def _build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
def get_indexed_dataset_(data_prefix, data_impl, skip_warmup):
logger = get_dist_logger()
start_time = time.time()
indexed_dataset = make_indexed_dataset(data_prefix,
data_impl,
skip_warmup)
indexed_dataset = make_indexed_dataset(data_prefix, data_impl, skip_warmup)
assert indexed_dataset.sizes.shape[0] == indexed_dataset.doc_idx[-1]
logger.info('\n > building dataset index ...', ranks=[0])
logger.info('\n > finished creating indexed dataset in {:4f} '
'seconds'.format(time.time() - start_time), ranks=[0])
logger.info('\n > indexed dataset stats:' +
'\n number of documents: {}'.format(
indexed_dataset.doc_idx.shape[0] - 1) +
'\n number of sentences: {}'.format(
indexed_dataset.sizes.shape[0]),
ranks=[0]
)
logger.info("\n > building dataset index ...", ranks=[0])
logger.info(
"\n > finished creating indexed dataset in {:4f} " "seconds".format(time.time() - start_time), ranks=[0]
)
logger.info(
"\n > indexed dataset stats:"
+ "\n number of documents: {}".format(indexed_dataset.doc_idx.shape[0] - 1)
+ "\n number of sentences: {}".format(indexed_dataset.sizes.shape[0]),
ranks=[0],
)
return indexed_dataset
def get_train_valid_test_split_(splits_string, size):
""" Get dataset splits from comma or '/' separated string list."""
"""Get dataset splits from comma or '/' separated string list."""
splits = []
if splits_string.find(',') != -1:
splits = [float(s) for s in splits_string.split(',')]
elif splits_string.find('/') != -1:
splits = [float(s) for s in splits_string.split('/')]
if splits_string.find(",") != -1:
splits = [float(s) for s in splits_string.split(",")]
elif splits_string.find("/") != -1:
splits = [float(s) for s in splits_string.split("/")]
else:
splits = [float(splits_string)]
while len(splits) < 3:
splits.append(0.)
splits.append(0.0)
splits = splits[:3]
splits_sum = sum(splits)
assert splits_sum > 0.0
splits = [split / splits_sum for split in splits]
splits_index = [0]
for index, split in enumerate(splits):
splits_index.append(splits_index[index] +
int(round(split * float(size))))
splits_index.append(splits_index[index] + int(round(split * float(size))))
diff = splits_index[-1] - size
for index in range(1, len(splits_index)):
splits_index[index] -= diff

File diff suppressed because it is too large Load Diff

View File

@@ -2,12 +2,11 @@ import itertools
import random
import numpy as np
from torch.utils.data import Dataset
from megatron import get_tokenizer
from megatron import get_args
from megatron import get_args, get_tokenizer
from megatron.data.dataset_utils import get_indexed_dataset_
from megatron.data.realm_dataset_utils import get_block_samples_mapping
from torch.utils.data import Dataset
def make_attention_mask(source_block, target_block):
"""
@@ -20,16 +19,17 @@ def make_attention_mask(source_block, target_block):
# (source_length, target_length)
return mask
def get_ict_dataset(use_titles=True, query_in_block_prob=1):
"""Get a dataset which uses block samples mappings to get ICT/block indexing data (via get_block())
rather than for training, since it is only built with a single epoch sample mapping.
"""
args = get_args()
block_dataset = get_indexed_dataset_(args.data_path, 'mmap', True)
titles_dataset = get_indexed_dataset_(args.titles_data_path, 'mmap', True)
block_dataset = get_indexed_dataset_(args.data_path, "mmap", True)
titles_dataset = get_indexed_dataset_(args.titles_data_path, "mmap", True)
kwargs = dict(
name='full',
name="full",
block_dataset=block_dataset,
title_dataset=titles_dataset,
data_prefix=args.data_path,
@@ -39,7 +39,7 @@ def get_ict_dataset(use_titles=True, query_in_block_prob=1):
seed=1,
query_in_block_prob=query_in_block_prob,
use_titles=use_titles,
use_one_sent_docs=args.use_one_sent_docs
use_one_sent_docs=args.use_one_sent_docs,
)
dataset = ICTDataset(**kwargs)
return dataset
@@ -47,9 +47,22 @@ def get_ict_dataset(use_titles=True, query_in_block_prob=1):
class ICTDataset(Dataset):
"""Dataset containing sentences and their blocks for an inverse cloze task."""
def __init__(self, name, block_dataset, title_dataset, data_prefix,
num_epochs, max_num_samples, max_seq_length, query_in_block_prob,
seed, use_titles=True, use_one_sent_docs=False, binary_head=False):
def __init__(
self,
name,
block_dataset,
title_dataset,
data_prefix,
num_epochs,
max_num_samples,
max_seq_length,
query_in_block_prob,
seed,
use_titles=True,
use_one_sent_docs=False,
binary_head=False,
):
self.name = name
self.seed = seed
self.max_seq_length = max_seq_length
@@ -61,8 +74,16 @@ class ICTDataset(Dataset):
self.use_one_sent_docs = use_one_sent_docs
self.samples_mapping = get_block_samples_mapping(
block_dataset, title_dataset, data_prefix, num_epochs,
max_num_samples, max_seq_length, seed, name, use_one_sent_docs)
block_dataset,
title_dataset,
data_prefix,
num_epochs,
max_num_samples,
max_seq_length,
seed,
name,
use_one_sent_docs,
)
self.tokenizer = get_tokenizer()
self.vocab_id_list = list(self.tokenizer.inv_vocab.keys())
self.vocab_id_to_token_list = self.tokenizer.inv_vocab
@@ -99,8 +120,8 @@ class ICTDataset(Dataset):
# still need to truncate because blocks are concluded when
# the sentence lengths have exceeded max_seq_length.
query = query[:self.max_seq_length - 2]
block = list(itertools.chain(*block))[:self.max_seq_length - title_pad_offset]
query = query[: self.max_seq_length - 2]
block = list(itertools.chain(*block))[: self.max_seq_length - title_pad_offset]
query_tokens, query_pad_mask = self.concat_and_pad_tokens(query)
context_tokens, context_pad_mask = self.concat_and_pad_tokens(block, title)
@@ -111,13 +132,13 @@ class ICTDataset(Dataset):
block_data = sample_data.as_array()
sample = {
'query_tokens': query_tokens,
'query_mask': query_mask,
'query_pad_mask': query_pad_mask,
'context_tokens': context_tokens,
'context_mask': context_mask,
'context_pad_mask': context_pad_mask,
'block_data': block_data,
"query_tokens": query_tokens,
"query_mask": query_mask,
"query_pad_mask": query_pad_mask,
"context_tokens": context_tokens,
"context_mask": context_mask,
"context_pad_mask": context_pad_mask,
"block_data": block_data,
}
return sample
@@ -127,7 +148,7 @@ class ICTDataset(Dataset):
block = [self.block_dataset[i] for i in range(start_idx, end_idx)]
title = self.title_dataset[int(doc_idx)]
block = list(itertools.chain(*block))[:self.max_seq_length - (3 + len(title))]
block = list(itertools.chain(*block))[: self.max_seq_length - (3 + len(title))]
block_tokens, block_pad_mask = self.concat_and_pad_tokens(block, title)
return block_tokens, block_pad_mask

View File

@@ -27,17 +27,17 @@ def __best_fitting_dtype(vocab_size=None):
def get_available_dataset_impl():
return ['lazy', 'cached', 'mmap']
return ["lazy", "cached", "mmap"]
def infer_dataset_impl(path):
if IndexedDataset.exists(path):
with open(index_file_path(path), 'rb') as f:
with open(index_file_path(path), "rb") as f:
magic = f.read(8)
if magic == IndexedDataset._HDR_MAGIC:
return 'cached'
return "cached"
elif magic == MMapIndexedDataset.Index._HDR_MAGIC[:8]:
return 'mmap'
return "mmap"
else:
return None
else:
@@ -47,7 +47,7 @@ def infer_dataset_impl(path):
def make_builder(out_file, impl, vocab_size=None):
if impl == 'mmap':
if impl == "mmap":
return MMapIndexedDatasetBuilder(out_file, dtype=__best_fitting_dtype(vocab_size))
else:
return IndexedDatasetBuilder(out_file)
@@ -58,20 +58,20 @@ def make_dataset(path, impl, skip_warmup=False):
print(f"Dataset does not exist: {path}")
print("Path should be a basename that both .idx and .bin can be appended to get full filenames.")
return None
if impl == 'infer':
if impl == "infer":
impl = infer_dataset_impl(path)
if impl == 'lazy' and IndexedDataset.exists(path):
if impl == "lazy" and IndexedDataset.exists(path):
return IndexedDataset(path)
elif impl == 'cached' and IndexedDataset.exists(path):
elif impl == "cached" and IndexedDataset.exists(path):
return IndexedCachedDataset(path)
elif impl == 'mmap' and MMapIndexedDataset.exists(path):
elif impl == "mmap" and MMapIndexedDataset.exists(path):
return MMapIndexedDataset(path, skip_warmup)
print(f"Unknown dataset implementation: {impl}")
return None
def dataset_exists(path, impl):
if impl == 'mmap':
if impl == "mmap":
return MMapIndexedDataset.exists(path)
else:
return IndexedDataset.exists(path)
@@ -98,11 +98,11 @@ def code(dtype):
def index_file_path(prefix_path):
return prefix_path + '.idx'
return prefix_path + ".idx"
def data_file_path(prefix_path):
return prefix_path + '.bin'
return prefix_path + ".bin"
def create_doc_idx(sizes):
@@ -115,7 +115,8 @@ def create_doc_idx(sizes):
class IndexedDataset(torch.utils.data.Dataset):
"""Loader for IndexedDataset"""
_HDR_MAGIC = b'TNTIDX\x00\x00'
_HDR_MAGIC = b"TNTIDX\x00\x00"
def __init__(self, path):
super().__init__()
@@ -124,27 +125,28 @@ class IndexedDataset(torch.utils.data.Dataset):
self.read_index(path)
def read_index(self, path):
with open(index_file_path(path), 'rb') as f:
with open(index_file_path(path), "rb") as f:
magic = f.read(8)
assert magic == self._HDR_MAGIC, ('Index file doesn\'t match expected format. '
'Make sure that --dataset-impl is configured properly.')
assert magic == self._HDR_MAGIC, (
"Index file doesn't match expected format. " "Make sure that --dataset-impl is configured properly."
)
version = f.read(8)
assert struct.unpack('<Q', version) == (1,)
code, self.element_size = struct.unpack('<QQ', f.read(16))
assert struct.unpack("<Q", version) == (1,)
code, self.element_size = struct.unpack("<QQ", f.read(16))
self.dtype = dtypes[code]
self._len, self.s = struct.unpack('<QQ', f.read(16))
self.doc_count = struct.unpack('<Q', f.read(8))
self._len, self.s = struct.unpack("<QQ", f.read(16))
self.doc_count = struct.unpack("<Q", f.read(8))
self.dim_offsets = read_longs(f, self._len + 1)
self.data_offsets = read_longs(f, self._len + 1)
self.sizes = read_longs(f, self.s)
self.doc_idx = read_longs(f, self.doc_count)
def read_data(self, path):
self.data_file = open(data_file_path(path), 'rb', buffering=0)
self.data_file = open(data_file_path(path), "rb", buffering=0)
def check_index(self, i):
if i < 0 or i >= self._len:
raise IndexError('index out of range')
raise IndexError("index out of range")
def __del__(self):
if self.data_file:
@@ -157,7 +159,7 @@ class IndexedDataset(torch.utils.data.Dataset):
if isinstance(idx, int):
i = idx
self.check_index(i)
tensor_size = self.sizes[self.dim_offsets[i]:self.dim_offsets[i + 1]]
tensor_size = self.sizes[self.dim_offsets[i] : self.dim_offsets[i + 1]]
a = np.empty(tensor_size, dtype=self.dtype)
self.data_file.seek(self.data_offsets[i] * self.element_size)
self.data_file.readinto(a)
@@ -166,7 +168,7 @@ class IndexedDataset(torch.utils.data.Dataset):
start, stop, step = idx.indices(len(self))
if step != 1:
raise ValueError("Slices into indexed_dataset must be contiguous")
sizes = self.sizes[self.dim_offsets[start]:self.dim_offsets[stop]]
sizes = self.sizes[self.dim_offsets[start] : self.dim_offsets[stop]]
size = sum(sizes)
a = np.empty(size, dtype=self.dtype)
self.data_file.seek(self.data_offsets[start] * self.element_size)
@@ -186,15 +188,14 @@ class IndexedDataset(torch.utils.data.Dataset):
@staticmethod
def exists(path):
return (os.path.exists(index_file_path(path)) and os.path.exists(data_file_path(path)))
return os.path.exists(index_file_path(path)) and os.path.exists(data_file_path(path))
@property
def supports_prefetch(self):
return False # avoid prefetching to save memory
return False # avoid prefetching to save memory
class IndexedCachedDataset(IndexedDataset):
def __init__(self, path):
super().__init__(path)
self.cache = None
@@ -219,7 +220,7 @@ class IndexedCachedDataset(IndexedDataset):
for i in indices:
self.cache_index[i] = ptx
size = self.data_offsets[i + 1] - self.data_offsets[i]
a = self.cache[ptx:ptx + size]
a = self.cache[ptx : ptx + size]
self.data_file.seek(self.data_offsets[i] * self.element_size)
self.data_file.readinto(a)
ptx += size
@@ -233,10 +234,10 @@ class IndexedCachedDataset(IndexedDataset):
if isinstance(idx, int):
i = idx
self.check_index(i)
tensor_size = self.sizes[self.dim_offsets[i]:self.dim_offsets[i + 1]]
tensor_size = self.sizes[self.dim_offsets[i] : self.dim_offsets[i + 1]]
a = np.empty(tensor_size, dtype=self.dtype)
ptx = self.cache_index[i]
np.copyto(a, self.cache[ptx:ptx + a.size])
np.copyto(a, self.cache[ptx : ptx + a.size])
return a
elif isinstance(idx, slice):
# Hack just to make this work, can optimizer later if necessary
@@ -250,7 +251,7 @@ class IndexedDatasetBuilder(object):
element_sizes = {np.uint8: 1, np.int8: 1, np.int16: 2, np.int32: 4, np.int64: 8, float: 4, np.double: 8}
def __init__(self, out_file, dtype=np.int32):
self.out_file = open(out_file, 'wb')
self.out_file = open(out_file, "wb")
self.dtype = dtype
self.data_offsets = [0]
self.dim_offsets = [0]
@@ -280,7 +281,7 @@ class IndexedDatasetBuilder(object):
for dim_offset in index.dim_offsets[1:]:
self.dim_offsets.append(begin + dim_offset)
with open(data_file_path(another_file), 'rb') as f:
with open(data_file_path(another_file), "rb") as f:
while True:
data = f.read(1024)
if data:
@@ -290,12 +291,12 @@ class IndexedDatasetBuilder(object):
def finalize(self, index_file):
self.out_file.close()
index = open(index_file, 'wb')
index.write(b'TNTIDX\x00\x00')
index.write(struct.pack('<Q', 1))
index.write(struct.pack('<QQ', code(self.dtype), self.element_size))
index.write(struct.pack('<QQ', len(self.data_offsets) - 1, len(self.sizes)))
index.write(struct.pack('<Q', len(self.doc_idx)))
index = open(index_file, "wb")
index.write(b"TNTIDX\x00\x00")
index.write(struct.pack("<Q", 1))
index.write(struct.pack("<QQ", code(self.dtype), self.element_size))
index.write(struct.pack("<QQ", len(self.data_offsets) - 1, len(self.sizes)))
index.write(struct.pack("<Q", len(self.doc_idx)))
write_longs(index, self.dim_offsets)
write_longs(index, self.data_offsets)
write_longs(index, self.sizes)
@@ -304,27 +305,24 @@ class IndexedDatasetBuilder(object):
def _warmup_mmap_file(path):
with open(path, 'rb') as stream:
with open(path, "rb") as stream:
while stream.read(100 * 1024 * 1024):
pass
class MMapIndexedDataset(torch.utils.data.Dataset):
class Index(object):
_HDR_MAGIC = b'MMIDIDX\x00\x00'
_HDR_MAGIC = b"MMIDIDX\x00\x00"
@classmethod
def writer(cls, path, dtype):
class _Writer(object):
def __enter__(self):
self._file = open(path, 'wb')
self._file = open(path, "wb")
self._file.write(cls._HDR_MAGIC)
self._file.write(struct.pack('<Q', 1))
self._file.write(struct.pack('<B', code(dtype)))
self._file.write(struct.pack("<Q", 1))
self._file.write(struct.pack("<B", code(dtype)))
return self
@@ -343,19 +341,19 @@ class MMapIndexedDataset(torch.utils.data.Dataset):
def write(self, sizes, doc_idx):
pointers = self._get_pointers(sizes)
self._file.write(struct.pack('<Q', len(sizes)))
self._file.write(struct.pack('<Q', len(doc_idx)))
self._file.write(struct.pack("<Q", len(sizes)))
self._file.write(struct.pack("<Q", len(doc_idx)))
sizes = np.array(sizes, dtype=np.int32)
self._file.write(sizes.tobytes(order='C'))
self._file.write(sizes.tobytes(order="C"))
del sizes
pointers = np.array(pointers, dtype=np.int64)
self._file.write(pointers.tobytes(order='C'))
self._file.write(pointers.tobytes(order="C"))
del pointers
doc_idx = np.array(doc_idx, dtype=np.int64)
self._file.write(doc_idx.tobytes(order='C'))
self._file.write(doc_idx.tobytes(order="C"))
def __exit__(self, exc_type, exc_val, exc_tb):
self._file.close()
@@ -363,39 +361,41 @@ class MMapIndexedDataset(torch.utils.data.Dataset):
return _Writer()
def __init__(self, path, skip_warmup=False):
with open(path, 'rb') as stream:
with open(path, "rb") as stream:
magic_test = stream.read(9)
assert self._HDR_MAGIC == magic_test, ('Index file doesn\'t match expected format. '
'Make sure that --dataset-impl is configured properly.')
version = struct.unpack('<Q', stream.read(8))
assert self._HDR_MAGIC == magic_test, (
"Index file doesn't match expected format. " "Make sure that --dataset-impl is configured properly."
)
version = struct.unpack("<Q", stream.read(8))
assert (1,) == version
dtype_code, = struct.unpack('<B', stream.read(1))
(dtype_code,) = struct.unpack("<B", stream.read(1))
self._dtype = dtypes[dtype_code]
self._dtype_size = self._dtype().itemsize
self._len = struct.unpack('<Q', stream.read(8))[0]
self._doc_count = struct.unpack('<Q', stream.read(8))[0]
self._len = struct.unpack("<Q", stream.read(8))[0]
self._doc_count = struct.unpack("<Q", stream.read(8))[0]
offset = stream.tell()
if not skip_warmup:
print(" warming up index mmap file...")
_warmup_mmap_file(path)
self._bin_buffer_mmap = np.memmap(path, mode='r', order='C')
self._bin_buffer_mmap = np.memmap(path, mode="r", order="C")
self._bin_buffer = memoryview(self._bin_buffer_mmap)
print(" reading sizes...")
self._sizes = np.frombuffer(self._bin_buffer, dtype=np.int32, count=self._len, offset=offset)
print(" reading pointers...")
self._pointers = np.frombuffer(self._bin_buffer,
dtype=np.int64,
count=self._len,
offset=offset + self._sizes.nbytes)
self._pointers = np.frombuffer(
self._bin_buffer, dtype=np.int64, count=self._len, offset=offset + self._sizes.nbytes
)
print(" reading document index...")
self._doc_idx = np.frombuffer(self._bin_buffer,
dtype=np.int64,
count=self._doc_count,
offset=offset + self._sizes.nbytes + self._pointers.nbytes)
self._doc_idx = np.frombuffer(
self._bin_buffer,
dtype=np.int64,
count=self._doc_count,
offset=offset + self._sizes.nbytes + self._pointers.nbytes,
)
def __del__(self):
self._bin_buffer_mmap._mmap.close()
@@ -443,7 +443,7 @@ class MMapIndexedDataset(torch.utils.data.Dataset):
print(" warming up data mmap file...")
_warmup_mmap_file(data_file_path(self._path))
print(" creating numpy buffer of mmap...")
self._bin_buffer_mmap = np.memmap(data_file_path(self._path), mode='r', order='C')
self._bin_buffer_mmap = np.memmap(data_file_path(self._path), mode="r", order="C")
print(" creating memory view of numpy buffer...")
self._bin_buffer = memoryview(self._bin_buffer_mmap)
@@ -474,7 +474,7 @@ class MMapIndexedDataset(torch.utils.data.Dataset):
return sents
def get(self, idx, offset=0, length=None):
""" Retrieves a single item from the dataset with the option to only
"""Retrieves a single item from the dataset with the option to only
return a portion of the item.
get(idx) is the same as [idx] but get() does not support slicing.
@@ -506,20 +506,19 @@ class MMapIndexedDataset(torch.utils.data.Dataset):
@staticmethod
def exists(path):
return (os.path.exists(index_file_path(path)) and os.path.exists(data_file_path(path)))
return os.path.exists(index_file_path(path)) and os.path.exists(data_file_path(path))
class MMapIndexedDatasetBuilder(object):
def __init__(self, out_file, dtype=np.int64):
self._data_file = open(out_file, 'wb')
self._data_file = open(out_file, "wb")
self._dtype = dtype
self._sizes = []
self._doc_idx = [0]
def add_item(self, tensor):
np_array = np.array(tensor.numpy(), dtype=self._dtype)
self._data_file.write(np_array.tobytes(order='C'))
self._data_file.write(np_array.tobytes(order="C"))
self._sizes.append(np_array.size)
def end_document(self):
@@ -534,7 +533,7 @@ class MMapIndexedDatasetBuilder(object):
self._sizes.append(size)
# Concatenate data
with open(data_file_path(another_file), 'rb') as f:
with open(data_file_path(another_file), "rb") as f:
shutil.copyfileobj(f, self._data_file)
def finalize(self, index_file):

View File

@@ -2,13 +2,12 @@
# put some code used during development and manual testing of
# indexed_dataset.
from megatron.data import indexed_dataset
from megatron.tokenizer import build_tokenizer
import argparse
import os
import sys
import torch
from megatron.data import indexed_dataset
from megatron.tokenizer import build_tokenizer
script_dir = os.path.dirname(os.path.realpath(__file__))
sys.path.append(os.path.join(script_dir, "../../../"))
@@ -42,7 +41,7 @@ def test_indexed_dataset(args):
def test_indexed_dataset_get(args):
ds = indexed_dataset.make_dataset(args.data, args.dataset_impl)
tokenizer = build_tokenizer(args)
build_tokenizer(args)
size = ds.sizes[0]
print(f"size: {size}")
full = ds.get(0)
@@ -61,6 +60,7 @@ def test_indexed_dataset_get(args):
print(part)
# print(tokenizer.detokenize(part.data.tolist()))
# def test_albert_dataset(args):
# # tokenizer = FullBertTokenizer(args.vocab, do_lower_case=True)
# # idataset = indexed_dataset.make_dataset(args.data, args.dataset_impl)
@@ -81,34 +81,27 @@ def test_indexed_dataset_get(args):
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--data', type=str, help='prefix to data files')
parser.add_argument('--dataset-impl', type=str, default='infer',
choices=['lazy', 'cached', 'mmap', 'infer'])
parser.add_argument('--count', type=int, default=10,
help='Number of samples/documents to print')
parser.add_argument("--data", type=str, help="prefix to data files")
parser.add_argument("--dataset-impl", type=str, default="infer", choices=["lazy", "cached", "mmap", "infer"])
parser.add_argument("--count", type=int, default=10, help="Number of samples/documents to print")
group = parser.add_argument_group(title='tokenizer')
group.add_argument('--tokenizer-type', type=str, required=True,
choices=['BertWordPieceLowerCase',
'GPT2BPETokenizer'],
help='What type of tokenizer to use.')
group.add_argument('--vocab-file', type=str, default=None,
help='Path to the vocab file')
group.add_argument('--merge-file', type=str, default=None,
help='Path to the BPE merge file (if necessary).')
group = parser.add_argument_group(title="tokenizer")
group.add_argument(
"--tokenizer-type",
type=str,
required=True,
choices=["BertWordPieceLowerCase", "GPT2BPETokenizer"],
help="What type of tokenizer to use.",
)
group.add_argument("--vocab-file", type=str, default=None, help="Path to the vocab file")
group.add_argument("--merge-file", type=str, default=None, help="Path to the BPE merge file (if necessary).")
parser.add_argument('--epochs', type=int, default=5,
help='Number of epochs to plan for')
parser.add_argument('--max-num-samples', type=int, default=None,
help='Maximum number of samples to plan for')
parser.add_argument('--masked-lm-prob', type=float, default=0.15,
help='probability of masking tokens')
parser.add_argument('--seq-length', type=int, default=512,
help='maximum sequence length')
parser.add_argument('--short-seq-prob', type=float, default=0.1,
help='probability of creating a short sequence')
parser.add_argument('--seed', type=int, default=1234,
help='random seed')
parser.add_argument("--epochs", type=int, default=5, help="Number of epochs to plan for")
parser.add_argument("--max-num-samples", type=int, default=None, help="Maximum number of samples to plan for")
parser.add_argument("--masked-lm-prob", type=float, default=0.15, help="probability of masking tokens")
parser.add_argument("--seq-length", type=int, default=512, help="maximum sequence length")
parser.add_argument("--short-seq-prob", type=float, default=0.1, help="probability of creating a short sequence")
parser.add_argument("--seed", type=int, default=1234, help="random seed")
args = parser.parse_args()
args.rank = 0
args.make_vocab_size_divisible_by = 128
@@ -117,7 +110,7 @@ def main():
if args.dataset_impl == "infer":
args.dataset_impl = indexed_dataset.infer_dataset_impl(args.data)
# test_albert_dataset(args)
# test_albert_dataset(args)
test_indexed_dataset_get(args)

View File

@@ -1,8 +1,7 @@
import torch
class DummyDataloader():
class DummyDataloader:
def __init__(self, batch_size, vocab_size, seq_length):
self.batch_size = batch_size
self.vocab_size = vocab_size
@@ -10,30 +9,44 @@ class DummyDataloader():
self.step = 0
def generate(self):
tokens = torch.randint(low=0, high=self.vocab_size, size=(
self.batch_size,
self.seq_length,
))
types = torch.randint(low=0, high=3, size=(
self.batch_size,
self.seq_length,
))
tokens = torch.randint(
low=0,
high=self.vocab_size,
size=(
self.batch_size,
self.seq_length,
),
)
types = torch.randint(
low=0,
high=3,
size=(
self.batch_size,
self.seq_length,
),
)
sentence_order = torch.randint(low=0, high=2, size=(self.batch_size,))
loss_mask = torch.randint(low=0, high=2, size=(
self.batch_size,
self.seq_length,
))
loss_mask = torch.randint(
low=0,
high=2,
size=(
self.batch_size,
self.seq_length,
),
)
lm_labels = torch.randint(low=0, high=self.vocab_size, size=(self.batch_size, self.seq_length))
padding_mask = torch.randint(low=0, high=2, size=(self.batch_size, self.seq_length))
return dict(text=tokens,
types=types,
is_random=sentence_order,
loss_mask=loss_mask,
labels=lm_labels,
padding_mask=padding_mask)
return dict(
text=tokens,
types=types,
is_random=sentence_order,
loss_mask=loss_mask,
labels=lm_labels,
padding_mask=padding_mask,
)
def __iter__(self):
return self
def __next__(self):
return self.generate()
return self.generate()

View File

@@ -16,7 +16,6 @@
from .tokenizer import build_tokenizer
_TOKENIZER = None
_PADDED_VOCAB_SIZE = -1

View File

@@ -15,13 +15,12 @@
"""Tokenization classes."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import absolute_import, division, print_function
import collections
import re
import unicodedata
import six
@@ -43,14 +42,13 @@ def validate_case_matches_checkpoint(do_lower_case, init_checkpoint):
model_name = m.group(1)
lower_models = [
"uncased_L-24_H-1024_A-16", "uncased_L-12_H-768_A-12",
"multilingual_L-12_H-768_A-12", "chinese_L-12_H-768_A-12"
"uncased_L-24_H-1024_A-16",
"uncased_L-12_H-768_A-12",
"multilingual_L-12_H-768_A-12",
"chinese_L-12_H-768_A-12",
]
cased_models = [
"cased_L-12_H-768_A-12", "cased_L-24_H-1024_A-16",
"multi_cased_L-12_H-768_A-12"
]
cased_models = ["cased_L-12_H-768_A-12", "cased_L-24_H-1024_A-16", "multi_cased_L-12_H-768_A-12"]
is_bad_config = False
if model_name in lower_models and not do_lower_case:
@@ -71,8 +69,8 @@ def validate_case_matches_checkpoint(do_lower_case, init_checkpoint):
"However, `%s` seems to be a %s model, so you "
"should pass in `--do_lower_case=%s` so that the fine-tuning matches "
"how the model was pre-training. If this error is wrong, please "
"just comment out this check." % (actual_flag, init_checkpoint,
model_name, case_name, opposite_flag))
"just comment out this check." % (actual_flag, init_checkpoint, model_name, case_name, opposite_flag)
)
def convert_to_unicode(text):
@@ -183,27 +181,27 @@ class FullTokenizer(object):
@staticmethod
def convert_tokens_to_string(tokens, clean_up_tokenization_spaces=True):
""" Converts a sequence of tokens (string) in a single string. """
"""Converts a sequence of tokens (string) in a single string."""
def clean_up_tokenization(out_string):
""" Clean up a list of simple English tokenization artifacts
"""Clean up a list of simple English tokenization artifacts
like spaces before punctuations and abbreviated forms.
"""
out_string = (
out_string.replace(" .", ".")
.replace(" ?", "?")
.replace(" !", "!")
.replace(" ,", ",")
.replace(" ' ", "'")
.replace(" n't", "n't")
.replace(" 'm", "'m")
.replace(" 's", "'s")
.replace(" 've", "'ve")
.replace(" 're", "'re")
.replace(" ?", "?")
.replace(" !", "!")
.replace(" ,", ",")
.replace(" ' ", "'")
.replace(" n't", "n't")
.replace(" 'm", "'m")
.replace(" 's", "'s")
.replace(" 've", "'ve")
.replace(" 're", "'re")
)
return out_string
text = ' '.join(tokens).replace(' ##', '').strip()
text = " ".join(tokens).replace(" ##", "").strip()
if clean_up_tokenization_spaces:
clean_text = clean_up_tokenization(text)
return clean_text
@@ -303,14 +301,16 @@ class BasicTokenizer(object):
# as is Japanese Hiragana and Katakana. Those alphabets are used to write
# space-separated words, so they are not treated specially and handled
# like the all of the other languages.
if ((cp >= 0x4E00 and cp <= 0x9FFF) or #
(cp >= 0x3400 and cp <= 0x4DBF) or #
(cp >= 0x20000 and cp <= 0x2A6DF) or #
(cp >= 0x2A700 and cp <= 0x2B73F) or #
(cp >= 0x2B740 and cp <= 0x2B81F) or #
(cp >= 0x2B820 and cp <= 0x2CEAF) or
(cp >= 0xF900 and cp <= 0xFAFF) or #
(cp >= 0x2F800 and cp <= 0x2FA1F)): #
if (
(cp >= 0x4E00 and cp <= 0x9FFF)
or (cp >= 0x3400 and cp <= 0x4DBF) #
or (cp >= 0x20000 and cp <= 0x2A6DF) #
or (cp >= 0x2A700 and cp <= 0x2B73F) #
or (cp >= 0x2B740 and cp <= 0x2B81F) #
or (cp >= 0x2B820 and cp <= 0x2CEAF) #
or (cp >= 0xF900 and cp <= 0xFAFF)
or (cp >= 0x2F800 and cp <= 0x2FA1F) #
): #
return True
return False
@@ -320,7 +320,7 @@ class BasicTokenizer(object):
output = []
for char in text:
cp = ord(char)
if cp == 0 or cp == 0xfffd or _is_control(char):
if cp == 0 or cp == 0xFFFD or _is_control(char):
continue
if _is_whitespace(char):
output.append(" ")
@@ -422,8 +422,7 @@ def _is_punctuation(char):
# Characters such as "^", "$", and "`" are not in the Unicode
# Punctuation class but we treat them as punctuation anyways, for
# consistency.
if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or
(cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)):
if (cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126):
return True
cat = unicodedata.category(char)
if cat.startswith("P"):

View File

@@ -25,16 +25,15 @@ from .bert_tokenization import FullTokenizer as FullBertTokenizer
def build_tokenizer(vocab_file, tokenizer_type, vocab_extra_ids=0):
"""Initialize tokenizer."""
if not gpc.is_initialized(ParallelMode.GLOBAL) or gpc.get_global_rank() == 0:
print('> building {} tokenizer ...'.format(tokenizer_type), flush=True)
print("> building {} tokenizer ...".format(tokenizer_type), flush=True)
# Select and instantiate the tokenizer.
if tokenizer_type == 'BertWordPieceLowerCase':
if tokenizer_type == "BertWordPieceLowerCase":
tokenizer = _BertWordPieceTokenizer(vocab_file=vocab_file, lower_case=True, vocab_extra_ids=vocab_extra_ids)
elif tokenizer_type == 'BertWordPieceCase':
elif tokenizer_type == "BertWordPieceCase":
tokenizer = _BertWordPieceTokenizer(vocab_file=vocab_file, lower_case=False, vocab_extra_ids=vocab_extra_ids)
else:
raise NotImplementedError('{} tokenizer is not '
'implemented.'.format(tokenizer_type))
raise NotImplementedError("{} tokenizer is not " "implemented.".format(tokenizer_type))
# Add vocab size.
padded_vocab_size = _vocab_size_with_padding(tokenizer.vocab_size)
@@ -55,9 +54,11 @@ def _vocab_size_with_padding(orig_vocab_size, make_vocab_size_divisible_by=128):
while (after % multiple) != 0:
after += 1
if not gpc.is_initialized(ParallelMode.GLOBAL) or gpc.get_global_rank() == 0:
print(' > padded vocab (size: {}) with {} dummy tokens '
'(new size: {})'.format(orig_vocab_size, after - orig_vocab_size, after),
flush=True)
print(
" > padded vocab (size: {}) with {} dummy tokens "
"(new size: {})".format(orig_vocab_size, after - orig_vocab_size, after),
flush=True,
)
return after
@@ -77,46 +78,38 @@ class AbstractTokenizer(ABC):
@abstractmethod
def vocab(self):
"""Dictionary from vocab text token to id token."""
pass
@property
@abstractmethod
def inv_vocab(self):
"""Dictionary from vocab id token to text token."""
pass
@abstractmethod
def tokenize(self, text):
pass
def detokenize(self, token_ids):
raise NotImplementedError('detokenizer is not implemented for {} '
'tokenizer'.format(self.name))
raise NotImplementedError("detokenizer is not implemented for {} " "tokenizer".format(self.name))
@property
def cls(self):
raise NotImplementedError('CLS is not provided for {} '
'tokenizer'.format(self.name))
raise NotImplementedError("CLS is not provided for {} " "tokenizer".format(self.name))
@property
def sep(self):
raise NotImplementedError('SEP is not provided for {} '
'tokenizer'.format(self.name))
raise NotImplementedError("SEP is not provided for {} " "tokenizer".format(self.name))
@property
def pad(self):
raise NotImplementedError('PAD is not provided for {} '
'tokenizer'.format(self.name))
raise NotImplementedError("PAD is not provided for {} " "tokenizer".format(self.name))
@property
def eod(self):
raise NotImplementedError('EOD is not provided for {} '
'tokenizer'.format(self.name))
raise NotImplementedError("EOD is not provided for {} " "tokenizer".format(self.name))
@property
def mask(self):
raise NotImplementedError('MASK is not provided for {} '
'tokenizer'.format(self.name))
raise NotImplementedError("MASK is not provided for {} " "tokenizer".format(self.name))
class _BertWordPieceTokenizer(AbstractTokenizer):
@@ -124,24 +117,24 @@ class _BertWordPieceTokenizer(AbstractTokenizer):
def __init__(self, vocab_file, lower_case=True, vocab_extra_ids=0):
if lower_case:
name = 'BERT Lower Case'
name = "BERT Lower Case"
else:
name = 'BERT Upper Case'
name = "BERT Upper Case"
super().__init__(name)
self.tokenizer = FullBertTokenizer(vocab_file, do_lower_case=lower_case)
self.cls_id = self.tokenizer.vocab['[CLS]']
self.sep_id = self.tokenizer.vocab['[SEP]']
self.pad_id = self.tokenizer.vocab['[PAD]']
self.mask_id = self.tokenizer.vocab['[MASK]']
self.cls_id = self.tokenizer.vocab["[CLS]"]
self.sep_id = self.tokenizer.vocab["[SEP]"]
self.pad_id = self.tokenizer.vocab["[PAD]"]
self.mask_id = self.tokenizer.vocab["[MASK]"]
self._additional_special_tokens = []
# (dsachan) Add BOS and EOS tokens
SPECIAL_TOKENS = {'eos_token': '[EOS]', 'bos_token': '[BOS]'}
self._bos_token = '[BOS]'
SPECIAL_TOKENS = {"eos_token": "[EOS]", "bos_token": "[BOS]"}
self._bos_token = "[BOS]"
self.add_token(self._bos_token)
self._bos_token_id = self.vocab.get(self._bos_token)
self._eos_token = '[EOS]'
self._eos_token = "[EOS]"
self.add_token(self._eos_token)
self._eos_token_id = self.vocab.get(self._eos_token)
@@ -185,7 +178,7 @@ class _BertWordPieceTokenizer(AbstractTokenizer):
def decode_token_ids(self, token_ids):
tokens = self.tokenizer.convert_ids_to_tokens(token_ids)
exclude_list = ['[PAD]', '[CLS]']
exclude_list = ["[PAD]", "[CLS]"]
non_pads = [t for t in tokens if t not in exclude_list]
result = ""
@@ -215,32 +208,32 @@ class _BertWordPieceTokenizer(AbstractTokenizer):
@property
def bos_token(self):
""" Beginning of sentence token id """
"""Beginning of sentence token id"""
return self._bos_token
@property
def eos_token(self):
""" End of sentence token id """
"""End of sentence token id"""
return self._eos_token
@property
def additional_special_tokens(self):
""" All the additional special tokens you may want to use (list of strings)."""
"""All the additional special tokens you may want to use (list of strings)."""
return self._additional_special_tokens
@property
def bos_token_id(self):
""" Id of the beginning of sentence token in the vocabulary."""
"""Id of the beginning of sentence token in the vocabulary."""
return self._bos_token_id
@property
def eos_token_id(self):
""" Id of the end of sentence token in the vocabulary."""
"""Id of the end of sentence token in the vocabulary."""
return self._eos_token_id
@property
def additional_special_tokens_ids(self):
""" Ids of all the additional special tokens in the vocabulary (list of integers)."""
"""Ids of all the additional special tokens in the vocabulary (list of integers)."""
return [self.vocab.get(token) for token in self._additional_special_tokens]
@additional_special_tokens.setter

View File

@@ -1,17 +1,12 @@
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
from colossalai.legacy.context import ParallelMode
from colossalai.legacy.core import global_context as gpc
from colossalai.logging import get_dist_logger
from .cross_entropy import vocab_cross_entropy
class BertLoss(nn.Module):
def forward(self, lm_loss, sop_logits, loss_mask, sentence_order):
lm_loss_ = lm_loss.float()
loss_mask = loss_mask.float()

View File

@@ -1,11 +1,8 @@
import torch
from torch.cuda.amp import custom_bwd, custom_fwd
from colossalai.legacy.context.parallel_mode import ParallelMode
class _VocabCrossEntropy(torch.autograd.Function):
@staticmethod
@custom_fwd
def forward(ctx, vocab_parallel_logits, target):
@@ -59,7 +56,7 @@ class _VocabCrossEntropy(torch.autograd.Function):
# Add the gradient from matching classes.
arange_1d = torch.arange(start=0, end=grad_2d.size()[0], device=grad_2d.device)
grad_2d[arange_1d, masked_target_1d] -= (1.0 - target_mask.view(-1).float())
grad_2d[arange_1d, masked_target_1d] -= 1.0 - target_mask.view(-1).float()
# Finally elementwise multiplication with the output gradients.
grad_input.mul_(grad_output.unsqueeze(dim=-1))

View File

@@ -1,11 +1,9 @@
import torch
def ensure_divisibility(numerator, denominator):
"""Ensure that numerator is divisible by the denominator."""
assert numerator % denominator == 0, '{} is not divisible by {}'.format(
numerator, denominator)
assert numerator % denominator == 0, "{} is not divisible by {}".format(numerator, denominator)
def divide(numerator, denominator):
@@ -15,8 +13,7 @@ def divide(numerator, denominator):
return numerator // denominator
def split_tensor_along_last_dim(tensor, num_partitions,
contiguous_split_chunks=False):
def split_tensor_along_last_dim(tensor, num_partitions, contiguous_split_chunks=False):
"""Split a tensor along its last dimension.
Arguments:
tensor: input tensor.
@@ -38,12 +35,11 @@ def split_tensor_along_last_dim(tensor, num_partitions,
class VocabUtility:
"""Split the vocabulary into `world_size` chunks amd return the
first and last index of the vocabulary belonging to the `rank`
partition: Note that indices in [fist, last)"""
first and last index of the vocabulary belonging to the `rank`
partition: Note that indices in [fist, last)"""
@staticmethod
def vocab_range_from_per_partition_vocab_size(per_partition_vocab_size,
rank, world_size):
def vocab_range_from_per_partition_vocab_size(per_partition_vocab_size, rank, world_size):
index_f = rank * per_partition_vocab_size
index_l = index_f + per_partition_vocab_size
return index_f, index_l
@@ -51,5 +47,4 @@ class VocabUtility:
@staticmethod
def vocab_range_from_global_vocab_size(global_vocab_size, rank, world_size):
per_partition_vocab_size = divide(global_vocab_size, world_size)
return VocabUtility.vocab_range_from_per_partition_vocab_size(
per_partition_vocab_size, rank, world_size)
return VocabUtility.vocab_range_from_per_partition_vocab_size(per_partition_vocab_size, rank, world_size)

View File

@@ -21,16 +21,17 @@ import math
class AnnealingLR(object):
"""Anneals the learning rate."""
def __init__(self,
optimizer,
max_lr,
min_lr,
warmup_steps,
decay_steps,
decay_style,
use_checkpoint_lr_scheduler=True,
override_lr_scheduler=False):
def __init__(
self,
optimizer,
max_lr,
min_lr,
warmup_steps,
decay_steps,
decay_style,
use_checkpoint_lr_scheduler=True,
override_lr_scheduler=False,
):
# Class values.
self.optimizer = optimizer
@@ -50,23 +51,21 @@ class AnnealingLR(object):
self.override_lr_scheduler = override_lr_scheduler
self.use_checkpoint_lr_scheduler = use_checkpoint_lr_scheduler
if self.override_lr_scheduler:
assert not self.use_checkpoint_lr_scheduler, 'both override and '\
'use-checkpoint are set.'
assert not self.use_checkpoint_lr_scheduler, "both override and " "use-checkpoint are set."
# Set the learning rate
self.step(0)
def get_lr(self):
"""Learning rate decay functions from:
https://openreview.net/pdf?id=BJYwwY9ll pg. 4"""
https://openreview.net/pdf?id=BJYwwY9ll pg. 4"""
# Use linear warmup for the initial part.
if self.warmup_steps > 0 and self.num_steps <= self.warmup_steps:
return self.max_lr * float(self.num_steps) / \
float(self.warmup_steps)
return self.max_lr * float(self.num_steps) / float(self.warmup_steps)
# If the learning rate is constant, just return the initial value.
if self.decay_style == 'constant':
if self.decay_style == "constant":
return self.max_lr
# For any steps larger than `self.decay_steps`, use `self.min_lr`.
@@ -81,13 +80,12 @@ class AnnealingLR(object):
assert decay_ratio <= 1.0
delta_lr = self.max_lr - self.min_lr
if self.decay_style == 'linear':
coeff = (1.0 - decay_ratio)
elif self.decay_style == 'cosine':
if self.decay_style == "linear":
coeff = 1.0 - decay_ratio
elif self.decay_style == "cosine":
coeff = 0.5 * (math.cos(math.pi * decay_ratio) + 1.0)
else:
raise Exception('{} decay style is not supported.'.format(
self.decay_style))
raise Exception("{} decay style is not supported.".format(self.decay_style))
return self.min_lr + coeff * delta_lr
@@ -96,16 +94,16 @@ class AnnealingLR(object):
self.num_steps += increment
new_lr = self.get_lr()
for group in self.optimizer.param_groups:
group['lr'] = new_lr
group["lr"] = new_lr
def state_dict(self):
state_dict = {
'max_lr': self.max_lr,
'warmup_steps': self.warmup_steps,
'num_steps': self.num_steps,
'decay_style': self.decay_style,
'decay_steps': self.decay_steps,
'min_lr': self.min_lr
"max_lr": self.max_lr,
"warmup_steps": self.warmup_steps,
"num_steps": self.num_steps,
"decay_style": self.decay_style,
"decay_steps": self.decay_steps,
"min_lr": self.min_lr,
}
return state_dict
@@ -116,43 +114,35 @@ class AnnealingLR(object):
return cls_value
if not self.use_checkpoint_lr_scheduler:
assert cls_value == sd_value, \
f'AnnealingLR: class input value {cls_value} and checkpoint' \
f'value {sd_value} for {name} do not match'
assert cls_value == sd_value, (
f"AnnealingLR: class input value {cls_value} and checkpoint" f"value {sd_value} for {name} do not match"
)
return sd_value
def load_state_dict(self, sd):
if 'start_lr' in sd:
max_lr_ = sd['start_lr']
if "start_lr" in sd:
max_lr_ = sd["start_lr"]
else:
max_lr_ = sd['max_lr']
self.max_lr = self._check_and_set(self.max_lr, max_lr_,
'learning rate')
max_lr_ = sd["max_lr"]
self.max_lr = self._check_and_set(self.max_lr, max_lr_, "learning rate")
self.min_lr = self._check_and_set(self.min_lr, sd['min_lr'],
'minimum learning rate')
self.min_lr = self._check_and_set(self.min_lr, sd["min_lr"], "minimum learning rate")
if 'warmup_iter' in sd:
warmup_steps_ = sd['warmup_iter']
if "warmup_iter" in sd:
warmup_steps_ = sd["warmup_iter"]
else:
warmup_steps_ = sd['warmup_steps']
self.warmup_steps = self._check_and_set(self.warmup_steps,
warmup_steps_,
'warmup iterations')
warmup_steps_ = sd["warmup_steps"]
self.warmup_steps = self._check_and_set(self.warmup_steps, warmup_steps_, "warmup iterations")
if 'end_iter' in sd:
decay_steps_ = sd['end_iter']
if "end_iter" in sd:
decay_steps_ = sd["end_iter"]
else:
decay_steps_ = sd['decay_steps']
self.decay_steps = self._check_and_set(self.decay_steps, decay_steps_,
'total number of iterations')
self.decay_style = self._check_and_set(self.decay_style,
sd['decay_style'],
'decay style')
decay_steps_ = sd["decay_steps"]
self.decay_steps = self._check_and_set(self.decay_steps, decay_steps_, "total number of iterations")
self.decay_style = self._check_and_set(self.decay_style, sd["decay_style"], "decay style")
if 'num_iters' in sd:
num_steps = sd['num_iters']
if "num_iters" in sd:
num_steps = sd["num_iters"]
else:
num_steps = sd['num_steps']
num_steps = sd["num_steps"]
self.step(increment=num_steps)

View File

@@ -16,7 +16,6 @@ from .layers.init_method import init_normal, output_init_normal
class BertForPretrain(nn.Module):
def __init__(
self,
vocab_size,
@@ -34,7 +33,9 @@ class BertForPretrain(nn.Module):
):
super().__init__()
self.seq_parallel_size = gpc.get_world_size(ParallelMode.SEQUENCE)
assert max_sequence_length % self.seq_parallel_size == 0, 'sequence length is not divisible by the sequence parallel size'
assert (
max_sequence_length % self.seq_parallel_size == 0
), "sequence length is not divisible by the sequence parallel size"
self.sub_seq_length = max_sequence_length // self.seq_parallel_size
self.init_std = init_std
self.num_layers = num_layers
@@ -43,28 +44,32 @@ class BertForPretrain(nn.Module):
num_tokentypes = 0
self.preprocessor = PreProcessor(self.sub_seq_length)
self.embedding = Embedding(hidden_size=hidden_size,
vocab_size=vocab_size,
max_sequence_length=max_sequence_length,
embedding_dropout_prob=dropout_prob,
num_tokentypes=num_tokentypes)
self.embedding = Embedding(
hidden_size=hidden_size,
vocab_size=vocab_size,
max_sequence_length=max_sequence_length,
embedding_dropout_prob=dropout_prob,
num_tokentypes=num_tokentypes,
)
self.bert_layers = nn.ModuleList()
for i in range(num_layers):
bert_layer = BertLayer(layer_number=i + 1,
hidden_size=hidden_size,
num_attention_heads=num_attention_heads,
attention_dropout=dropout_prob,
mlp_ratio=mlp_ratio,
hidden_dropout=dropout_prob,
convert_fp16_to_fp32_in_softmax=convert_fp16_to_fp32_in_softmax,
is_naive_fp16=is_naive_fp16)
bert_layer = BertLayer(
layer_number=i + 1,
hidden_size=hidden_size,
num_attention_heads=num_attention_heads,
attention_dropout=dropout_prob,
mlp_ratio=mlp_ratio,
hidden_dropout=dropout_prob,
convert_fp16_to_fp32_in_softmax=convert_fp16_to_fp32_in_softmax,
is_naive_fp16=is_naive_fp16,
)
self.bert_layers.append(bert_layer)
self.layer_norm = LayerNorm(hidden_size)
self.head = BertDualHead(hidden_size,
self.embedding.word_embedding_weight.size(0),
add_binary_head=add_binary_head)
self.head = BertDualHead(
hidden_size, self.embedding.word_embedding_weight.size(0), add_binary_head=add_binary_head
)
self.reset_parameters()
def _init_normal(self, tensor):
@@ -122,27 +127,30 @@ class BertForPretrain(nn.Module):
class PipelineBertForPretrain(nn.Module):
def __init__(self,
vocab_size,
hidden_size,
max_sequence_length,
num_attention_heads,
num_layers,
add_binary_head,
is_naive_fp16,
num_tokentypes=2,
dropout_prob=0.1,
mlp_ratio=4,
init_std=0.02,
convert_fp16_to_fp32_in_softmax=False,
first_stage=True,
last_stage=True,
start_idx=None,
end_idx=None):
def __init__(
self,
vocab_size,
hidden_size,
max_sequence_length,
num_attention_heads,
num_layers,
add_binary_head,
is_naive_fp16,
num_tokentypes=2,
dropout_prob=0.1,
mlp_ratio=4,
init_std=0.02,
convert_fp16_to_fp32_in_softmax=False,
first_stage=True,
last_stage=True,
start_idx=None,
end_idx=None,
):
super().__init__()
self.seq_parallel_size = gpc.get_world_size(ParallelMode.SEQUENCE)
assert max_sequence_length % self.seq_parallel_size == 0, 'sequence length is not divisible by the sequence parallel size'
assert (
max_sequence_length % self.seq_parallel_size == 0
), "sequence length is not divisible by the sequence parallel size"
self.sub_seq_length = max_sequence_length // self.seq_parallel_size
self.init_std = init_std
self.num_layers = num_layers
@@ -156,11 +164,13 @@ class PipelineBertForPretrain(nn.Module):
self.preprocessor = PreProcessor(self.sub_seq_length)
if self.first_stage:
self.embedding = Embedding(hidden_size=hidden_size,
vocab_size=vocab_size,
max_sequence_length=max_sequence_length,
embedding_dropout_prob=dropout_prob,
num_tokentypes=num_tokentypes)
self.embedding = Embedding(
hidden_size=hidden_size,
vocab_size=vocab_size,
max_sequence_length=max_sequence_length,
embedding_dropout_prob=dropout_prob,
num_tokentypes=num_tokentypes,
)
# transformer layers
self.bert_layers = nn.ModuleList()
@@ -170,14 +180,16 @@ class PipelineBertForPretrain(nn.Module):
end_idx = num_layers
for i in range(start_idx, end_idx):
bert_layer = BertLayer(layer_number=i + 1,
hidden_size=hidden_size,
num_attention_heads=num_attention_heads,
attention_dropout=dropout_prob,
mlp_ratio=mlp_ratio,
hidden_dropout=dropout_prob,
convert_fp16_to_fp32_in_softmax=convert_fp16_to_fp32_in_softmax,
is_naive_fp16=is_naive_fp16)
bert_layer = BertLayer(
layer_number=i + 1,
hidden_size=hidden_size,
num_attention_heads=num_attention_heads,
attention_dropout=dropout_prob,
mlp_ratio=mlp_ratio,
hidden_dropout=dropout_prob,
convert_fp16_to_fp32_in_softmax=convert_fp16_to_fp32_in_softmax,
is_naive_fp16=is_naive_fp16,
)
self.bert_layers.append(bert_layer)
if self.last_stage:
@@ -256,7 +268,7 @@ def _filter_kwargs(func, kwargs):
return {k: v for k, v in kwargs.items() if k in sig.parameters}
def build_pipeline_bert(num_layers, num_chunks, device=torch.device('cuda'), **kwargs):
def build_pipeline_bert(num_layers, num_chunks, device=torch.device("cuda"), **kwargs):
logger = get_dist_logger()
pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE)
pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
@@ -265,12 +277,12 @@ def build_pipeline_bert(num_layers, num_chunks, device=torch.device('cuda'), **k
parts = partition_uniform(num_layers, pipeline_size, num_chunks)[pipeline_rank]
models = []
for start, end in parts:
kwargs['num_layers'] = num_layers
kwargs['start_idx'] = start
kwargs['end_idx'] = end
kwargs['first_stage'] = start == 0
kwargs['last_stage'] = end == num_layers
logger.info(f'Rank{rank} build layer {start}-{end}, {end-start}/{num_layers} layers')
kwargs["num_layers"] = num_layers
kwargs["start_idx"] = start
kwargs["end_idx"] = end
kwargs["first_stage"] = start == 0
kwargs["last_stage"] = end == num_layers
logger.info(f"Rank{rank} build layer {start}-{end}, {end-start}/{num_layers} layers")
chunk = PipelineBertForPretrain(**_filter_kwargs(PipelineBertForPretrain.__init__, kwargs)).to(device)
if start == 0:
wrapper.register_module(chunk.embedding.word_embeddings)

View File

@@ -1,4 +1,4 @@
from .embedding import VocabEmbedding, Embedding
from .bert_layer import BertLayer
from .embedding import Embedding, VocabEmbedding
from .head import BertDualHead
from .preprocess import PreProcessor

View File

@@ -20,18 +20,20 @@ class BertLayer(nn.Module):
output of the same size.
"""
def __init__(self,
layer_number,
hidden_size,
num_attention_heads,
attention_dropout,
mlp_ratio,
hidden_dropout,
is_naive_fp16,
apply_residual_connection_post_layernorm=False,
fp32_residual_connection=False,
bias_dropout_fusion: bool = True,
convert_fp16_to_fp32_in_softmax: bool = False):
def __init__(
self,
layer_number,
hidden_size,
num_attention_heads,
attention_dropout,
mlp_ratio,
hidden_dropout,
is_naive_fp16,
apply_residual_connection_post_layernorm=False,
fp32_residual_connection=False,
bias_dropout_fusion: bool = True,
convert_fp16_to_fp32_in_softmax: bool = False,
):
super().__init__()
self.layer_number = layer_number
@@ -50,7 +52,8 @@ class BertLayer(nn.Module):
layer_number=layer_number,
apply_query_key_layer_scaling=True,
convert_fp16_to_fp32_in_softmax=convert_fp16_to_fp32_in_softmax,
fp16=is_naive_fp16)
fp16=is_naive_fp16,
)
self.hidden_dropout = hidden_dropout
self.bias_dropout_fusion = bias_dropout_fusion
@@ -90,8 +93,9 @@ class BertLayer(nn.Module):
# re-enable torch grad to enable fused optimization.
with torch.enable_grad():
layernorm_input = bias_dropout_add_func(attention_output, attention_bias.expand_as(residual), residual,
self.hidden_dropout)
layernorm_input = bias_dropout_add_func(
attention_output, attention_bias.expand_as(residual), residual, self.hidden_dropout
)
# Layer norm post the self attention.
layernorm_output = self.post_attention_layernorm(layernorm_input)

View File

@@ -1,5 +1,6 @@
import torch
def bias_dropout_add(x, bias, residual, prob, training):
# type: (Tensor, Tensor, Tensor, float, bool) -> Tensor
out = torch.nn.functional.dropout(x + bias, p=prob, training=training)
@@ -10,4 +11,5 @@ def bias_dropout_add(x, bias, residual, prob, training):
def get_bias_dropout_add(training):
def _bias_dropout_add(x, bias, residual, prob):
return bias_dropout_add(x, bias, residual, prob, training)
return _bias_dropout_add
return _bias_dropout_add

View File

@@ -5,7 +5,6 @@ import torch.nn.init as init
class VocabEmbedding(torch.nn.Module):
def __init__(self, num_embeddings, embedding_dim):
super(VocabEmbedding, self).__init__()
# Keep the input dimensions.
@@ -13,26 +12,29 @@ class VocabEmbedding(torch.nn.Module):
self.embedding_dim = embedding_dim
self.padding_idx = None
self.max_norm = None
self.norm_type = 2.
self.norm_type = 2.0
self.scale_grad_by_freq = False
self.sparse = False
self._weight = None
# Allocate weights and initialize.
self.weight = nn.Parameter(torch.empty(
self.num_embeddings, self.embedding_dim))
self.weight = nn.Parameter(torch.empty(self.num_embeddings, self.embedding_dim))
init.xavier_uniform_(self.weight)
def forward(self, hidden_state):
output = F.embedding(hidden_state, self.weight,
self.padding_idx, self.max_norm,
self.norm_type, self.scale_grad_by_freq,
self.sparse)
output = F.embedding(
hidden_state,
self.weight,
self.padding_idx,
self.max_norm,
self.norm_type,
self.scale_grad_by_freq,
self.sparse,
)
return output
def __repr__(self):
return f'VocabEmbedding(num_embeddings={self.num_embeddings}, ' \
f'embedding_dim={self.embedding_dim})'
return f"VocabEmbedding(num_embeddings={self.num_embeddings}, " f"embedding_dim={self.embedding_dim})"
class Embedding(nn.Module):
@@ -48,12 +50,7 @@ class Embedding(nn.Module):
will ignore this embedding
"""
def __init__(self,
hidden_size,
vocab_size,
max_sequence_length,
embedding_dropout_prob,
num_tokentypes):
def __init__(self, hidden_size, vocab_size, max_sequence_length, embedding_dropout_prob, num_tokentypes):
super(Embedding, self).__init__()
self.hidden_size = hidden_size
@@ -62,16 +59,14 @@ class Embedding(nn.Module):
self.word_embeddings = VocabEmbedding(vocab_size, self.hidden_size)
# Position embedding (serial).
self.position_embeddings = torch.nn.Embedding(
max_sequence_length, self.hidden_size)
self.position_embeddings = torch.nn.Embedding(max_sequence_length, self.hidden_size)
# Token type embedding.
# Add this as an optional field that can be added through
# method call so we can load a pretrain model without
# token types and add them as needed.
if self.num_tokentypes > 0:
self.tokentype_embeddings = torch.nn.Embedding(self.num_tokentypes,
self.hidden_size)
self.tokentype_embeddings = torch.nn.Embedding(self.num_tokentypes, self.hidden_size)
else:
self.tokentype_embeddings = None

View File

@@ -3,12 +3,10 @@ import torch.nn as nn
import torch.nn.functional as F
from loss_func.cross_entropy import vocab_cross_entropy
import colossalai
from colossalai.kernel import LayerNorm
from colossalai.legacy.context import ParallelMode
from colossalai.legacy.core import global_context as gpc
from .embedding import VocabEmbedding
from .linear import Linear
from .pooler import Pooler
@@ -26,7 +24,6 @@ class BertLMHead(nn.Module):
vocab_size,
hidden_size,
):
super(BertLMHead, self).__init__()
self.bias = torch.nn.Parameter(torch.zeros(vocab_size))
@@ -46,7 +43,6 @@ class BertLMHead(nn.Module):
class BertBinaryHead(nn.Module):
def __init__(self, hidden_size):
super().__init__()
self.pooler = Pooler(hidden_size)
@@ -62,7 +58,6 @@ class BertBinaryHead(nn.Module):
class BertDualHead(nn.Module):
def __init__(self, hidden_size, vocab_size, add_binary_head):
super().__init__()
self.lm_head = BertLMHead(vocab_size, hidden_size)

View File

@@ -1,6 +1,8 @@
import torch
import math
import torch
def init_normal(tensor, sigma):
"""Init method based on N(0, sigma)."""
torch.nn.init.normal_(tensor, mean=0.0, std=sigma)

View File

@@ -1,8 +1,8 @@
import torch
import torch.nn as nn
from torch.nn import Parameter
import torch.nn.functional as F
import torch.nn.init as init
from torch.nn import Parameter
class Linear(nn.Module):
@@ -24,11 +24,7 @@ class Linear(nn.Module):
adding bias but instead return it.
"""
def __init__(self,
input_size,
output_size,
bias=True,
skip_bias_add=False):
def __init__(self, input_size, output_size, bias=True, skip_bias_add=False):
super(Linear, self).__init__()
# Keep input parameters
@@ -36,9 +32,12 @@ class Linear(nn.Module):
self.output_size = output_size
self.skip_bias_add = skip_bias_add
self.weight = Parameter(torch.empty(self.output_size,
self.input_size,
))
self.weight = Parameter(
torch.empty(
self.output_size,
self.input_size,
)
)
init.normal_(self.weight)
if bias:
self.bias = Parameter(torch.empty(self.output_size))
@@ -46,7 +45,7 @@ class Linear(nn.Module):
with torch.no_grad():
self.bias.zero_()
else:
self.register_parameter('bias', None)
self.register_parameter("bias", None)
def forward(self, input_):
# Matrix multiply.
@@ -59,5 +58,7 @@ class Linear(nn.Module):
return output
def __repr__(self):
return f'Linear(in_features={self.input_size}, out_features={self.output_size}, ' + \
f'bias={self.bias is not None}, skip_bias_add={self.skip_bias_add})'
return (
f"Linear(in_features={self.input_size}, out_features={self.output_size}, "
+ f"bias={self.bias is not None}, skip_bias_add={self.skip_bias_add})"
)

View File

@@ -1,10 +1,10 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from .linear import Linear
from colossalai.kernel.jit import bias_gelu_impl
from .linear import Linear
class TransformerMLP(nn.Module):
"""MLP.
@@ -18,19 +18,13 @@ class TransformerMLP(nn.Module):
super(TransformerMLP, self).__init__()
# Project to 4h.
self.dense_h_to_4h = Linear(
hidden_size,
int(hidden_size*mlp_ratio),
skip_bias_add=True)
self.dense_h_to_4h = Linear(hidden_size, int(hidden_size * mlp_ratio), skip_bias_add=True)
self.bias_gelu_fusion = fuse_gelu
self.activation_func = F.gelu
# Project back to h.
self.dense_4h_to_h = Linear(
int(hidden_size*mlp_ratio),
hidden_size,
skip_bias_add=True)
self.dense_4h_to_h = Linear(int(hidden_size * mlp_ratio), hidden_size, skip_bias_add=True)
def forward(self, hidden_states):
# hidden states should be in the shape of [s, b, h]
@@ -39,11 +33,9 @@ class TransformerMLP(nn.Module):
intermediate_parallel, bias_parallel = self.dense_h_to_4h(hidden_states)
if self.bias_gelu_fusion:
intermediate_parallel = \
bias_gelu_impl(intermediate_parallel, bias_parallel)
intermediate_parallel = bias_gelu_impl(intermediate_parallel, bias_parallel)
else:
intermediate_parallel = \
self.activation_func(intermediate_parallel + bias_parallel)
intermediate_parallel = self.activation_func(intermediate_parallel + bias_parallel)
# [s, b, h]
output, output_bias = self.dense_4h_to_h(intermediate_parallel)

View File

@@ -1,5 +1,6 @@
import torch
import torch.nn as nn
from .linear import Linear

View File

@@ -6,7 +6,6 @@ from colossalai.legacy.core import global_context as gpc
class PreProcessor(nn.Module):
def __init__(self, sub_seq_length):
super().__init__()
self.sub_seq_length = sub_seq_length
@@ -15,10 +14,9 @@ class PreProcessor(nn.Module):
# Create position ids
seq_length = token_ids.size(1)
local_rank = gpc.get_local_rank(ParallelMode.SEQUENCE)
position_ids = torch.arange(seq_length * local_rank,
seq_length * (local_rank + 1),
dtype=torch.long,
device=token_ids.device)
position_ids = torch.arange(
seq_length * local_rank, seq_length * (local_rank + 1), dtype=torch.long, device=token_ids.device
)
position_ids = position_ids.unsqueeze(0).expand_as(token_ids)
return position_ids
@@ -42,7 +40,7 @@ class PreProcessor(nn.Module):
extended_attention_mask = attention_mask_bss.unsqueeze(1)
# Convert attention mask to binary:
extended_attention_mask = (extended_attention_mask < 0.5)
extended_attention_mask = extended_attention_mask < 0.5
return extended_attention_mask

View File

@@ -12,7 +12,6 @@ from colossalai.kernel import LayerNorm
from colossalai.legacy.amp import AMP_TYPE
from colossalai.legacy.context.parallel_mode import ParallelMode
from colossalai.legacy.core import global_context as gpc
from colossalai.legacy.engine.schedule import PipelineSchedule
from colossalai.legacy.utils import is_using_pp
from colossalai.logging import get_dist_logger
from colossalai.nn.optimizer import FusedAdam
@@ -31,7 +30,7 @@ def process_batch_data(batch_data):
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('-s', '--synthetic', action="store_true", help="whether use synthetic data")
parser.add_argument("-s", "--synthetic", action="store_true", help="whether use synthetic data")
return parser.parse_args()
@@ -48,37 +47,39 @@ def pipeline_data_process_func(stage_output, micro_batch_data):
def main():
# initialize
args = parse_args()
colossalai.launch_from_torch(config='./config.py', seed=1234, backend='nccl')
parse_args()
colossalai.launch_from_torch(config="./config.py", seed=1234, backend="nccl")
logger = get_dist_logger()
# build synthetic dataloader
BATCH_SIZE_PER_GPUS = gpc.config.GLOBAL_BATCH_SIZE // gpc.get_world_size(ParallelMode.DATA)
VOCAB_SIZE = 30528
trainloader = DummyDataloader(batch_size=BATCH_SIZE_PER_GPUS,
vocab_size=VOCAB_SIZE,
seq_length=gpc.config.SEQ_LENGTH)
validloader = DummyDataloader(batch_size=BATCH_SIZE_PER_GPUS,
vocab_size=VOCAB_SIZE,
seq_length=gpc.config.SEQ_LENGTH)
trainloader = DummyDataloader(
batch_size=BATCH_SIZE_PER_GPUS, vocab_size=VOCAB_SIZE, seq_length=gpc.config.SEQ_LENGTH
)
validloader = DummyDataloader(
batch_size=BATCH_SIZE_PER_GPUS, vocab_size=VOCAB_SIZE, seq_length=gpc.config.SEQ_LENGTH
)
logger.info("Dataloaders are built", ranks=[0])
# build model
if hasattr(gpc.config, 'fp16') and gpc.config.fp16.get('mode') == AMP_TYPE.NAIVE:
if hasattr(gpc.config, "fp16") and gpc.config.fp16.get("mode") == AMP_TYPE.NAIVE:
is_naive_fp16 = True
else:
is_naive_fp16 = False
use_pipeline = is_using_pp()
kwargs = dict(vocab_size=VOCAB_SIZE,
hidden_size=gpc.config.HIDDEN_SIZE,
max_sequence_length=gpc.config.SEQ_LENGTH,
num_attention_heads=gpc.config.NUM_ATTENTION_HEADS,
convert_fp16_to_fp32_in_softmax=True,
is_naive_fp16=is_naive_fp16,
add_binary_head=gpc.config.ADD_BINARY_HEAD)
kwargs = dict(
vocab_size=VOCAB_SIZE,
hidden_size=gpc.config.HIDDEN_SIZE,
max_sequence_length=gpc.config.SEQ_LENGTH,
num_attention_heads=gpc.config.NUM_ATTENTION_HEADS,
convert_fp16_to_fp32_in_softmax=True,
is_naive_fp16=is_naive_fp16,
add_binary_head=gpc.config.ADD_BINARY_HEAD,
)
if use_pipeline:
model = build_pipeline_bert(num_layers=gpc.config.DEPTH, num_chunks=1, **kwargs)
@@ -99,35 +100,39 @@ def main():
logger.info("Criterion is built", ranks=[0])
# layernorm and bias has no weight decay
weight_decay_params = {'params': []}
no_weight_decay_params = {'params': [], 'weight_decay': 0.0}
weight_decay_params = {"params": []}
no_weight_decay_params = {"params": [], "weight_decay": 0.0}
for module_ in model.modules():
if isinstance(module_, LayerNorm):
no_weight_decay_params['params'].extend([p for p in list(module_._parameters.values()) if p is not None])
no_weight_decay_params["params"].extend([p for p in list(module_._parameters.values()) if p is not None])
else:
weight_decay_params['params'].extend(
[p for n, p in list(module_._parameters.items()) if p is not None and n != 'bias'])
no_weight_decay_params['params'].extend(
[p for n, p in list(module_._parameters.items()) if p is not None and n == 'bias'])
weight_decay_params["params"].extend(
[p for n, p in list(module_._parameters.items()) if p is not None and n != "bias"]
)
no_weight_decay_params["params"].extend(
[p for n, p in list(module_._parameters.items()) if p is not None and n == "bias"]
)
logger.info(
f"without weight decay param: {len(no_weight_decay_params['params'])}, with weight decay param: {len(weight_decay_params['params'])}"
)
# optimizer
optimizer = FusedAdam((weight_decay_params, no_weight_decay_params),
lr=gpc.config.LR,
weight_decay=gpc.config.WEIGHT_DECAY)
optimizer = FusedAdam(
(weight_decay_params, no_weight_decay_params), lr=gpc.config.LR, weight_decay=gpc.config.WEIGHT_DECAY
)
logger.info("Optimizer is built", ranks=[0])
# lr scheduler
# follow Megatron-LM setting
warmup_steps = int(gpc.config.DECAY_ITERS * gpc.config.WARMUP_FRACTION)
lr_scheduler = AnnealingLR(optimizer=optimizer,
max_lr=gpc.config.LR,
min_lr=gpc.config.MIN_LR,
warmup_steps=warmup_steps,
decay_steps=gpc.config.DECAY_ITERS,
decay_style='linear')
lr_scheduler = AnnealingLR(
optimizer=optimizer,
max_lr=gpc.config.LR,
min_lr=gpc.config.MIN_LR,
warmup_steps=warmup_steps,
decay_steps=gpc.config.DECAY_ITERS,
decay_style="linear",
)
logger.info(f"LR Scheduler is built with {warmup_steps} warmup steps and {gpc.config.DECAY_ITERS} decay steps")
# # init
@@ -135,7 +140,6 @@ def main():
# build timer
timer = MultiTimer()
skip_iters = 0
# build loss tracker
accumulated_train_loss = torch.zeros(1, dtype=torch.float32).cuda()
@@ -150,7 +154,7 @@ def main():
logger.info("start training")
for step in range(1, gpc.config.TRAIN_ITERS + 1):
timer.start('train-iterations')
timer.start("train-iterations")
engine.train()
if use_pipeline:
engine.zero_grad()
@@ -158,13 +162,14 @@ def main():
engine.step()
else:
tokens, types, sentence_order, loss_mask, lm_labels, padding_mask = get_batch_for_sequence_parallel(
trainloader)
trainloader
)
engine.zero_grad()
lm_loss, sop_output = engine(tokens, padding_mask, types, lm_labels)
train_loss = engine.criterion(lm_loss, sop_output, loss_mask, sentence_order)
engine.backward(train_loss)
engine.step()
timer.stop('train-iterations', keep_in_history=True)
timer.stop("train-iterations", keep_in_history=True)
if not gpc.is_initialized(ParallelMode.PIPELINE) or gpc.is_last_rank(ParallelMode.PIPELINE):
accumulated_train_loss += train_loss
@@ -177,12 +182,18 @@ def main():
for j in range(gpc.config.EVAL_ITERS):
with torch.no_grad():
if use_pipeline:
_, _, eval_loss = engine.execute_schedule(valid_data_iter,
forward_only=True,
return_output_label=False)
_, _, eval_loss = engine.execute_schedule(
valid_data_iter, forward_only=True, return_output_label=False
)
else:
tokens, types, sentence_order, loss_mask, lm_labels, padding_mask = get_batch_for_sequence_parallel(
validloader)
(
tokens,
types,
sentence_order,
loss_mask,
lm_labels,
padding_mask,
) = get_batch_for_sequence_parallel(validloader)
lm_loss, sop_output = engine(tokens, padding_mask, types, lm_labels)
eval_loss = engine.criterion(lm_loss, sop_output, loss_mask, sentence_order)
@@ -196,18 +207,22 @@ def main():
timer_string = []
for n, t in timer:
timer_string.append(f"{n}: {t.get_history_mean()*1000:.5f}")
timer_string = ' | '.join(timer_string)
lr = list(engine.optimizer.param_groups)[0]['lr']
timer_string = " | ".join(timer_string)
lr = list(engine.optimizer.param_groups)[0]["lr"]
loss_scale = engine.optimizer.optim.loss_scale.item()
if gpc.is_initialized(ParallelMode.PIPELINE):
ranks = [gpc.get_ranks_in_group(ParallelMode.PIPELINE)[-1]]
else:
ranks = [0]
logger.info(f'Step {step} / {gpc.config.TRAIN_ITERS} | Train Loss: {accumulated_train_loss.item():.5g} ' +
f'| Eval Loss: {accumulated_eval_loss.item():.5g} ' + f'| Loss Scale: {loss_scale}' +
f"| Learning rate: {lr} | " + timer_string,
ranks=ranks)
logger.info(
f"Step {step} / {gpc.config.TRAIN_ITERS} | Train Loss: {accumulated_train_loss.item():.5g} "
+ f"| Eval Loss: {accumulated_eval_loss.item():.5g} "
+ f"| Loss Scale: {loss_scale}"
+ f"| Learning rate: {lr} | "
+ timer_string,
ranks=ranks,
)
for n, t in timer:
t.reset()
@@ -215,5 +230,5 @@ def main():
accumulated_train_loss.zero_()
if __name__ == '__main__':
if __name__ == "__main__":
main()