[test] refactor tests with spawn (#3452)

* [test] added spawn decorator

* polish code

* polish code

* polish code

* polish code

* polish code

* polish code
This commit is contained in:
Frank Lee
2023-04-06 14:51:35 +08:00
committed by GitHub
parent 62f4e2eb07
commit 80eba05b0a
240 changed files with 1723 additions and 2342 deletions

View File

@@ -1,13 +1,12 @@
import math
import pytest
import torch
import torch.distributed as dist
import pytest
import colossalai
import torch.multiprocessing as mp
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port
from colossalai.tensor import DistSpecManager, ProcessGroup, ShardSpec, ReplicaSpec
from functools import partial
from colossalai.tensor import DistSpecManager, ProcessGroup, ReplicaSpec, ShardSpec
from colossalai.testing import rerun_if_address_is_in_use, spawn
def run():
@@ -58,8 +57,7 @@ def run_dist(rank, world_size, port):
@pytest.mark.parametrize('world_size', [1, 4])
@rerun_if_address_is_in_use()
def test_dist_spec_mgr(world_size):
run_func = partial(run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
spawn(run_dist, world_size)
if __name__ == '__main__':

View File

@@ -1,17 +1,11 @@
import torch
import pytest
from colossalai.tensor import ColoTensor
import torch
from numpy import allclose
import colossalai
from colossalai.utils import free_port
from colossalai.tensor import ColoTensorSpec
from colossalai.core import global_context as gpc
import torch.multiprocessing as mp
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port
from colossalai.tensor import distspec, ColoTensor, ProcessGroup, ShardSpec, ReplicaSpec
from functools import partial
from colossalai.tensor import ColoTensor, ColoTensorSpec, ProcessGroup, ReplicaSpec, ShardSpec, distspec
from colossalai.testing import rerun_if_address_is_in_use, spawn
def _run_tensor_indexing():
@@ -152,8 +146,7 @@ def run_dist_tests(rank, world_size, port):
@pytest.mark.parametrize('world_size', [1, 2])
@rerun_if_address_is_in_use()
def test_dist_cases(world_size):
run_func = partial(run_dist_tests, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
spawn(run_dist_tests, world_size)
if __name__ == '__main__':

View File

@@ -1,15 +1,11 @@
from functools import partial
import pytest
import torch
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
import colossalai
from colossalai.nn.parallel.data_parallel import ColoDDP
from colossalai.tensor import ColoTensor, ColoTensorSpec, ComputePattern, ComputeSpec, ProcessGroup, ShardSpec
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port
from colossalai.testing import rerun_if_address_is_in_use, spawn
from colossalai.utils.cuda import get_current_device
from colossalai.zero import ColoInitContext
from tests.components_to_test.registry import non_distributed_component_funcs
@@ -145,8 +141,7 @@ def run_dist(rank, world_size, port, use_ddp):
@pytest.mark.parametrize('use_ddp', [False, True])
@rerun_if_address_is_in_use()
def test_gpt(world_size, use_ddp):
run_func = partial(run_dist, world_size=world_size, port=free_port(), use_ddp=use_ddp)
mp.spawn(run_func, nprocs=world_size)
spawn(run_dist, world_size, use_ddp=use_ddp)
if __name__ == '__main__':

View File

@@ -1,15 +1,11 @@
from functools import partial
import pytest
import torch
import torch.multiprocessing as mp
import colossalai
from colossalai.nn.optimizer import ColossalaiOptimizer
from colossalai.tensor import ColoTensor, ProcessGroup
from colossalai.tensor.colo_parameter import ColoParameter
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port
from colossalai.testing import free_port, rerun_if_address_is_in_use, spawn
from colossalai.utils.cuda import get_current_device
from colossalai.zero import ColoInitContext
from tests.components_to_test.registry import non_distributed_component_funcs
@@ -313,8 +309,7 @@ def run_model_dist(rank, world_size, port):
@pytest.mark.parametrize('world_size', [1, 4])
@rerun_if_address_is_in_use()
def test_model(world_size):
run_func = partial(run_model_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
spawn(run_model_dist, world_size)
def run_pretrain_load_dist(rank, world_size, port):
@@ -329,8 +324,7 @@ def run_pretrain_load_dist(rank, world_size, port):
@pytest.mark.parametrize('world_size', [1, 4])
@rerun_if_address_is_in_use()
def test_pretrain_load(world_size):
run_func = partial(run_pretrain_load_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
spawn(run_pretrain_load_dist, world_size)
if __name__ == '__main__':

View File

@@ -1,9 +1,7 @@
from copy import deepcopy
from functools import partial
import pytest
import torch
import torch.multiprocessing as mp
import colossalai
from colossalai.nn.parallel.layers import check_colo_module, init_colo_module
@@ -17,8 +15,7 @@ from colossalai.tensor import (
ShardSpec,
distspec,
)
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port
from colossalai.testing import rerun_if_address_is_in_use, spawn
from colossalai.utils.cuda import get_current_device
from colossalai.zero import ColoInitContext
from tests.components_to_test.registry import non_distributed_component_funcs
@@ -207,8 +204,7 @@ def run_dist_check(rank, world_size, port):
@pytest.mark.skip("for higher testing speed")
@rerun_if_address_is_in_use()
def test_module_linear_1d(world_size):
run_func = partial(run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
spawn(run_dist, world_size)
@pytest.mark.dist
@@ -216,8 +212,7 @@ def test_module_linear_1d(world_size):
@pytest.mark.skip("for higher testing speed")
@rerun_if_address_is_in_use()
def test_module_model(world_size):
run_func = partial(run_dist_model, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
spawn(run_dist_model, world_size)
@pytest.mark.dist
@@ -225,8 +220,7 @@ def test_module_model(world_size):
@pytest.mark.skip("for higher testing speed")
@rerun_if_address_is_in_use()
def test_module_check(world_size):
run_func = partial(run_dist_check, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
spawn(run_dist_check, world_size)
if __name__ == '__main__':

View File

@@ -1,47 +1,41 @@
import torch
import pytest
from functools import partial
import torch.multiprocessing as mp
import torch.distributed as dist
import colossalai
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils.cuda import get_current_device
from colossalai.utils import free_port
from colossalai.tensor import ComputePattern, ComputeSpec, ColoTensor, ShardSpec, ProcessGroup, ColoTensorSpec
from colossalai.utils.checkpoint.utils import gather_tensor, scatter_tensor
from tests.test_tensor.common_utils import tensor_shard_equal
def run_dist(rank, world_size, port, dp_degree, tp_degree):
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
pg = ProcessGroup(dp_degree=dp_degree, tp_degree=tp_degree)
x = torch.randn(4, 4)
param = ColoTensor(torch.nn.Parameter(x), spec=ColoTensorSpec(pg))
spec = ShardSpec([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D)
param.set_tensor_spec(*spec)
gather_tensor(param)
if dist.get_rank() == 0:
assert torch.all(x == param)
else:
assert tensor_shard_equal(x, param.data, pg.tp_local_rank(), pg.tp_world_size())
dist.barrier()
scatter_tensor(param, spec[0])
assert tensor_shard_equal(x, param.data, pg.tp_local_rank(), pg.tp_world_size())
assert param.requires_grad is True
dist.barrier()
@pytest.mark.dist
@pytest.mark.parametrize('world_size', [4])
@rerun_if_address_is_in_use()
def test_checkpoint(world_size):
run_func = partial(run_dist, world_size=world_size, port=free_port(), dp_degree=2, tp_degree=world_size // 2)
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_checkpoint(world_size=4)
import pytest
import torch
import torch.distributed as dist
import colossalai
from colossalai.tensor import ColoTensor, ColoTensorSpec, ComputePattern, ComputeSpec, ProcessGroup, ShardSpec
from colossalai.testing import rerun_if_address_is_in_use, spawn
from colossalai.utils.checkpoint.utils import gather_tensor, scatter_tensor
from tests.test_tensor.common_utils import tensor_shard_equal
def run_dist(rank, world_size, port, dp_degree, tp_degree):
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
pg = ProcessGroup(dp_degree=dp_degree, tp_degree=tp_degree)
x = torch.randn(4, 4)
param = ColoTensor(torch.nn.Parameter(x), spec=ColoTensorSpec(pg))
spec = ShardSpec([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D)
param.set_tensor_spec(*spec)
gather_tensor(param)
if dist.get_rank() == 0:
assert torch.all(x == param)
else:
assert tensor_shard_equal(x, param.data, pg.tp_local_rank(), pg.tp_world_size())
dist.barrier()
scatter_tensor(param, spec[0])
assert tensor_shard_equal(x, param.data, pg.tp_local_rank(), pg.tp_world_size())
assert param.requires_grad is True
dist.barrier()
@pytest.mark.dist
@pytest.mark.parametrize('world_size', [4])
@rerun_if_address_is_in_use()
def test_checkpoint(world_size):
spawn(run_dist, world_size, dp_degree=2, tp_degree=world_size // 2)
if __name__ == '__main__':
test_checkpoint(world_size=4)

View File

@@ -1,10 +1,5 @@
from functools import partial
import pytest
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.distributed import ReduceOp
from colossalai.core import global_context as gpc
from colossalai.device.device_mesh import DeviceMesh
@@ -12,8 +7,7 @@ from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.tensor.shape_consistency import CollectiveCommPattern, CommSpec
from colossalai.tensor.sharding_spec import ShardingSpec
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port
from colossalai.testing import rerun_if_address_is_in_use, spawn
def check_all_gather(device_mesh, rank):
@@ -218,8 +212,7 @@ def check_comm(rank, world_size, port):
@rerun_if_address_is_in_use()
def test_comm_spec():
world_size = 4
run_func = partial(check_comm, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
spawn(check_comm, world_size)
if __name__ == '__main__':

View File

@@ -1,8 +1,5 @@
from functools import partial
import pytest
import torch
import torch.multiprocessing as mp
import colossalai
from colossalai.tensor import (
@@ -14,8 +11,7 @@ from colossalai.tensor import (
ReplicaSpec,
ShardSpec,
)
from colossalai.testing import parameterize, rerun_if_address_is_in_use
from colossalai.utils import free_port
from colossalai.testing import rerun_if_address_is_in_use, spawn
from colossalai.utils.cuda import get_current_device
from colossalai.zero import ColoInitContext
from tests.components_to_test.registry import non_distributed_component_funcs
@@ -61,8 +57,7 @@ def run_colo_init_context(rank: int, world_size: int, port: int):
@pytest.mark.parametrize('world_size', [1, 4])
@rerun_if_address_is_in_use()
def test_colo_init_context(world_size):
run_func = partial(run_colo_init_context, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
spawn(run_colo_init_context, world_size)
if __name__ == '__main__':

View File

@@ -1,9 +1,6 @@
from functools import partial
import pytest
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.distributed import ReduceOp
from colossalai.core import global_context as gpc
@@ -12,8 +9,7 @@ from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.tensor.d_tensor.comm_spec import CollectiveCommPattern, CommSpec
from colossalai.tensor.d_tensor.sharding_spec import ShardingSpec
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port
from colossalai.testing import rerun_if_address_is_in_use, spawn
def check_all_gather(process_groups_dict, rank):
@@ -182,8 +178,7 @@ def check_comm(rank, world_size, port):
@rerun_if_address_is_in_use()
def test_comm_spec():
world_size = 4
run_func = partial(check_comm, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
spawn(check_comm, world_size)
if __name__ == '__main__':

View File

@@ -1,7 +1,4 @@
from functools import partial
import torch
import torch.multiprocessing as mp
from colossalai.device.device_mesh import DeviceMesh
from colossalai.initialize import launch
@@ -9,7 +6,7 @@ from colossalai.logging import disable_existing_loggers
from colossalai.tensor.d_tensor.d_tensor import DTensor, distribute_tensor
from colossalai.tensor.d_tensor.layout import Layout
from colossalai.tensor.d_tensor.sharding_spec import ShardingSpec
from colossalai.utils import free_port
from colossalai.testing import rerun_if_address_is_in_use, spawn
class TestModel(torch.nn.Module):
@@ -92,10 +89,10 @@ def check_dtensor(rank, world_size, port):
raise ValueError(f'rank {rank} is not in the device mesh')
@rerun_if_address_is_in_use()
def test_dtensor():
world_size = 4
run_func = partial(check_dtensor, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
spawn(check_dtensor, world_size)
if __name__ == '__main__':

View File

@@ -1,9 +1,7 @@
import math
from functools import partial
import pytest
import torch
import torch.multiprocessing as mp
from colossalai.device.device_mesh import DeviceMesh
from colossalai.initialize import launch
@@ -12,8 +10,7 @@ from colossalai.tensor.d_tensor.comm_spec import CollectiveCommPattern
from colossalai.tensor.d_tensor.layout import Layout
from colossalai.tensor.d_tensor.layout_converter import LayoutConverter
from colossalai.tensor.d_tensor.sharding_spec import DimSpec, ShardingSpec
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port
from colossalai.testing import rerun_if_address_is_in_use, spawn
entire_shape = torch.Size((64, 32, 16))
layout_converter = LayoutConverter()
@@ -192,14 +189,9 @@ def check_layout_converting_apply(rank, world_size, port):
@rerun_if_address_is_in_use()
def test_layout_converter():
world_size = 4
run_func = partial(check_one_step_transform, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
run_func = partial(check_layout_converting, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
run_func = partial(check_layout_converting_apply, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
spawn(check_one_step_transform, world_size)
spawn(check_layout_converting, world_size)
spawn(check_layout_converting_apply, world_size)
if __name__ == '__main__':

View File

@@ -1,8 +1,5 @@
from functools import partial
import pytest
import torch
import torch.multiprocessing as mp
from colossalai.core import global_context as gpc
from colossalai.device.device_mesh import DeviceMesh
@@ -11,7 +8,7 @@ from colossalai.logging import disable_existing_loggers
from colossalai.tensor.shape_consistency import CollectiveCommPattern, CommSpec
from colossalai.tensor.sharding_spec import ShardingSpec
from colossalai.tensor.utils import mix_gather_simulator
from colossalai.utils import free_port
from colossalai.testing import rerun_if_address_is_in_use, spawn
def check_mix_gather_S0S1(device_mesh, rank):
@@ -323,10 +320,10 @@ def check_comm(rank, world_size, port):
@pytest.mark.skip(reason="Skip because the check functions assume 8 GPUS but CI only have 4 GPUs")
@rerun_if_address_is_in_use()
def test_mix_gather():
world_size = 8
run_func = partial(check_comm, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
spawn(check_comm, world_size)
if __name__ == '__main__':

View File

@@ -1,9 +1,10 @@
from colossalai.tensor import ColoParameter, ColoTensor, ColoTensorSpec, ProcessGroup
import torch
import pytest
import torch
from common_utils import tensor_equal
import colossalai
from colossalai.utils import free_port
from colossalai.tensor import ColoParameter, ColoTensor, ColoTensorSpec, ProcessGroup
from colossalai.testing import free_port
@pytest.mark.skip

View File

@@ -1,16 +1,12 @@
from functools import partial
import pytest
import torch
import torch.multiprocessing as mp
from colossalai.device.device_mesh import DeviceMesh
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.tensor.shape_consistency import CollectiveCommPattern, ShapeConsistencyManager
from colossalai.tensor.sharding_spec import ShardingSpec
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port
from colossalai.testing import rerun_if_address_is_in_use, spawn
def check_apply(rank, world_size, port):
@@ -73,8 +69,7 @@ def check_apply(rank, world_size, port):
@rerun_if_address_is_in_use()
def test_apply():
world_size = 4
run_func = partial(check_apply, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
spawn(check_apply, world_size)
if __name__ == '__main__':

View File

@@ -1,8 +1,5 @@
from functools import partial
import pytest
import torch
import torch.multiprocessing as mp
import torch.nn.functional as F
import colossalai
@@ -10,8 +7,7 @@ from colossalai.device.device_mesh import DeviceMesh
from colossalai.nn._ops._utils import gather_forward_split_backward
from colossalai.tensor import ColoParameter, ColoTensor, ProcessGroup
from colossalai.tensor.sharding_spec import ShardingSpec
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port
from colossalai.testing import rerun_if_address_is_in_use, spawn
def run_dist(rank, world_size, port):
@@ -229,8 +225,7 @@ def run_dist(rank, world_size, port):
@pytest.mark.parametrize('world_size', [4])
@rerun_if_address_is_in_use()
def test_sharded_mlp(world_size):
run_func = partial(run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
spawn(run_dist, world_size)
if __name__ == '__main__':

View File

@@ -1,15 +1,11 @@
from functools import partial
import pytest
import torch
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
import colossalai
from colossalai.amp import convert_to_apex_amp
from colossalai.tensor import ColoTensor, ColoTensorSpec, ComputePattern, ComputeSpec, ProcessGroup, ShardSpec
from colossalai.testing import parameterize, rerun_if_address_is_in_use
from colossalai.utils import free_port
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from colossalai.utils.cuda import get_current_device
from colossalai.zero import ColoInitContext, GeminiAdamOptimizer, GeminiDDP, ZeroDDP
from colossalai.zero.gemini import search_chunk_configuration
@@ -140,8 +136,7 @@ def run_dist(rank, world_size, port):
@pytest.mark.parametrize('world_size', [1, 4])
@rerun_if_address_is_in_use()
def test_gpt(world_size):
run_func = partial(run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
spawn(run_dist, world_size)
if __name__ == '__main__':