mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-02 09:38:05 +00:00
[refactor] pipeline, put runtime schedule into engine. (#627)
This commit is contained in:
@@ -23,9 +23,9 @@ from torchvision.datasets import CIFAR10
|
||||
BATCH_SIZE = 4
|
||||
NUM_EPOCHS = 60
|
||||
WARMUP_EPOCHS = 5
|
||||
CONFIG = dict(parallel=dict(pipeline=2, tensor=dict(size=2, mode='1d')),
|
||||
fp16=dict(mode=AMP_TYPE.NAIVE),
|
||||
gradient_accumulation=2)
|
||||
CONFIG = dict(NUM_MICRO_BATCHES=2, parallel=dict(pipeline=2, tensor=dict(size=2, mode='1d')),
|
||||
fp16=dict(mode=AMP_TYPE.NAIVE),
|
||||
gradient_accumulation=2)
|
||||
|
||||
|
||||
def run_trainer(rank, world_size, port):
|
||||
@@ -63,10 +63,9 @@ def run_trainer(rank, world_size, port):
|
||||
train_dataloader,
|
||||
lr_scheduler=lr_scheduler)
|
||||
|
||||
schedule = PipelineSchedule(num_microbatches=2)
|
||||
logger = get_dist_logger()
|
||||
|
||||
trainer = Trainer(engine=engine, logger=logger, schedule=schedule)
|
||||
trainer = Trainer(engine=engine, logger=logger)
|
||||
|
||||
hook_list = [
|
||||
hooks.LRSchedulerHook(lr_scheduler=lr_scheduler, by_epoch=False),
|
||||
|
@@ -7,6 +7,7 @@ IMG_SIZE = 224
|
||||
DIM = 768
|
||||
NUM_CLASSES = 10
|
||||
NUM_ATTN_HEADS = 12
|
||||
NUM_MICRO_BATCHES = 2
|
||||
|
||||
# resnet 18
|
||||
model = dict(type='VanillaResNet',
|
||||
|
@@ -19,7 +19,6 @@ from torchvision import transforms
|
||||
from torchvision.datasets import CIFAR10
|
||||
|
||||
BATCH_SIZE = 4
|
||||
NUM_MICRO = 2
|
||||
|
||||
DIR_PATH = osp.dirname(osp.realpath(__file__))
|
||||
CONFIG_PATH = osp.join(DIR_PATH, './resnet_config.py')
|
||||
@@ -57,7 +56,7 @@ def run_schedule(rank, world_size, port):
|
||||
engine, train_dataloader, _, _ = colossalai.initialize(model, optimizer, criterion, train_dataloader)
|
||||
|
||||
# build pipeline schedule
|
||||
schedule = PipelineSchedule(num_microbatches=NUM_MICRO)
|
||||
schedule = engine.schedule
|
||||
|
||||
# run schedule
|
||||
data_iter = iter(train_dataloader)
|
||||
|
@@ -23,7 +23,7 @@ BATCH_SIZE = 4
|
||||
IMG_SIZE = 32
|
||||
NUM_EPOCHS = 200
|
||||
|
||||
CONFIG = dict(parallel=dict(pipeline=2),)
|
||||
CONFIG = dict(NUM_MICRO_BATCHES=2, parallel=dict(pipeline=2),)
|
||||
|
||||
|
||||
def run_trainer_with_pipeline(rank, world_size, port):
|
||||
@@ -69,9 +69,8 @@ def run_trainer_with_pipeline(rank, world_size, port):
|
||||
|
||||
logger = get_dist_logger()
|
||||
logger.info("engine is built", ranks=[0])
|
||||
pipe_schedule = PipelineSchedule(num_microbatches=2)
|
||||
timer = MultiTimer()
|
||||
trainer = Trainer(engine=engine, schedule=pipe_schedule, logger=logger, timer=timer)
|
||||
trainer = Trainer(engine=engine, logger=logger, timer=timer)
|
||||
logger.info("trainer is built", ranks=[0])
|
||||
|
||||
logger.info("start training", ranks=[0])
|
||||
|
Reference in New Issue
Block a user