Hotfix/Colossalai layers (#92)

* optimized 1d layer apis; reorganized nn.layer modules; fixed tests

* fixed 2.5d runtime issue

* reworked split batch, now called in trainer.schedule.load_batch

Co-authored-by: BoxiangW <45734921+BoxiangW@users.noreply.github.com>
This commit is contained in:
アマデウス
2021-12-29 23:32:10 +08:00
committed by GitHub
parent 0fedef4f3c
commit 01a80cd86d
71 changed files with 1033 additions and 773 deletions

View File

@@ -1,25 +1,23 @@
# !/usr/bin/env python
# -*- encoding: utf-8 -*-
import colossalai
import os
from functools import partial
from pathlib import Path
import colossalai
import pytest
import torch
import os.path as osp
from pathlib import Path
import torch.nn as nn
import torch.multiprocessing as mp
from torchvision import transforms
from torch.optim import Adam
from colossalai.core import global_context as gpc
import torch.nn as nn
from colossalai.amp import AMP_TYPE
from colossalai.core import global_context as gpc
from colossalai.logging import get_dist_logger
from colossalai.utils import report_memory_usage, get_dataloader
from torchvision.models import resnet18
from colossalai.utils import free_port, get_dataloader, report_memory_usage
from torch.optim import Adam
from torchvision import transforms
from torchvision.datasets import CIFAR10
from functools import partial
from torchvision.models import resnet18
# Config
BATCH_SIZE = 128
@@ -38,14 +36,14 @@ CONFIG = dict(
)
def run_engine(rank, world_size):
def run_engine(rank, world_size, port):
# init dist env
colossalai.launch(
config=CONFIG,
rank=rank,
world_size=world_size,
host='localhost',
port=29910,
port=port,
backend='nccl'
)
@@ -104,7 +102,7 @@ def run_engine(rank, world_size):
@pytest.mark.dist
def test_engine():
world_size = 4
run_func = partial(run_engine, world_size=world_size)
run_func = partial(run_engine, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)

View File

@@ -1,23 +1,20 @@
import colossalai
import os
from functools import partial
from pathlib import Path
import colossalai
import pytest
import torch
import os.path as osp
from pathlib import Path
import torch.nn as nn
import torch.multiprocessing as mp
from torchvision import transforms
from torch.optim import Adam
from colossalai.core import global_context as gpc
import torch.nn as nn
from colossalai.amp import AMP_TYPE
from colossalai.core import global_context as gpc
from colossalai.logging import get_dist_logger
from colossalai.utils import report_memory_usage, get_dataloader
from colossalai.initialize import get_default_parser
from torchvision.models import resnet18
from colossalai.utils import free_port, get_dataloader, report_memory_usage
from torch.optim import Adam
from torchvision import transforms
from torchvision.datasets import CIFAR10
from functools import partial
from torchvision.models import resnet18
# Config
BATCH_SIZE = 128
@@ -38,14 +35,14 @@ CONFIG = dict(
)
def run_engine(rank, world_size):
def run_engine(rank, world_size, port):
# init dist env
colossalai.launch(
config=CONFIG,
rank=rank,
world_size=world_size,
host='localhost',
port=29911,
port=port,
backend='nccl'
)
@@ -104,7 +101,7 @@ def run_engine(rank, world_size):
@pytest.mark.dist
def test_engine():
world_size = 4
run_func = partial(run_engine, world_size=world_size)
run_func = partial(run_engine, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)

View File

@@ -1,23 +1,19 @@
import colossalai
import os
from functools import partial
from pathlib import Path
import colossalai
import pytest
import torch
import os.path as osp
from pathlib import Path
import torch.nn as nn
import torch.multiprocessing as mp
from torchvision import transforms
from torch.optim import Adam
import torch.nn as nn
from colossalai.core import global_context as gpc
from colossalai.amp import AMP_TYPE
from colossalai.logging import get_dist_logger
from colossalai.utils import report_memory_usage, get_dataloader
from colossalai.initialize import get_default_parser
from torchvision.models import resnet18
from colossalai.utils import free_port, get_dataloader, report_memory_usage
from torch.optim import Adam
from torchvision import transforms
from torchvision.datasets import CIFAR10
from functools import partial
from torchvision.models import resnet18
# Config
BATCH_SIZE = 128
@@ -35,14 +31,14 @@ CONFIG = dict(
)
def run_engine(rank, world_size):
def run_engine(rank, world_size, port):
# init dist env
colossalai.launch(
config=CONFIG,
rank=rank,
world_size=world_size,
host='localhost',
port=29912,
port=port,
backend='nccl'
)
@@ -101,7 +97,7 @@ def run_engine(rank, world_size):
@pytest.mark.dist
def test_engine():
world_size = 4
run_func = partial(run_engine, world_size=world_size)
run_func = partial(run_engine, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)

View File

@@ -1,23 +1,20 @@
import colossalai
import os
from functools import partial
from pathlib import Path
import colossalai
import pytest
import torch
import os.path as osp
from pathlib import Path
import torch.nn as nn
import torch.multiprocessing as mp
from torchvision import transforms
from torch.optim import Adam
from colossalai.core import global_context as gpc
import torch.nn as nn
from colossalai.amp import AMP_TYPE
from colossalai.core import global_context as gpc
from colossalai.logging import get_dist_logger
from colossalai.utils import report_memory_usage, get_dataloader
from colossalai.initialize import get_default_parser
from torchvision.models import resnet18
from colossalai.utils import free_port, get_dataloader, report_memory_usage
from torch.optim import Adam
from torchvision import transforms
from torchvision.datasets import CIFAR10
from functools import partial
from torchvision.models import resnet18
# Config
BATCH_SIZE = 128
@@ -36,14 +33,14 @@ CONFIG = dict(
)
def run_engine(rank, world_size):
def run_engine(rank, world_size, port):
# init dist env
colossalai.launch(
config=CONFIG,
rank=rank,
world_size=world_size,
host='localhost',
port=29913,
port=port,
backend='nccl'
)
@@ -102,7 +99,7 @@ def run_engine(rank, world_size):
@pytest.mark.dist
def test_engine():
world_size = 4
run_func = partial(run_engine, world_size=world_size)
run_func = partial(run_engine, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)