[pipeline] refactor the pipeline module (#1087)

* [pipeline] refactor the pipeline module

* polish code
This commit is contained in:
Frank Lee
2022-06-10 11:27:38 +08:00
committed by GitHub
parent bad5d4c0a1
commit 2b2dc1c86b
29 changed files with 366 additions and 1127 deletions

View File

@@ -8,27 +8,45 @@ from pathlib import Path
import colossalai
import pytest
import torch
import torch.nn as nn
import torch.multiprocessing as mp
from colossalai.builder import build_pipeline_model_from_cfg
from colossalai.core import global_context as gpc
from colossalai.engine.schedule import PipelineSchedule
from colossalai.context import ParallelMode
from colossalai.initialize import launch
from colossalai.utils import free_port, get_dataloader, print_rank_0
from colossalai.testing import rerun_on_exception
from torchvision import transforms
from torchvision.datasets import CIFAR10
from torchvision.models import resnet18
BATCH_SIZE = 4
DIR_PATH = osp.dirname(osp.realpath(__file__))
CONFIG_PATH = osp.join(DIR_PATH, './resnet_config.py')
BATCH_SIZE = 8
CONFIG=dict(
NUM_MICRO_BATCHES=2,
parallel = dict(
pipeline=dict(size=2),
tensor=dict(size=1, mode=None)
)
)
def run_schedule(rank, world_size, port):
launch(config=CONFIG_PATH, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
# build model
model = build_pipeline_model_from_cfg(gpc.config.model, 1)
model = resnet18(num_classes=10)
if gpc.get_local_rank(ParallelMode.PIPELINE) == 0:
model = nn.Sequential(model.conv1, model.bn1, model.relu, model.maxpool, model.layer1, model.layer2)
elif gpc.get_local_rank(ParallelMode.PIPELINE) == 1:
class Flatten(nn.Module):
def forward(self, x):
return torch.flatten(x, 1)
model = nn.Sequential(model.layer3, model.layer4, model.avgpool, Flatten(), model.fc)
print_rank_0('model is created')
train_dataset = CIFAR10(root=Path(os.environ['DATA']),
@@ -69,7 +87,7 @@ def run_schedule(rank, world_size, port):
@pytest.mark.dist
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
def test_pipeline_schedule():
world_size = 4
world_size = 2
run_func = partial(run_schedule, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)