[legacy] move trainer to legacy (#4545)

* [legacy] move trainer to legacy

* [doc] update docs related to trainer

* [test] ignore legacy test
This commit is contained in:
Hongxin Liu
2023-08-31 13:51:28 +08:00
parent 807e01a4ba
commit 89fe027787
32 changed files with 63 additions and 153 deletions

View File

@@ -0,0 +1,108 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import pytest
import torch
import torch.distributed as dist
from colossalai.communication import (
recv_backward,
recv_forward,
recv_obj_meta,
send_backward,
send_backward_recv_forward,
send_forward,
send_forward_recv_backward,
send_obj_meta,
)
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.initialize import launch
from colossalai.logging import get_dist_logger
from colossalai.testing import rerun_if_address_is_in_use, spawn
from colossalai.utils import get_current_device
BATCH_SIZE = 4
SEQ_LENGTH = 2
HIDDEN_SIZE = 16
CONFIG = dict(parallel=dict(pipeline=dict(size=4), tensor=dict(size=1, mode=None)), seed=1024)
def check_equal(A, B):
return torch.allclose(A, B, rtol=1e-5, atol=1e-3)
def check_forward(output_tensor, rank, logger):
dist.barrier()
if gpc.is_first_rank(ParallelMode.PIPELINE):
tensor = output_tensor.clone()
else:
tensor = recv_forward(output_tensor.shape)
logger.info('Rank {} received forward. Correct tensor: {}'.format(rank, check_equal(tensor, output_tensor)))
if not gpc.is_last_rank(ParallelMode.PIPELINE):
send_forward(tensor)
logger.info('Rank {} sent forward.'.format(rank))
def check_backward(output_grad, rank, logger):
dist.barrier()
if gpc.is_last_rank(ParallelMode.PIPELINE):
grad = output_grad.clone()
else:
grad = recv_backward(output_grad.shape)
logger.info('Rank {} received backward. Correct grad: {}'.format(rank, check_equal(grad, output_grad)))
if not gpc.is_first_rank(ParallelMode.PIPELINE):
send_backward(grad)
logger.info('Rank {} sent backward.'.format(rank))
def check_forward_backward(output_tensor, output_grad, rank, logger):
dist.barrier()
if not gpc.is_first_rank(ParallelMode.PIPELINE):
tensor = send_backward_recv_forward(output_grad, output_tensor.shape)
logger.info('Rank {} sent backward received forward. Correct tensor: {}'.format(
rank, check_equal(tensor, output_tensor)))
if not gpc.is_last_rank(ParallelMode.PIPELINE):
grad = send_forward_recv_backward(output_tensor, output_grad.shape)
logger.info('Rank {} sent forward received backward. Correct grad: {}'.format(
rank, check_equal(grad, output_grad)))
def check_comm(size, rank, prev_rank, next_rank, logger):
dtype = torch.float32
device = get_current_device()
tensor_shape = (BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE)
grad_shape = (BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE)
tensor = torch.randn(tensor_shape, dtype=dtype, device=device)
dist.all_reduce(tensor)
grad = torch.randn(grad_shape, dtype=dtype, device=device)
dist.all_reduce(grad)
check_forward(tensor, rank, logger)
check_backward(grad, rank, logger)
check_forward_backward(tensor, grad, rank, logger)
def run_check(rank, world_size, port):
launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
logger = get_dist_logger()
rank = gpc.get_global_rank()
prev_rank = gpc.get_prev_global_rank(ParallelMode.PIPELINE)
next_rank = gpc.get_next_global_rank(ParallelMode.PIPELINE)
logger.info('Rank {0}: prev rank {1}, next rank {2}'.format(rank, prev_rank, next_rank))
logger.info('Distributed environment is initialized.')
check_comm(world_size, rank, prev_rank, next_rank, logger)
gpc.destroy()
torch.cuda.empty_cache()
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_p2p():
world_size = 4
spawn(run_check, world_size)
if __name__ == '__main__':
test_p2p()

View File

@@ -0,0 +1,87 @@
# referenced from Megatron and used to testify communication
import os
from pathlib import Path
import pytest
import torch
import torch.nn as nn
from torchvision import transforms
from torchvision.datasets import CIFAR10
from torchvision.models import resnet18
import colossalai
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.initialize import launch
from colossalai.testing import rerun_if_address_is_in_use, spawn
from colossalai.utils import get_dataloader, print_rank_0
BATCH_SIZE = 8
CONFIG = dict(NUM_MICRO_BATCHES=2, parallel=dict(pipeline=dict(size=2), tensor=dict(size=1, mode=None)))
def run_schedule(rank, world_size, port):
launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
# build model
model = resnet18(num_classes=10)
if gpc.get_local_rank(ParallelMode.PIPELINE) == 0:
model = nn.Sequential(model.conv1, model.bn1, model.relu, model.maxpool, model.layer1, model.layer2)
elif gpc.get_local_rank(ParallelMode.PIPELINE) == 1:
class Flatten(nn.Module):
def forward(self, x):
return torch.flatten(x, 1)
model = nn.Sequential(model.layer3, model.layer4, model.avgpool, Flatten(), model.fc)
print_rank_0('model is created')
train_dataset = CIFAR10(root=Path(os.environ['DATA']),
download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010]),
]))
train_dataloader = get_dataloader(
dataset=train_dataset,
shuffle=True,
add_sampler=True,
batch_size=BATCH_SIZE,
pin_memory=True,
)
# build criterion
criterion = torch.nn.CrossEntropyLoss()
# optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=0)
# initialize
engine, train_dataloader, _, _ = colossalai.initialize(model, optimizer, criterion, train_dataloader)
# build pipeline schedule
schedule = engine.schedule
# run schedule
data_iter = iter(train_dataloader)
schedule.forward_backward_step(engine, data_iter)
gpc.destroy()
torch.cuda.empty_cache()
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_pipeline_schedule():
world_size = 2
spawn(run_schedule, world_size)
if __name__ == '__main__':
test_pipeline_schedule()

