mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-08 12:30:42 +00:00
[test] refactor tests with spawn (#3452)
* [test] added spawn decorator * polish code * polish code * polish code * polish code * polish code * polish code
This commit is contained in:
@@ -1,13 +1,12 @@
|
||||
import math
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import pytest
|
||||
|
||||
import colossalai
|
||||
import torch.multiprocessing as mp
|
||||
from colossalai.testing import rerun_if_address_is_in_use
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.tensor import DistSpecManager, ProcessGroup, ShardSpec, ReplicaSpec
|
||||
from functools import partial
|
||||
from colossalai.tensor import DistSpecManager, ProcessGroup, ReplicaSpec, ShardSpec
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
|
||||
|
||||
def run():
|
||||
@@ -58,8 +57,7 @@ def run_dist(rank, world_size, port):
|
||||
@pytest.mark.parametrize('world_size', [1, 4])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_dist_spec_mgr(world_size):
|
||||
run_func = partial(run_dist, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
spawn(run_dist, world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@@ -1,17 +1,11 @@
|
||||
import torch
|
||||
import pytest
|
||||
from colossalai.tensor import ColoTensor
|
||||
import torch
|
||||
from numpy import allclose
|
||||
|
||||
import colossalai
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.tensor import ColoTensorSpec
|
||||
from colossalai.core import global_context as gpc
|
||||
import torch.multiprocessing as mp
|
||||
from colossalai.testing import rerun_if_address_is_in_use
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.tensor import distspec, ColoTensor, ProcessGroup, ShardSpec, ReplicaSpec
|
||||
from functools import partial
|
||||
from colossalai.tensor import ColoTensor, ColoTensorSpec, ProcessGroup, ReplicaSpec, ShardSpec, distspec
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
|
||||
|
||||
def _run_tensor_indexing():
|
||||
@@ -152,8 +146,7 @@ def run_dist_tests(rank, world_size, port):
|
||||
@pytest.mark.parametrize('world_size', [1, 2])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_dist_cases(world_size):
|
||||
run_func = partial(run_dist_tests, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
spawn(run_dist_tests, world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@@ -1,15 +1,11 @@
|
||||
from functools import partial
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
|
||||
import colossalai
|
||||
from colossalai.nn.parallel.data_parallel import ColoDDP
|
||||
from colossalai.tensor import ColoTensor, ColoTensorSpec, ComputePattern, ComputeSpec, ProcessGroup, ShardSpec
|
||||
from colossalai.testing import rerun_if_address_is_in_use
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
from colossalai.utils.cuda import get_current_device
|
||||
from colossalai.zero import ColoInitContext
|
||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||
@@ -145,8 +141,7 @@ def run_dist(rank, world_size, port, use_ddp):
|
||||
@pytest.mark.parametrize('use_ddp', [False, True])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_gpt(world_size, use_ddp):
|
||||
run_func = partial(run_dist, world_size=world_size, port=free_port(), use_ddp=use_ddp)
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
spawn(run_dist, world_size, use_ddp=use_ddp)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@@ -1,15 +1,11 @@
|
||||
from functools import partial
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
import colossalai
|
||||
from colossalai.nn.optimizer import ColossalaiOptimizer
|
||||
from colossalai.tensor import ColoTensor, ProcessGroup
|
||||
from colossalai.tensor.colo_parameter import ColoParameter
|
||||
from colossalai.testing import rerun_if_address_is_in_use
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.testing import free_port, rerun_if_address_is_in_use, spawn
|
||||
from colossalai.utils.cuda import get_current_device
|
||||
from colossalai.zero import ColoInitContext
|
||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||
@@ -313,8 +309,7 @@ def run_model_dist(rank, world_size, port):
|
||||
@pytest.mark.parametrize('world_size', [1, 4])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_model(world_size):
|
||||
run_func = partial(run_model_dist, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
spawn(run_model_dist, world_size)
|
||||
|
||||
|
||||
def run_pretrain_load_dist(rank, world_size, port):
|
||||
@@ -329,8 +324,7 @@ def run_pretrain_load_dist(rank, world_size, port):
|
||||
@pytest.mark.parametrize('world_size', [1, 4])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_pretrain_load(world_size):
|
||||
run_func = partial(run_pretrain_load_dist, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
spawn(run_pretrain_load_dist, world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@@ -1,9 +1,7 @@
|
||||
from copy import deepcopy
|
||||
from functools import partial
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
import colossalai
|
||||
from colossalai.nn.parallel.layers import check_colo_module, init_colo_module
|
||||
@@ -17,8 +15,7 @@ from colossalai.tensor import (
|
||||
ShardSpec,
|
||||
distspec,
|
||||
)
|
||||
from colossalai.testing import rerun_if_address_is_in_use
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
from colossalai.utils.cuda import get_current_device
|
||||
from colossalai.zero import ColoInitContext
|
||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||
@@ -207,8 +204,7 @@ def run_dist_check(rank, world_size, port):
|
||||
@pytest.mark.skip("for higher testing speed")
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_module_linear_1d(world_size):
|
||||
run_func = partial(run_dist, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
spawn(run_dist, world_size)
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@@ -216,8 +212,7 @@ def test_module_linear_1d(world_size):
|
||||
@pytest.mark.skip("for higher testing speed")
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_module_model(world_size):
|
||||
run_func = partial(run_dist_model, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
spawn(run_dist_model, world_size)
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@@ -225,8 +220,7 @@ def test_module_model(world_size):
|
||||
@pytest.mark.skip("for higher testing speed")
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_module_check(world_size):
|
||||
run_func = partial(run_dist_check, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
spawn(run_dist_check, world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@@ -1,47 +1,41 @@
|
||||
import torch
|
||||
import pytest
|
||||
from functools import partial
|
||||
|
||||
import torch.multiprocessing as mp
|
||||
import torch.distributed as dist
|
||||
|
||||
import colossalai
|
||||
from colossalai.testing import rerun_if_address_is_in_use
|
||||
from colossalai.utils.cuda import get_current_device
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.tensor import ComputePattern, ComputeSpec, ColoTensor, ShardSpec, ProcessGroup, ColoTensorSpec
|
||||
from colossalai.utils.checkpoint.utils import gather_tensor, scatter_tensor
|
||||
from tests.test_tensor.common_utils import tensor_shard_equal
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port, dp_degree, tp_degree):
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
pg = ProcessGroup(dp_degree=dp_degree, tp_degree=tp_degree)
|
||||
x = torch.randn(4, 4)
|
||||
param = ColoTensor(torch.nn.Parameter(x), spec=ColoTensorSpec(pg))
|
||||
spec = ShardSpec([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D)
|
||||
param.set_tensor_spec(*spec)
|
||||
|
||||
gather_tensor(param)
|
||||
if dist.get_rank() == 0:
|
||||
assert torch.all(x == param)
|
||||
else:
|
||||
assert tensor_shard_equal(x, param.data, pg.tp_local_rank(), pg.tp_world_size())
|
||||
dist.barrier()
|
||||
|
||||
scatter_tensor(param, spec[0])
|
||||
assert tensor_shard_equal(x, param.data, pg.tp_local_rank(), pg.tp_world_size())
|
||||
assert param.requires_grad is True
|
||||
dist.barrier()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize('world_size', [4])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_checkpoint(world_size):
|
||||
run_func = partial(run_dist, world_size=world_size, port=free_port(), dp_degree=2, tp_degree=world_size // 2)
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_checkpoint(world_size=4)
|
||||
import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
import colossalai
|
||||
from colossalai.tensor import ColoTensor, ColoTensorSpec, ComputePattern, ComputeSpec, ProcessGroup, ShardSpec
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
from colossalai.utils.checkpoint.utils import gather_tensor, scatter_tensor
|
||||
from tests.test_tensor.common_utils import tensor_shard_equal
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port, dp_degree, tp_degree):
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
pg = ProcessGroup(dp_degree=dp_degree, tp_degree=tp_degree)
|
||||
x = torch.randn(4, 4)
|
||||
param = ColoTensor(torch.nn.Parameter(x), spec=ColoTensorSpec(pg))
|
||||
spec = ShardSpec([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D)
|
||||
param.set_tensor_spec(*spec)
|
||||
|
||||
gather_tensor(param)
|
||||
if dist.get_rank() == 0:
|
||||
assert torch.all(x == param)
|
||||
else:
|
||||
assert tensor_shard_equal(x, param.data, pg.tp_local_rank(), pg.tp_world_size())
|
||||
dist.barrier()
|
||||
|
||||
scatter_tensor(param, spec[0])
|
||||
assert tensor_shard_equal(x, param.data, pg.tp_local_rank(), pg.tp_world_size())
|
||||
assert param.requires_grad is True
|
||||
dist.barrier()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize('world_size', [4])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_checkpoint(world_size):
|
||||
spawn(run_dist, world_size, dp_degree=2, tp_degree=world_size // 2)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_checkpoint(world_size=4)
|
||||
|
@@ -1,10 +1,5 @@
|
||||
from functools import partial
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.multiprocessing as mp
|
||||
from torch.distributed import ReduceOp
|
||||
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
@@ -12,8 +7,7 @@ from colossalai.initialize import launch
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.tensor.shape_consistency import CollectiveCommPattern, CommSpec
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||
from colossalai.testing import rerun_if_address_is_in_use
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
|
||||
|
||||
def check_all_gather(device_mesh, rank):
|
||||
@@ -218,8 +212,7 @@ def check_comm(rank, world_size, port):
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_comm_spec():
|
||||
world_size = 4
|
||||
run_func = partial(check_comm, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
spawn(check_comm, world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@@ -1,8 +1,5 @@
|
||||
from functools import partial
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
import colossalai
|
||||
from colossalai.tensor import (
|
||||
@@ -14,8 +11,7 @@ from colossalai.tensor import (
|
||||
ReplicaSpec,
|
||||
ShardSpec,
|
||||
)
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
from colossalai.utils.cuda import get_current_device
|
||||
from colossalai.zero import ColoInitContext
|
||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||
@@ -61,8 +57,7 @@ def run_colo_init_context(rank: int, world_size: int, port: int):
|
||||
@pytest.mark.parametrize('world_size', [1, 4])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_colo_init_context(world_size):
|
||||
run_func = partial(run_colo_init_context, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
spawn(run_colo_init_context, world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@@ -1,9 +1,6 @@
|
||||
from functools import partial
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.multiprocessing as mp
|
||||
from torch.distributed import ReduceOp
|
||||
|
||||
from colossalai.core import global_context as gpc
|
||||
@@ -12,8 +9,7 @@ from colossalai.initialize import launch
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.tensor.d_tensor.comm_spec import CollectiveCommPattern, CommSpec
|
||||
from colossalai.tensor.d_tensor.sharding_spec import ShardingSpec
|
||||
from colossalai.testing import rerun_if_address_is_in_use
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
|
||||
|
||||
def check_all_gather(process_groups_dict, rank):
|
||||
@@ -182,8 +178,7 @@ def check_comm(rank, world_size, port):
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_comm_spec():
|
||||
world_size = 4
|
||||
run_func = partial(check_comm, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
spawn(check_comm, world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@@ -1,7 +1,4 @@
|
||||
from functools import partial
|
||||
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.initialize import launch
|
||||
@@ -9,7 +6,7 @@ from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.tensor.d_tensor.d_tensor import DTensor, distribute_tensor
|
||||
from colossalai.tensor.d_tensor.layout import Layout
|
||||
from colossalai.tensor.d_tensor.sharding_spec import ShardingSpec
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
|
||||
|
||||
class TestModel(torch.nn.Module):
|
||||
@@ -92,10 +89,10 @@ def check_dtensor(rank, world_size, port):
|
||||
raise ValueError(f'rank {rank} is not in the device mesh')
|
||||
|
||||
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_dtensor():
|
||||
world_size = 4
|
||||
run_func = partial(check_dtensor, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
spawn(check_dtensor, world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@@ -1,9 +1,7 @@
|
||||
import math
|
||||
from functools import partial
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.initialize import launch
|
||||
@@ -12,8 +10,7 @@ from colossalai.tensor.d_tensor.comm_spec import CollectiveCommPattern
|
||||
from colossalai.tensor.d_tensor.layout import Layout
|
||||
from colossalai.tensor.d_tensor.layout_converter import LayoutConverter
|
||||
from colossalai.tensor.d_tensor.sharding_spec import DimSpec, ShardingSpec
|
||||
from colossalai.testing import rerun_if_address_is_in_use
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
|
||||
entire_shape = torch.Size((64, 32, 16))
|
||||
layout_converter = LayoutConverter()
|
||||
@@ -192,14 +189,9 @@ def check_layout_converting_apply(rank, world_size, port):
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_layout_converter():
|
||||
world_size = 4
|
||||
run_func = partial(check_one_step_transform, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
|
||||
run_func = partial(check_layout_converting, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
|
||||
run_func = partial(check_layout_converting_apply, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
spawn(check_one_step_transform, world_size)
|
||||
spawn(check_layout_converting, world_size)
|
||||
spawn(check_layout_converting_apply, world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@@ -1,8 +1,5 @@
|
||||
from functools import partial
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
@@ -11,7 +8,7 @@ from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.tensor.shape_consistency import CollectiveCommPattern, CommSpec
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||
from colossalai.tensor.utils import mix_gather_simulator
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
|
||||
|
||||
def check_mix_gather_S0S1(device_mesh, rank):
|
||||
@@ -323,10 +320,10 @@ def check_comm(rank, world_size, port):
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Skip because the check functions assume 8 GPUS but CI only have 4 GPUs")
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_mix_gather():
|
||||
world_size = 8
|
||||
run_func = partial(check_comm, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
spawn(check_comm, world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@@ -1,9 +1,10 @@
|
||||
from colossalai.tensor import ColoParameter, ColoTensor, ColoTensorSpec, ProcessGroup
|
||||
import torch
|
||||
import pytest
|
||||
import torch
|
||||
from common_utils import tensor_equal
|
||||
|
||||
import colossalai
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.tensor import ColoParameter, ColoTensor, ColoTensorSpec, ProcessGroup
|
||||
from colossalai.testing import free_port
|
||||
|
||||
|
||||
@pytest.mark.skip
|
||||
|
@@ -1,16 +1,12 @@
|
||||
from functools import partial
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.initialize import launch
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.tensor.shape_consistency import CollectiveCommPattern, ShapeConsistencyManager
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||
from colossalai.testing import rerun_if_address_is_in_use
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
|
||||
|
||||
def check_apply(rank, world_size, port):
|
||||
@@ -73,8 +69,7 @@ def check_apply(rank, world_size, port):
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_apply():
|
||||
world_size = 4
|
||||
run_func = partial(check_apply, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
spawn(check_apply, world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@@ -1,8 +1,5 @@
|
||||
from functools import partial
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn.functional as F
|
||||
|
||||
import colossalai
|
||||
@@ -10,8 +7,7 @@ from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.nn._ops._utils import gather_forward_split_backward
|
||||
from colossalai.tensor import ColoParameter, ColoTensor, ProcessGroup
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||
from colossalai.testing import rerun_if_address_is_in_use
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
@@ -229,8 +225,7 @@ def run_dist(rank, world_size, port):
|
||||
@pytest.mark.parametrize('world_size', [4])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_sharded_mlp(world_size):
|
||||
run_func = partial(run_dist, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
spawn(run_dist, world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@@ -1,15 +1,11 @@
|
||||
from functools import partial
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
|
||||
import colossalai
|
||||
from colossalai.amp import convert_to_apex_amp
|
||||
from colossalai.tensor import ColoTensor, ColoTensorSpec, ComputePattern, ComputeSpec, ProcessGroup, ShardSpec
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||
from colossalai.utils.cuda import get_current_device
|
||||
from colossalai.zero import ColoInitContext, GeminiAdamOptimizer, GeminiDDP, ZeroDDP
|
||||
from colossalai.zero.gemini import search_chunk_configuration
|
||||
@@ -140,8 +136,7 @@ def run_dist(rank, world_size, port):
|
||||
@pytest.mark.parametrize('world_size', [1, 4])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_gpt(world_size):
|
||||
run_func = partial(run_dist, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
spawn(run_dist, world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
Reference in New Issue
Block a user