[refactor] pipeline, put runtime schedule into engine. (#627)

This commit is contained in:
YuliangLiu0306
2022-04-03 20:46:45 +08:00
committed by GitHub
parent e5d615aeee
commit ade05a5d83
9 changed files with 68 additions and 49 deletions

View File

@@ -7,6 +7,7 @@ IMG_SIZE = 224
DIM = 768
NUM_CLASSES = 10
NUM_ATTN_HEADS = 12
NUM_MICRO_BATCHES = 2
# resnet 18
model = dict(type='VanillaResNet',

View File

@@ -19,7 +19,6 @@ from torchvision import transforms
from torchvision.datasets import CIFAR10
BATCH_SIZE = 4
NUM_MICRO = 2
DIR_PATH = osp.dirname(osp.realpath(__file__))
CONFIG_PATH = osp.join(DIR_PATH, './resnet_config.py')
@@ -57,7 +56,7 @@ def run_schedule(rank, world_size, port):
engine, train_dataloader, _, _ = colossalai.initialize(model, optimizer, criterion, train_dataloader)
# build pipeline schedule
schedule = PipelineSchedule(num_microbatches=NUM_MICRO)
schedule = engine.schedule
# run schedule
data_iter = iter(train_dataloader)