[refactor] pipeline, put runtime schedule into engine. (#627)

This commit is contained in:
YuliangLiu0306
2022-04-03 20:46:45 +08:00
committed by GitHub
parent e5d615aeee
commit ade05a5d83
9 changed files with 68 additions and 49 deletions

View File

@@ -23,9 +23,9 @@ from torchvision.datasets import CIFAR10
BATCH_SIZE = 4
NUM_EPOCHS = 60
WARMUP_EPOCHS = 5
CONFIG = dict(parallel=dict(pipeline=2, tensor=dict(size=2, mode='1d')),
fp16=dict(mode=AMP_TYPE.NAIVE),
gradient_accumulation=2)
CONFIG = dict(NUM_MICRO_BATCHES=2, parallel=dict(pipeline=2, tensor=dict(size=2, mode='1d')),
fp16=dict(mode=AMP_TYPE.NAIVE),
gradient_accumulation=2)
def run_trainer(rank, world_size, port):
@@ -63,10 +63,9 @@ def run_trainer(rank, world_size, port):
train_dataloader,
lr_scheduler=lr_scheduler)
schedule = PipelineSchedule(num_microbatches=2)
logger = get_dist_logger()
trainer = Trainer(engine=engine, logger=logger, schedule=schedule)
trainer = Trainer(engine=engine, logger=logger)
hook_list = [
hooks.LRSchedulerHook(lr_scheduler=lr_scheduler, by_epoch=False),