mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-29 21:46:06 +00:00
[example] integrate seq-parallel tutorial with CI (#2463)
This commit is contained in:
@@ -1,9 +1,8 @@
|
||||
import argparse
|
||||
|
||||
import torch
|
||||
from data import build_train_valid_test_data_iterators
|
||||
from data.bert_helper import SequenceParallelDataIterator, get_batch_for_sequence_parallel
|
||||
from data.tokenizer import get_padded_vocab_size, initialize_tokenizer
|
||||
from data.dummy_dataloader import DummyDataloader
|
||||
from loss_func.bert_loss import BertLoss
|
||||
from lr_scheduler import AnnealingLR
|
||||
from model.bert import BertForPretrain, build_pipeline_bert
|
||||
@@ -36,7 +35,7 @@ def parse_args():
|
||||
|
||||
|
||||
def pipeline_data_process_func(stage_output, micro_batch_data):
|
||||
tokens, types, sentence_order, loss_mask, lm_labels, padding_mask = micro_batch_data
|
||||
tokens, types, sentence_order, loss_mask, lm_labels, padding_mask = micro_batch_data
|
||||
if gpc.is_first_rank(ParallelMode.PIPELINE):
|
||||
data = (tokens, padding_mask, types, lm_labels)
|
||||
label = (loss_mask, sentence_order)
|
||||
@@ -53,36 +52,15 @@ def main():
|
||||
|
||||
logger = get_dist_logger()
|
||||
|
||||
# build dataloader
|
||||
if not args.synthetic:
|
||||
initialize_tokenizer(gpc.config.VOCAB_FILE_PATH, tokenizer_type='BertWordPieceLowerCase')
|
||||
VOCAB_SIZE = get_padded_vocab_size()
|
||||
trainloader, validloader, testloader = build_train_valid_test_data_iterators(
|
||||
train_iters=gpc.config.TRAIN_ITERS,
|
||||
global_batch_size=gpc.config.GLOBAL_BATCH_SIZE,
|
||||
eval_interval=gpc.config.EVAL_INTERVAL,
|
||||
eval_iters=gpc.config.EVAL_ITERS,
|
||||
data_prefix=[gpc.config.DATA_PATH],
|
||||
data_impl='mmap',
|
||||
splits_string='949,50,1',
|
||||
max_seq_length=gpc.config.SEQ_LENGTH,
|
||||
masked_lm_prob=0.15,
|
||||
short_seq_prob=0.1,
|
||||
seed=1234,
|
||||
skip_warmup=True,
|
||||
binary_head=False,
|
||||
)
|
||||
else:
|
||||
from data.dummy_dataloader import DummyDataloader
|
||||
|
||||
BATCH_SIZE_PER_GPUS = gpc.config.GLOBAL_BATCH_SIZE // gpc.get_world_size(ParallelMode.DATA)
|
||||
VOCAB_SIZE = 30528
|
||||
trainloader = DummyDataloader(batch_size=BATCH_SIZE_PER_GPUS,
|
||||
vocab_size=VOCAB_SIZE,
|
||||
seq_length=gpc.config.SEQ_LENGTH)
|
||||
validloader = DummyDataloader(batch_size=BATCH_SIZE_PER_GPUS,
|
||||
vocab_size=VOCAB_SIZE,
|
||||
seq_length=gpc.config.SEQ_LENGTH)
|
||||
# build synthetic dataloader
|
||||
BATCH_SIZE_PER_GPUS = gpc.config.GLOBAL_BATCH_SIZE // gpc.get_world_size(ParallelMode.DATA)
|
||||
VOCAB_SIZE = 30528
|
||||
trainloader = DummyDataloader(batch_size=BATCH_SIZE_PER_GPUS,
|
||||
vocab_size=VOCAB_SIZE,
|
||||
seq_length=gpc.config.SEQ_LENGTH)
|
||||
validloader = DummyDataloader(batch_size=BATCH_SIZE_PER_GPUS,
|
||||
vocab_size=VOCAB_SIZE,
|
||||
seq_length=gpc.config.SEQ_LENGTH)
|
||||
|
||||
logger.info("Dataloaders are built", ranks=[0])
|
||||
|
||||
|
Reference in New Issue
Block a user