Hotfix/Colossalai layers (#92)

* optimized 1d layer apis; reorganized nn.layer modules; fixed tests

* fixed 2.5d runtime issue

* reworked split batch, now called in trainer.schedule.load_batch

Co-authored-by: BoxiangW <45734921+BoxiangW@users.noreply.github.com>
This commit is contained in:
アマデウス
2021-12-29 23:32:10 +08:00
committed by GitHub
parent 0fedef4f3c
commit 01a80cd86d
71 changed files with 1033 additions and 773 deletions

View File

@@ -1,26 +1,23 @@
# referenced from Megatron and used to testify communication
import colossalai
import os
import os.path as osp
from functools import partial
from pathlib import Path
import colossalai
import pytest
import torch
import torch.multiprocessing as mp
import model
from colossalai.builder import build_pipeline_model_from_cfg
from colossalai.communication import p2p as p2p_communication
from colossalai.communication.utils import send_tensor_meta, recv_tensor_meta
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.initialize import launch
from colossalai.utils import print_rank_0, get_current_device, get_dataloader
from colossalai.engine.schedule import PipelineSchedule
from torchvision.datasets import CIFAR10
from colossalai.initialize import launch
from colossalai.utils import free_port, get_dataloader, print_rank_0
from torchvision import transforms
from pathlib import Path
from functools import partial
from torchvision.datasets import CIFAR10
import model
BATCH_SIZE = 32
NUM_MICRO = 8
@@ -30,12 +27,12 @@ DIR_PATH = osp.dirname(osp.realpath(__file__))
CONFIG_PATH = osp.join(DIR_PATH, './resnet_config.py')
def run_schedule(rank, world_size):
def run_schedule(rank, world_size, port):
launch(config=CONFIG_PATH,
rank=rank,
world_size=world_size,
host='localhost',
port=29934,
port=port,
backend='nccl')
# build model
@@ -86,7 +83,7 @@ def run_schedule(rank, world_size):
@pytest.mark.dist
def test_pipeline_schedule():
world_size = 4
run_func = partial(run_schedule, world_size=world_size)
run_func = partial(run_schedule, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)