mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-13 13:11:05 +00:00
Support TP-compatible Torch AMP and Update trainer API (#27)
* Add gradient accumulation, fix lr scheduler
* fix FP16 optimizer and adapted torch amp with tensor parallel (#18)
* fixed bugs in compatibility between torch amp and tensor parallel and performed some minor fixes
* fixed trainer
* Revert "fixed trainer"
This reverts commit 2e0b0b7699
.
* improved consistency between trainer, engine and schedule (#23)
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: ver217 <lhx0217@gmail.com>
This commit is contained in:
@@ -38,5 +38,3 @@ optimizer = dict(type='Adam', lr=0.001)
|
||||
|
||||
loss = dict(type='CrossEntropyLoss')
|
||||
|
||||
# set_device_func = lambda global_rank, world_size: global_rank % 4
|
||||
seed = 1024
|
||||
|
@@ -40,6 +40,3 @@ optimizer = dict(type='Adam', lr=0.001)
|
||||
|
||||
loss = dict(type='CrossEntropyLoss')
|
||||
fp16 = dict(mode=AMP_TYPE.APEX)
|
||||
|
||||
# set_device_func = lambda global_rank, world_size: global_rank % 4
|
||||
seed = 1024
|
||||
|
@@ -40,6 +40,3 @@ optimizer = dict(type='Adam', lr=0.001)
|
||||
|
||||
loss = dict(type='CrossEntropyLoss')
|
||||
fp16 = dict(mode=AMP_TYPE.TORCH)
|
||||
|
||||
# set_device_func = lambda global_rank, world_size: global_rank % 4
|
||||
seed = 1024
|
||||
|
@@ -38,11 +38,9 @@ parallel = dict(
|
||||
tensor=dict(size=1, mode=None)
|
||||
)
|
||||
|
||||
schedule = dict(
|
||||
num_microbatches=4
|
||||
engine = dict(
|
||||
schedule=dict(
|
||||
num_microbatches=4
|
||||
)
|
||||
)
|
||||
num_pipeling_batches = 2
|
||||
seed = 1024
|
||||
lr_scheduler = dict(type='LinearWarmupLR', warmup_steps=5)
|
||||
|
||||
num_epochs = 10
|
||||
|
@@ -8,7 +8,6 @@ import torch
|
||||
|
||||
from colossalai import initialize
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.engine import Engine
|
||||
from colossalai.logging import get_global_dist_logger
|
||||
from colossalai.utils import report_memory_usage
|
||||
|
||||
@@ -24,20 +23,13 @@ NO_PIPE_CONFIG_PATH = osp.join(DIR_PATH, '../configs/non_pipeline_resnet_apex_am
|
||||
|
||||
|
||||
def run_no_pipeline(config):
|
||||
model, train_dataloader, test_dataloader, criterion, optimizer, schedule, lr_scheduler = initialize(config)
|
||||
engine, train_dataloader, test_dataloader = initialize(config)
|
||||
logger = get_global_dist_logger()
|
||||
rank = torch.distributed.get_rank()
|
||||
|
||||
engine = Engine(model=model,
|
||||
train_dataloader=train_dataloader,
|
||||
criterion=criterion,
|
||||
optimizer=optimizer,
|
||||
schedule=schedule)
|
||||
engine.train()
|
||||
logger.info('lr = %g' % engine.get_lr())
|
||||
output, label, loss = engine.step()
|
||||
output, label, loss = engine.step(iter(train_dataloader))
|
||||
logger.info('Rank {} returns: {}'.format(rank, loss.item()))
|
||||
logger.info('lr = %g' % engine.get_lr())
|
||||
|
||||
gpc.destroy()
|
||||
logger.info('Test engine finished')
|
||||
|
@@ -8,7 +8,6 @@ import torch
|
||||
|
||||
from colossalai import initialize
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.engine import Engine
|
||||
from colossalai.logging import get_global_dist_logger
|
||||
from colossalai.utils import report_memory_usage
|
||||
|
||||
@@ -26,21 +25,14 @@ NO_PIPE_CONFIG_PATH = osp.join(DIR_PATH, '../configs/non_pipeline_resnet.py')
|
||||
def test_no_pipeline(config):
|
||||
print('Test no pipeline engine start')
|
||||
|
||||
model, train_dataloader, test_dataloader, criterion, optimizer, schedule, lr_scheduler = initialize(config)
|
||||
engine, train_dataloader, test_dataloader = initialize(config)
|
||||
logger = get_global_dist_logger()
|
||||
|
||||
rank = torch.distributed.get_rank()
|
||||
engine = Engine(model=model,
|
||||
train_dataloader=train_dataloader,
|
||||
criterion=criterion,
|
||||
optimizer=optimizer,
|
||||
schedule=schedule)
|
||||
|
||||
engine.train()
|
||||
logger.info('lr = %g' % engine.get_lr())
|
||||
output, label, loss = engine.step()
|
||||
output, label, loss = engine.step(iter(train_dataloader))
|
||||
logger.info('Rank {} returns: {}'.format(rank, loss.item()))
|
||||
logger.info('lr = %g' % engine.get_lr())
|
||||
|
||||
gpc.destroy()
|
||||
logger.info('Test engine finished')
|
||||
|
@@ -8,7 +8,6 @@ import torch
|
||||
|
||||
from colossalai import initialize
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.engine import Engine
|
||||
from colossalai.logging import get_global_dist_logger
|
||||
from colossalai.utils import report_memory_usage
|
||||
|
||||
@@ -26,21 +25,13 @@ NO_PIPE_CONFIG_PATH = osp.join(DIR_PATH, '../configs/non_pipeline_resnet_torch_a
|
||||
def test_no_pipeline(config):
|
||||
print('Test no pipeline engine start')
|
||||
|
||||
model, train_dataloader, test_dataloader, criterion, optimizer, schedule, lr_scheduler = initialize(config)
|
||||
engine, train_dataloader, test_dataloader = initialize(config)
|
||||
logger = get_global_dist_logger()
|
||||
|
||||
rank = torch.distributed.get_rank()
|
||||
engine = Engine(model=model,
|
||||
train_dataloader=train_dataloader,
|
||||
criterion=criterion,
|
||||
optimizer=optimizer,
|
||||
schedule=schedule)
|
||||
|
||||
engine.train()
|
||||
logger.info('lr = %g' % engine.get_lr())
|
||||
output, label, loss = engine.step()
|
||||
output, label, loss = engine.step(iter(train_dataloader))
|
||||
logger.info('Rank {} returns: {}'.format(rank, loss.item()))
|
||||
logger.info('lr = %g' % engine.get_lr())
|
||||
|
||||
gpc.destroy()
|
||||
logger.info('Test engine finished')
|
||||
|
@@ -5,6 +5,7 @@ import os.path as osp
|
||||
|
||||
import pytest
|
||||
|
||||
from colossalai.context import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.initialize import initialize
|
||||
from colossalai.logging import get_global_dist_logger
|
||||
@@ -22,13 +23,25 @@ CONFIG_PATH = osp.join(DIR_PATH, '../configs/pipeline_vanilla_resnet.py')
|
||||
@pytest.mark.skip("This test should be invoked using the test.sh provided")
|
||||
@pytest.mark.dist
|
||||
def test_schedule():
|
||||
model, train_dataloader, test_dataloader, criterion, optimizer, schedule, lr_scheduler = initialize(CONFIG_PATH)
|
||||
engine, train_dataloader, test_dataloader = initialize(CONFIG_PATH)
|
||||
logger = get_global_dist_logger()
|
||||
|
||||
schedule.zero_grad()
|
||||
output, label, losses = schedule.forward_backward_step(forward_only=False)
|
||||
schedule.step()
|
||||
logger.info('losses: {}'.format([loss.item() for loss in losses]))
|
||||
model = engine.model
|
||||
optimizer = engine.optimizer
|
||||
criterion = engine.criterion
|
||||
schedule = engine._schedule
|
||||
|
||||
output, label, loss = schedule.forward_backward_step(
|
||||
data_iter=iter(train_dataloader),
|
||||
model=model,
|
||||
optimizer=optimizer,
|
||||
criterion=criterion,
|
||||
forward_only=False
|
||||
)
|
||||
schedule.optimizer_step(model, optimizer)
|
||||
|
||||
if gpc.is_last_rank(ParallelMode.PIPELINE):
|
||||
logger.info('losses: {}'.format(loss))
|
||||
|
||||
gpc.destroy()
|
||||
logger.info('training finished')
|
||||
|
@@ -9,7 +9,6 @@ import torch
|
||||
from colossalai import initialize
|
||||
from colossalai.context import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.engine import Engine
|
||||
from colossalai.logging import get_global_dist_logger
|
||||
|
||||
NUM_BATCH = 128
|
||||
@@ -23,22 +22,14 @@ PIPE_CONFIG_PATH = osp.join(DIR_PATH, '../configs/pipeline_vanilla_resnet.py')
|
||||
|
||||
|
||||
def run_pipeline(config):
|
||||
model, train_dataloader, test_dataloader, criterion, optimizer, schedule, lr_scheduler = initialize(config)
|
||||
engine, train_dataloader, test_dataloader = initialize(config)
|
||||
logger = get_global_dist_logger()
|
||||
rank = torch.distributed.get_rank()
|
||||
engine = Engine(model=model,
|
||||
train_dataloader=train_dataloader,
|
||||
criterion=criterion,
|
||||
optimizer=optimizer,
|
||||
lr_scheduler=lr_scheduler,
|
||||
schedule=schedule)
|
||||
|
||||
engine.train()
|
||||
logger.info('lr = %g' % engine.get_lr())
|
||||
outputs, labels, loss = engine.step()
|
||||
outputs, labels, loss = engine.step(iter(train_dataloader))
|
||||
if gpc.is_last_rank(ParallelMode.PIPELINE):
|
||||
logger.info('losses: {}'.format(rank, loss.item()))
|
||||
logger.info('lr = %g' % engine.get_lr())
|
||||
|
||||
gpc.destroy()
|
||||
logger.info('Test engine pipeline finished')
|
||||
|
Reference in New Issue
Block a user