mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-24 19:17:30 +00:00
[refactor] pipeline, put runtime schedule into engine. (#627)
This commit is contained in:
@@ -7,6 +7,7 @@ IMG_SIZE = 224
|
||||
DIM = 768
|
||||
NUM_CLASSES = 10
|
||||
NUM_ATTN_HEADS = 12
|
||||
NUM_MICRO_BATCHES = 2
|
||||
|
||||
# resnet 18
|
||||
model = dict(type='VanillaResNet',
|
||||
|
@@ -19,7 +19,6 @@ from torchvision import transforms
|
||||
from torchvision.datasets import CIFAR10
|
||||
|
||||
BATCH_SIZE = 4
|
||||
NUM_MICRO = 2
|
||||
|
||||
DIR_PATH = osp.dirname(osp.realpath(__file__))
|
||||
CONFIG_PATH = osp.join(DIR_PATH, './resnet_config.py')
|
||||
@@ -57,7 +56,7 @@ def run_schedule(rank, world_size, port):
|
||||
engine, train_dataloader, _, _ = colossalai.initialize(model, optimizer, criterion, train_dataloader)
|
||||
|
||||
# build pipeline schedule
|
||||
schedule = PipelineSchedule(num_microbatches=NUM_MICRO)
|
||||
schedule = engine.schedule
|
||||
|
||||
# run schedule
|
||||
data_iter = iter(train_dataloader)
|
||||
|
Reference in New Issue
Block a user