mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-09 13:00:52 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -1,29 +1,34 @@
|
||||
from .blendable_dataset import BlendableDataset
|
||||
from .dataset_utils import get_datasets_weights_and_num_samples, get_indexed_dataset_, get_train_valid_test_split_
|
||||
from .bert_dataset import BertDataset
|
||||
from colossalai.logging import get_dist_logger
|
||||
|
||||
DSET_TYPE_BERT = 'standard_bert'
|
||||
DSET_TYPE_ICT = 'ict'
|
||||
DSET_TYPE_T5 = 't5'
|
||||
from .bert_dataset import BertDataset
|
||||
from .blendable_dataset import BlendableDataset
|
||||
from .dataset_utils import get_datasets_weights_and_num_samples, get_indexed_dataset_, get_train_valid_test_split_
|
||||
|
||||
DSET_TYPE_BERT = "standard_bert"
|
||||
DSET_TYPE_ICT = "ict"
|
||||
DSET_TYPE_T5 = "t5"
|
||||
|
||||
DSET_TYPES = [DSET_TYPE_BERT, DSET_TYPE_ICT, DSET_TYPE_T5]
|
||||
|
||||
|
||||
def _build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
|
||||
train_valid_test_num_samples,
|
||||
max_seq_length, masked_lm_prob,
|
||||
short_seq_prob, seed, skip_warmup,
|
||||
binary_head,
|
||||
dataset_type='standard_bert'):
|
||||
|
||||
def _build_train_valid_test_datasets(
|
||||
data_prefix,
|
||||
data_impl,
|
||||
splits_string,
|
||||
train_valid_test_num_samples,
|
||||
max_seq_length,
|
||||
masked_lm_prob,
|
||||
short_seq_prob,
|
||||
seed,
|
||||
skip_warmup,
|
||||
binary_head,
|
||||
dataset_type="standard_bert",
|
||||
):
|
||||
if dataset_type not in DSET_TYPES:
|
||||
raise ValueError("Invalid dataset_type: ", dataset_type)
|
||||
|
||||
# Indexed dataset.
|
||||
indexed_dataset = get_indexed_dataset_(data_prefix,
|
||||
data_impl,
|
||||
skip_warmup)
|
||||
indexed_dataset = get_indexed_dataset_(data_prefix, data_impl, skip_warmup)
|
||||
|
||||
# Get start and end indices of train/valid/train into doc-idx
|
||||
# Note that doc-idx is designed to be num-docs + 1 so we can
|
||||
@@ -34,22 +39,25 @@ def _build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
|
||||
logger = get_dist_logger()
|
||||
|
||||
# Print stats about the splits.
|
||||
logger.info('\n > dataset split:', ranks=[0])
|
||||
logger.info("\n > dataset split:", ranks=[0])
|
||||
|
||||
def print_split_stats(name, index):
|
||||
start_index = indexed_dataset.doc_idx[splits[index]]
|
||||
end_index = indexed_dataset.doc_idx[splits[index + 1]]
|
||||
logger.info('\n {}:'.format(name) +
|
||||
'\n document indices in [{}, {}) total of {} documents'.format(
|
||||
splits[index], splits[index + 1],
|
||||
splits[index + 1] - splits[index]) +
|
||||
'\n sentence indices in [{}, {}) total of {} sentences'.format(
|
||||
start_index, end_index,
|
||||
end_index - start_index),
|
||||
ranks=[0])
|
||||
print_split_stats('train', 0)
|
||||
print_split_stats('validation', 1)
|
||||
print_split_stats('test', 2)
|
||||
logger.info(
|
||||
"\n {}:".format(name)
|
||||
+ "\n document indices in [{}, {}) total of {} documents".format(
|
||||
splits[index], splits[index + 1], splits[index + 1] - splits[index]
|
||||
)
|
||||
+ "\n sentence indices in [{}, {}) total of {} sentences".format(
|
||||
start_index, end_index, end_index - start_index
|
||||
),
|
||||
ranks=[0],
|
||||
)
|
||||
|
||||
print_split_stats("train", 0)
|
||||
print_split_stats("validation", 1)
|
||||
print_split_stats("test", 2)
|
||||
|
||||
def build_dataset(index, name):
|
||||
dataset = None
|
||||
@@ -80,44 +88,53 @@ def _build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
|
||||
masked_lm_prob=masked_lm_prob,
|
||||
short_seq_prob=short_seq_prob,
|
||||
binary_head=binary_head,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Set the original pointer so dataset remains the main dataset.
|
||||
indexed_dataset.set_doc_idx(doc_idx_ptr)
|
||||
# Checks.
|
||||
assert indexed_dataset.doc_idx[0] == 0
|
||||
assert indexed_dataset.doc_idx.shape[0] == \
|
||||
(total_num_of_documents + 1)
|
||||
assert indexed_dataset.doc_idx.shape[0] == (total_num_of_documents + 1)
|
||||
return dataset
|
||||
|
||||
train_dataset = build_dataset(0, 'train')
|
||||
valid_dataset = build_dataset(1, 'valid')
|
||||
test_dataset = build_dataset(2, 'test')
|
||||
train_dataset = build_dataset(0, "train")
|
||||
valid_dataset = build_dataset(1, "valid")
|
||||
test_dataset = build_dataset(2, "test")
|
||||
|
||||
return (train_dataset, valid_dataset, test_dataset)
|
||||
|
||||
|
||||
def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
|
||||
train_valid_test_num_samples,
|
||||
max_seq_length, masked_lm_prob,
|
||||
short_seq_prob, seed, skip_warmup,
|
||||
binary_head,
|
||||
dataset_type='standard_bert'):
|
||||
|
||||
def build_train_valid_test_datasets(
|
||||
data_prefix,
|
||||
data_impl,
|
||||
splits_string,
|
||||
train_valid_test_num_samples,
|
||||
max_seq_length,
|
||||
masked_lm_prob,
|
||||
short_seq_prob,
|
||||
seed,
|
||||
skip_warmup,
|
||||
binary_head,
|
||||
dataset_type="standard_bert",
|
||||
):
|
||||
if len(data_prefix) == 1:
|
||||
return _build_train_valid_test_datasets(data_prefix[0],
|
||||
data_impl, splits_string,
|
||||
train_valid_test_num_samples,
|
||||
max_seq_length, masked_lm_prob,
|
||||
short_seq_prob, seed,
|
||||
skip_warmup,
|
||||
binary_head,
|
||||
dataset_type=dataset_type)
|
||||
return _build_train_valid_test_datasets(
|
||||
data_prefix[0],
|
||||
data_impl,
|
||||
splits_string,
|
||||
train_valid_test_num_samples,
|
||||
max_seq_length,
|
||||
masked_lm_prob,
|
||||
short_seq_prob,
|
||||
seed,
|
||||
skip_warmup,
|
||||
binary_head,
|
||||
dataset_type=dataset_type,
|
||||
)
|
||||
# Blending dataset.
|
||||
# Parse the values.
|
||||
output = get_datasets_weights_and_num_samples(data_prefix,
|
||||
train_valid_test_num_samples)
|
||||
output = get_datasets_weights_and_num_samples(data_prefix, train_valid_test_num_samples)
|
||||
prefixes, weights, datasets_train_valid_test_num_samples = output
|
||||
|
||||
# Build individual datasets.
|
||||
@@ -126,10 +143,18 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
|
||||
test_datasets = []
|
||||
for i in range(len(prefixes)):
|
||||
train_ds, valid_ds, test_ds = _build_train_valid_test_datasets(
|
||||
prefixes[i], data_impl, splits_string,
|
||||
prefixes[i],
|
||||
data_impl,
|
||||
splits_string,
|
||||
datasets_train_valid_test_num_samples[i],
|
||||
max_seq_length, masked_lm_prob, short_seq_prob,
|
||||
seed, skip_warmup, binary_head, dataset_type=dataset_type)
|
||||
max_seq_length,
|
||||
masked_lm_prob,
|
||||
short_seq_prob,
|
||||
seed,
|
||||
skip_warmup,
|
||||
binary_head,
|
||||
dataset_type=dataset_type,
|
||||
)
|
||||
if train_ds:
|
||||
train_datasets.append(train_ds)
|
||||
if valid_ds:
|
||||
@@ -148,5 +173,4 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
|
||||
if test_datasets:
|
||||
blending_test_dataset = BlendableDataset(test_datasets, weights)
|
||||
|
||||
return (blending_train_dataset, blending_valid_dataset,
|
||||
blending_test_dataset)
|
||||
return (blending_train_dataset, blending_valid_dataset, blending_test_dataset)
|
||||
|
Reference in New Issue
Block a user