Support TP-compatible Torch AMP and Update trainer API (#27)

* Add gradient accumulation, fix lr scheduler

* fix FP16 optimizer and adapted torch amp with tensor parallel (#18)

* fixed bugs in compatibility between torch amp and tensor parallel and performed some minor fixes

* fixed trainer

* Revert "fixed trainer"

This reverts commit 2e0b0b7699.

* improved consistency between trainer, engine and schedule (#23)

Co-authored-by: 1SAA <c2h214748@gmail.com>

Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: ver217 <lhx0217@gmail.com>
This commit is contained in:
Frank Lee
2021-11-18 19:45:06 +08:00
committed by GitHub
parent 2b05de4c64
commit 3defa32aee
80 changed files with 2194 additions and 1584 deletions

View File

@@ -38,5 +38,3 @@ optimizer = dict(type='Adam', lr=0.001)
loss = dict(type='CrossEntropyLoss')
# set_device_func = lambda global_rank, world_size: global_rank % 4
seed = 1024

View File

@@ -40,6 +40,3 @@ optimizer = dict(type='Adam', lr=0.001)
loss = dict(type='CrossEntropyLoss')
fp16 = dict(mode=AMP_TYPE.APEX)
# set_device_func = lambda global_rank, world_size: global_rank % 4
seed = 1024

View File

@@ -40,6 +40,3 @@ optimizer = dict(type='Adam', lr=0.001)
loss = dict(type='CrossEntropyLoss')
fp16 = dict(mode=AMP_TYPE.TORCH)
# set_device_func = lambda global_rank, world_size: global_rank % 4
seed = 1024

View File

@@ -38,11 +38,9 @@ parallel = dict(
tensor=dict(size=1, mode=None)
)
schedule = dict(
num_microbatches=4
engine = dict(
schedule=dict(
num_microbatches=4
)
)
num_pipeling_batches = 2
seed = 1024
lr_scheduler = dict(type='LinearWarmupLR', warmup_steps=5)
num_epochs = 10

View File

@@ -8,7 +8,6 @@ import torch
from colossalai import initialize
from colossalai.core import global_context as gpc
from colossalai.engine import Engine
from colossalai.logging import get_global_dist_logger
from colossalai.utils import report_memory_usage
@@ -24,20 +23,13 @@ NO_PIPE_CONFIG_PATH = osp.join(DIR_PATH, '../configs/non_pipeline_resnet_apex_am
def run_no_pipeline(config):
model, train_dataloader, test_dataloader, criterion, optimizer, schedule, lr_scheduler = initialize(config)
engine, train_dataloader, test_dataloader = initialize(config)
logger = get_global_dist_logger()
rank = torch.distributed.get_rank()
engine = Engine(model=model,
train_dataloader=train_dataloader,
criterion=criterion,
optimizer=optimizer,
schedule=schedule)
engine.train()
logger.info('lr = %g' % engine.get_lr())
output, label, loss = engine.step()
output, label, loss = engine.step(iter(train_dataloader))
logger.info('Rank {} returns: {}'.format(rank, loss.item()))
logger.info('lr = %g' % engine.get_lr())
gpc.destroy()
logger.info('Test engine finished')

View File

@@ -8,7 +8,6 @@ import torch
from colossalai import initialize
from colossalai.core import global_context as gpc
from colossalai.engine import Engine
from colossalai.logging import get_global_dist_logger
from colossalai.utils import report_memory_usage
@@ -26,21 +25,14 @@ NO_PIPE_CONFIG_PATH = osp.join(DIR_PATH, '../configs/non_pipeline_resnet.py')
def test_no_pipeline(config):
print('Test no pipeline engine start')
model, train_dataloader, test_dataloader, criterion, optimizer, schedule, lr_scheduler = initialize(config)
engine, train_dataloader, test_dataloader = initialize(config)
logger = get_global_dist_logger()
rank = torch.distributed.get_rank()
engine = Engine(model=model,
train_dataloader=train_dataloader,
criterion=criterion,
optimizer=optimizer,
schedule=schedule)
engine.train()
logger.info('lr = %g' % engine.get_lr())
output, label, loss = engine.step()
output, label, loss = engine.step(iter(train_dataloader))
logger.info('Rank {} returns: {}'.format(rank, loss.item()))
logger.info('lr = %g' % engine.get_lr())
gpc.destroy()
logger.info('Test engine finished')

View File

@@ -8,7 +8,6 @@ import torch
from colossalai import initialize
from colossalai.core import global_context as gpc
from colossalai.engine import Engine
from colossalai.logging import get_global_dist_logger
from colossalai.utils import report_memory_usage
@@ -26,21 +25,13 @@ NO_PIPE_CONFIG_PATH = osp.join(DIR_PATH, '../configs/non_pipeline_resnet_torch_a
def test_no_pipeline(config):
print('Test no pipeline engine start')
model, train_dataloader, test_dataloader, criterion, optimizer, schedule, lr_scheduler = initialize(config)
engine, train_dataloader, test_dataloader = initialize(config)
logger = get_global_dist_logger()
rank = torch.distributed.get_rank()
engine = Engine(model=model,
train_dataloader=train_dataloader,
criterion=criterion,
optimizer=optimizer,
schedule=schedule)
engine.train()
logger.info('lr = %g' % engine.get_lr())
output, label, loss = engine.step()
output, label, loss = engine.step(iter(train_dataloader))
logger.info('Rank {} returns: {}'.format(rank, loss.item()))
logger.info('lr = %g' % engine.get_lr())
gpc.destroy()
logger.info('Test engine finished')

View File

@@ -5,6 +5,7 @@ import os.path as osp
import pytest
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.initialize import initialize
from colossalai.logging import get_global_dist_logger
@@ -22,13 +23,25 @@ CONFIG_PATH = osp.join(DIR_PATH, '../configs/pipeline_vanilla_resnet.py')
@pytest.mark.skip("This test should be invoked using the test.sh provided")
@pytest.mark.dist
def test_schedule():
model, train_dataloader, test_dataloader, criterion, optimizer, schedule, lr_scheduler = initialize(CONFIG_PATH)
engine, train_dataloader, test_dataloader = initialize(CONFIG_PATH)
logger = get_global_dist_logger()
schedule.zero_grad()
output, label, losses = schedule.forward_backward_step(forward_only=False)
schedule.step()
logger.info('losses: {}'.format([loss.item() for loss in losses]))
model = engine.model
optimizer = engine.optimizer
criterion = engine.criterion
schedule = engine._schedule
output, label, loss = schedule.forward_backward_step(
data_iter=iter(train_dataloader),
model=model,
optimizer=optimizer,
criterion=criterion,
forward_only=False
)
schedule.optimizer_step(model, optimizer)
if gpc.is_last_rank(ParallelMode.PIPELINE):
logger.info('losses: {}'.format(loss))
gpc.destroy()
logger.info('training finished')

View File

@@ -9,7 +9,6 @@ import torch
from colossalai import initialize
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.engine import Engine
from colossalai.logging import get_global_dist_logger
NUM_BATCH = 128
@@ -23,22 +22,14 @@ PIPE_CONFIG_PATH = osp.join(DIR_PATH, '../configs/pipeline_vanilla_resnet.py')
def run_pipeline(config):
model, train_dataloader, test_dataloader, criterion, optimizer, schedule, lr_scheduler = initialize(config)
engine, train_dataloader, test_dataloader = initialize(config)
logger = get_global_dist_logger()
rank = torch.distributed.get_rank()
engine = Engine(model=model,
train_dataloader=train_dataloader,
criterion=criterion,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
schedule=schedule)
engine.train()
logger.info('lr = %g' % engine.get_lr())
outputs, labels, loss = engine.step()
outputs, labels, loss = engine.step(iter(train_dataloader))
if gpc.is_last_rank(ParallelMode.PIPELINE):
logger.info('losses: {}'.format(rank, loss.item()))
logger.info('lr = %g' % engine.get_lr())
gpc.destroy()
logger.info('Test engine pipeline finished')