View File

@@ -0,0 +1,59 @@
import pytest
import torch
import colossalai
from colossalai.amp.amp_type import AMP_TYPE
from colossalai.legacy.trainer import Trainer
from colossalai.logging import get_dist_logger
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from colossalai.utils import MultiTimer
from tests.components_to_test.registry import non_distributed_component_funcs
BATCH_SIZE = 4
IMG_SIZE = 32
NUM_EPOCHS = 200
CONFIG = dict(fp16=dict(mode=AMP_TYPE.TORCH))
@parameterize('model_name', ['repeated_computed_layers', 'resnet18', 'nested_model'])
def run_trainer(model_name):
get_components_func = non_distributed_component_funcs.get_callable(model_name)
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
model = model_builder()
optimizer = optimizer_class(model.parameters(), lr=1e-3)
engine, train_dataloader, *_ = colossalai.initialize(model=model,
optimizer=optimizer,
criterion=criterion,
train_dataloader=train_dataloader)
logger = get_dist_logger()
logger.info("engine is built", ranks=[0])
timer = MultiTimer()
trainer = Trainer(engine=engine, logger=logger, timer=timer)
logger.info("trainer is built", ranks=[0])
logger.info("start training", ranks=[0])
trainer.fit(train_dataloader=train_dataloader,
test_dataloader=test_dataloader,
epochs=NUM_EPOCHS,
max_steps=3,
display_progress=True,
test_interval=5)
torch.cuda.empty_cache()
def run_dist(rank, world_size, port):
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_trainer_no_pipeline():
world_size = 4
spawn(run_dist, world_size)
if __name__ == '__main__':
test_trainer_no_pipeline()

View File

@@ -0,0 +1,96 @@
import os
from pathlib import Path
import pytest
import torch
import torch.nn as nn
from torch.optim import Adam
from torchvision import transforms
from torchvision.datasets import CIFAR10
from torchvision.models import resnet18
import colossalai
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.legacy.trainer import Trainer
from colossalai.logging import get_dist_logger
from colossalai.testing import rerun_if_address_is_in_use, spawn
from colossalai.utils import MultiTimer, get_dataloader
BATCH_SIZE = 4
IMG_SIZE = 32
NUM_EPOCHS = 200
CONFIG = dict(
NUM_MICRO_BATCHES=2,
parallel=dict(pipeline=2),
)
def run_trainer_with_pipeline(rank, world_size, port):
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
# build model
model = resnet18(num_classes=10)
if gpc.get_local_rank(ParallelMode.PIPELINE) == 0:
model = nn.Sequential(model.conv1, model.bn1, model.relu, model.maxpool, model.layer1, model.layer2)
elif gpc.get_local_rank(ParallelMode.PIPELINE) == 1:
class Flatten(nn.Module):
def forward(self, x):
return torch.flatten(x, 1)
model = nn.Sequential(model.layer3, model.layer4, model.avgpool, Flatten(), model.fc)
# build dataloaders
train_dataset = CIFAR10(root=Path(os.environ['DATA']),
download=True,
transform=transforms.Compose([
transforms.Resize(size=(IMG_SIZE, IMG_SIZE)),
transforms.ToTensor(),
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
]))
train_dataloader = get_dataloader(dataset=train_dataset,
shuffle=True,
batch_size=BATCH_SIZE,
pin_memory=True,
drop_last=True)
# build optimizer
optimizer = Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()
engine, train_dataloader, *args = colossalai.initialize(model=model,
optimizer=optimizer,
criterion=criterion,
train_dataloader=train_dataloader)
logger = get_dist_logger()
logger.info("engine is built", ranks=[0])
timer = MultiTimer()
trainer = Trainer(engine=engine, logger=logger, timer=timer)
logger.info("trainer is built", ranks=[0])
logger.info("start training", ranks=[0])
trainer.fit(train_dataloader=train_dataloader,
epochs=NUM_EPOCHS,
max_steps=3,
display_progress=True,
test_interval=5)
gpc.destroy()
torch.cuda.empty_cache()
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_trainer_with_pipeline():
world_size = 4
spawn(run_trainer_with_pipeline, world_size)
if __name__ == '__main__':
test_trainer_with_pipeline()