[test] refactor tests with spawn (#3452)

* [test] added spawn decorator

* polish code

* polish code

* polish code

* polish code

* polish code

* polish code
This commit is contained in:
Frank Lee
2023-04-06 14:51:35 +08:00
committed by GitHub
parent 62f4e2eb07
commit 80eba05b0a
240 changed files with 1723 additions and 2342 deletions

View File

@@ -1,21 +1,26 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from functools import partial
import pytest
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from colossalai.communication import (recv_backward, recv_forward, recv_obj_meta, send_backward,
send_backward_recv_forward, send_forward, send_forward_recv_backward,
send_obj_meta)
from colossalai.communication import (
recv_backward,
recv_forward,
recv_obj_meta,
send_backward,
send_backward_recv_forward,
send_forward,
send_forward_recv_backward,
send_obj_meta,
)
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.initialize import launch
from colossalai.logging import get_dist_logger
from colossalai.utils import free_port, get_current_device
from colossalai.testing import rerun_on_exception
from colossalai.testing import rerun_if_address_is_in_use, spawn
from colossalai.utils import get_current_device
BATCH_SIZE = 4
SEQ_LENGTH = 2
@@ -93,11 +98,10 @@ def run_check(rank, world_size, port):
@pytest.mark.dist
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
@rerun_if_address_is_in_use()
def test_p2p():
world_size = 4
run_func = partial(run_check, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
spawn(run_check, world_size)
if __name__ == '__main__':

View File

@@ -1,34 +1,26 @@
# referenced from Megatron and used to testify communication
import os
import os.path as osp
from functools import partial
from pathlib import Path
import colossalai
import pytest
import torch
import torch.nn as nn
import torch.multiprocessing as mp
from colossalai.core import global_context as gpc
from colossalai.context import ParallelMode
from colossalai.initialize import launch
from colossalai.utils import free_port, get_dataloader, print_rank_0
from colossalai.testing import rerun_on_exception
from torchvision import transforms
from torchvision.datasets import CIFAR10
from torchvision.models import resnet18
import colossalai
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.initialize import launch
from colossalai.testing import rerun_if_address_is_in_use, spawn
from colossalai.utils import get_dataloader, print_rank_0
BATCH_SIZE = 8
CONFIG=dict(
NUM_MICRO_BATCHES=2,
parallel = dict(
pipeline=dict(size=2),
tensor=dict(size=1, mode=None)
)
)
CONFIG = dict(NUM_MICRO_BATCHES=2, parallel=dict(pipeline=dict(size=2), tensor=dict(size=1, mode=None)))
def run_schedule(rank, world_size, port):
launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
@@ -85,11 +77,10 @@ def run_schedule(rank, world_size, port):
@pytest.mark.dist
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
@rerun_if_address_is_in_use()
def test_pipeline_schedule():
world_size = 2
run_func = partial(run_schedule, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
spawn(run_schedule, world_size)
if __name__ == '__main__':

View File

@@ -1,15 +1,13 @@
from functools import partial
import colossalai
import pytest
import torch
import torch.multiprocessing as mp
import colossalai
from colossalai.amp.amp_type import AMP_TYPE
from colossalai.logging import get_dist_logger
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from colossalai.trainer import Trainer
from colossalai.utils import MultiTimer, free_port
from colossalai.utils import MultiTimer
from tests.components_to_test.registry import non_distributed_component_funcs
from colossalai.testing import parameterize, rerun_if_address_is_in_use
BATCH_SIZE = 4
IMG_SIZE = 32
@@ -54,8 +52,7 @@ def run_dist(rank, world_size, port):
@rerun_if_address_is_in_use()
def test_trainer_no_pipeline():
world_size = 4
run_func = partial(run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
spawn(run_dist, world_size)
if __name__ == '__main__':

View File

@@ -1,23 +1,21 @@
import os
from functools import partial
from pathlib import Path
import colossalai
import pytest
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.engine.schedule import PipelineSchedule
from colossalai.logging import get_dist_logger
from colossalai.trainer import Trainer
from colossalai.utils import MultiTimer, free_port, get_dataloader
from torch.optim import Adam
from torchvision import transforms
from torchvision.datasets import CIFAR10
from torchvision.models import resnet18
from colossalai.testing import rerun_if_address_is_in_use
import colossalai
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.logging import get_dist_logger
from colossalai.testing import rerun_if_address_is_in_use, spawn
from colossalai.trainer import Trainer
from colossalai.utils import MultiTimer, get_dataloader
BATCH_SIZE = 4
IMG_SIZE = 32
@@ -91,8 +89,7 @@ def run_trainer_with_pipeline(rank, world_size, port):
@rerun_if_address_is_in_use()
def test_trainer_with_pipeline():
world_size = 4
run_func = partial(run_trainer_with_pipeline, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
spawn(run_trainer_with_pipeline, world_size)
if __name__ == '__main__':