[refactor] pipeline, put runtime schedule into engine. (#627)

This commit is contained in:
YuliangLiu0306
2022-04-03 20:46:45 +08:00
committed by GitHub
parent e5d615aeee
commit ade05a5d83
9 changed files with 68 additions and 49 deletions

View File

@@ -23,9 +23,9 @@ from torchvision.datasets import CIFAR10
BATCH_SIZE = 4
NUM_EPOCHS = 60
WARMUP_EPOCHS = 5
CONFIG = dict(parallel=dict(pipeline=2, tensor=dict(size=2, mode='1d')),
fp16=dict(mode=AMP_TYPE.NAIVE),
gradient_accumulation=2)
CONFIG = dict(NUM_MICRO_BATCHES=2, parallel=dict(pipeline=2, tensor=dict(size=2, mode='1d')),
fp16=dict(mode=AMP_TYPE.NAIVE),
gradient_accumulation=2)
def run_trainer(rank, world_size, port):
@@ -63,10 +63,9 @@ def run_trainer(rank, world_size, port):
train_dataloader,
lr_scheduler=lr_scheduler)
schedule = PipelineSchedule(num_microbatches=2)
logger = get_dist_logger()
trainer = Trainer(engine=engine, logger=logger, schedule=schedule)
trainer = Trainer(engine=engine, logger=logger)
hook_list = [
hooks.LRSchedulerHook(lr_scheduler=lr_scheduler, by_epoch=False),

View File

@@ -7,6 +7,7 @@ IMG_SIZE = 224
DIM = 768
NUM_CLASSES = 10
NUM_ATTN_HEADS = 12
NUM_MICRO_BATCHES = 2
# resnet 18
model = dict(type='VanillaResNet',

View File

@@ -19,7 +19,6 @@ from torchvision import transforms
from torchvision.datasets import CIFAR10
BATCH_SIZE = 4
NUM_MICRO = 2
DIR_PATH = osp.dirname(osp.realpath(__file__))
CONFIG_PATH = osp.join(DIR_PATH, './resnet_config.py')
@@ -57,7 +56,7 @@ def run_schedule(rank, world_size, port):
engine, train_dataloader, _, _ = colossalai.initialize(model, optimizer, criterion, train_dataloader)
# build pipeline schedule
schedule = PipelineSchedule(num_microbatches=NUM_MICRO)
schedule = engine.schedule
# run schedule
data_iter = iter(train_dataloader)

View File

@@ -23,7 +23,7 @@ BATCH_SIZE = 4
IMG_SIZE = 32
NUM_EPOCHS = 200
CONFIG = dict(parallel=dict(pipeline=2),)
CONFIG = dict(NUM_MICRO_BATCHES=2, parallel=dict(pipeline=2),)
def run_trainer_with_pipeline(rank, world_size, port):
@@ -69,9 +69,8 @@ def run_trainer_with_pipeline(rank, world_size, port):
logger = get_dist_logger()
logger.info("engine is built", ranks=[0])
pipe_schedule = PipelineSchedule(num_microbatches=2)
timer = MultiTimer()
trainer = Trainer(engine=engine, schedule=pipe_schedule, logger=logger, timer=timer)
trainer = Trainer(engine=engine, logger=logger, timer=timer)
logger.info("trainer is built", ranks=[0])
logger.info("start training", ranks=[0])