[test] refactor tests with spawn (#3452)

* [test] added spawn decorator

* polish code

* polish code

* polish code

* polish code

* polish code

* polish code
This commit is contained in:
Frank Lee
2023-04-06 14:51:35 +08:00
committed by GitHub
parent 62f4e2eb07
commit 80eba05b0a
240 changed files with 1723 additions and 2342 deletions

View File

@@ -1,14 +1,11 @@
import colossalai
import torch
import pytest
import torch
import torch.nn as nn
import torch.multiprocessing as mp
from colossalai.tensor import ColoTensor, ProcessGroup
from colossalai.tensor import ColoTensorSpec
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port
from functools import partial
from tests.test_tensor.common_utils import tensor_shard_equal, tensor_equal, split_param_row_tp1d, split_param_col_tp1d
import colossalai
from colossalai.tensor import ColoTensor, ColoTensorSpec, ProcessGroup
from colossalai.testing import rerun_if_address_is_in_use, spawn
from tests.test_tensor.common_utils import split_param_col_tp1d, split_param_row_tp1d, tensor_equal, tensor_shard_equal
class Conv1D(nn.Module):
@@ -69,8 +66,7 @@ def run_dist(rank, world_size, port):
@pytest.mark.parametrize('world_size', [1, 4])
@rerun_if_address_is_in_use()
def test_addmm_1d(world_size):
run_func = partial(run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
spawn(run_dist, world_size)
if __name__ == '__main__':

View File

@@ -1,14 +1,11 @@
from torch.nn import functional as F
from functools import partial
import colossalai
import pytest
import torch
import torch.multiprocessing as mp
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port
from torch.nn import functional as F
import colossalai
from colossalai.tensor import ColoParameter, ColoTensorSpec, ProcessGroup
from tests.test_tensor.common_utils import tensor_equal, tensor_shard_equal, split_param_col_tp1d
from colossalai.testing import rerun_if_address_is_in_use, spawn
from tests.test_tensor.common_utils import split_param_col_tp1d, tensor_equal, tensor_shard_equal
def run_with_spec(spec_init_func):
@@ -39,8 +36,7 @@ def run_dist(rank, world_size, port):
@pytest.mark.parametrize('world_size', [1, 4])
@rerun_if_address_is_in_use()
def test_embedding_bag_1d(world_size):
run_func = partial(run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
spawn(run_dist, world_size)
if __name__ == '__main__':

View File

@@ -1,14 +1,11 @@
from torch.nn import functional as F
from functools import partial
import colossalai
import pytest
import torch
import torch.multiprocessing as mp
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port
from colossalai.tensor import ColoTensorSpec, ProcessGroup, ColoTensor
from tests.test_tensor.common_utils import tensor_equal, tensor_shard_equal, split_param_col_tp1d, split_param_row_tp1d
from torch.nn import functional as F
import colossalai
from colossalai.tensor import ColoTensor, ColoTensorSpec, ProcessGroup
from colossalai.testing import rerun_if_address_is_in_use, spawn
from tests.test_tensor.common_utils import split_param_col_tp1d, split_param_row_tp1d, tensor_equal, tensor_shard_equal
def run_with_spec(spec_init_func, pg: ProcessGroup):
@@ -40,8 +37,7 @@ def run_dist(rank, world_size, port):
@pytest.mark.parametrize('world_size', [1, 4])
@rerun_if_address_is_in_use()
def test_embedding_1d(world_size):
run_func = partial(run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
spawn(run_dist, world_size)
if __name__ == '__main__':

View File

@@ -1,14 +1,11 @@
from functools import partial
import colossalai
import pytest
import torch
import torch.multiprocessing as mp
import torch.nn.functional as F
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port
from colossalai.tensor import ColoTensorSpec, ProcessGroup, ColoTensor
from tests.test_tensor.common_utils import tensor_equal, tensor_shard_equal, split_param_col_tp1d, split_param_row_tp1d
import colossalai
from colossalai.tensor import ColoTensor, ColoTensorSpec, ProcessGroup
from colossalai.testing import rerun_if_address_is_in_use, spawn
from tests.test_tensor.common_utils import split_param_col_tp1d, split_param_row_tp1d, tensor_equal, tensor_shard_equal
def run_with_spec(spec_init_func, split_bias):
@@ -44,8 +41,7 @@ def run_dist(rank, world_size, port):
@pytest.mark.parametrize('world_size', [1, 4])
@rerun_if_address_is_in_use()
def test_linear_1d(world_size):
run_func = partial(run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
spawn(run_dist, world_size)
if __name__ == '__main__':

View File

@@ -1,52 +1,48 @@
import torch
import pytest
import colossalai
import torch.nn.functional as F
import torch.multiprocessing as mp
from functools import partial
from colossalai.tensor import ColoTensor, ProcessGroup, ColoTensorSpec
from colossalai.utils import get_current_device
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port
from colossalai.tensor import ShardSpec, ComputeSpec, ComputePattern
def check_cross_entropy():
input_t = torch.randn(4, 4, device=get_current_device(), requires_grad=True)
input_ct = torch.randn(4, 4, device=get_current_device(), requires_grad=True)
with torch.no_grad():
input_ct.copy_(input_t)
target = torch.randint(4, (4,), dtype=torch.int64, device=get_current_device())
world_size = torch.distributed.get_world_size()
pg = ProcessGroup(tp_degree=world_size)
input_t_colo = ColoTensor.from_torch_tensor(tensor=input_ct, spec=ColoTensorSpec(pg))
input_shard = input_t_colo.redistribute(ShardSpec([-1], [pg.tp_world_size()]))
input_shard.set_tensor_spec(dist_spec=None, compute_spec=ComputeSpec(ComputePattern.TP1D))
output = F.cross_entropy(input_t, target)
output_colo = F.cross_entropy(input_shard, target)
assert torch.allclose(output_colo, output)
output.backward()
output_colo.backward()
assert torch.allclose(input_t.grad, input_ct.grad)
def run_dist(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
check_cross_entropy()
@pytest.mark.dist
@pytest.mark.parametrize('world_size', [1, 2])
@rerun_if_address_is_in_use()
def test_loss_func(world_size):
run_func = partial(run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_loss_func(1)
import pytest
import torch
import torch.nn.functional as F
import colossalai
from colossalai.tensor import ColoTensor, ColoTensorSpec, ComputePattern, ComputeSpec, ProcessGroup, ShardSpec
from colossalai.testing import rerun_if_address_is_in_use, spawn
from colossalai.utils import get_current_device
def check_cross_entropy():
input_t = torch.randn(4, 4, device=get_current_device(), requires_grad=True)
input_ct = torch.randn(4, 4, device=get_current_device(), requires_grad=True)
with torch.no_grad():
input_ct.copy_(input_t)
target = torch.randint(4, (4,), dtype=torch.int64, device=get_current_device())
world_size = torch.distributed.get_world_size()
pg = ProcessGroup(tp_degree=world_size)
input_t_colo = ColoTensor.from_torch_tensor(tensor=input_ct, spec=ColoTensorSpec(pg))
input_shard = input_t_colo.redistribute(ShardSpec([-1], [pg.tp_world_size()]))
input_shard.set_tensor_spec(dist_spec=None, compute_spec=ComputeSpec(ComputePattern.TP1D))
output = F.cross_entropy(input_t, target)
output_colo = F.cross_entropy(input_shard, target)
assert torch.allclose(output_colo, output)
output.backward()
output_colo.backward()
assert torch.allclose(input_t.grad, input_ct.grad)
def run_dist(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
check_cross_entropy()
@pytest.mark.dist
@pytest.mark.parametrize('world_size', [1, 2])
@rerun_if_address_is_in_use()
def test_loss_func(world_size):
spawn(run_dist, world_size)
if __name__ == '__main__':
test_loss_func(1)

View File

@@ -1,14 +1,12 @@
import torch
import pytest
import colossalai
import torch
import torch.nn.functional as F
import torch.multiprocessing as mp
from functools import partial
from colossalai.tensor import ColoTensor, ProcessGroup, ColoTensorSpec, ShardSpec
from colossalai.utils import get_current_device
from torch.nn import Parameter
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port
import colossalai
from colossalai.tensor import ColoTensor, ColoTensorSpec, ProcessGroup, ShardSpec
from colossalai.testing import rerun_if_address_is_in_use, spawn
from colossalai.utils import get_current_device
def _run_layer_norm():
@@ -66,8 +64,7 @@ def run_dist(rank, world_size, port):
@pytest.mark.parametrize('world_size', [2])
@rerun_if_address_is_in_use()
def test_element_wise_ops(world_size):
run_func = partial(run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
spawn(run_dist, world_size)
def run_dist2(rank, world_size, port):
@@ -79,8 +76,7 @@ def run_dist2(rank, world_size, port):
@pytest.mark.parametrize('world_size', [1])
@rerun_if_address_is_in_use()
def test_ln(world_size):
run_func = partial(run_dist2, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
spawn(run_dist2, world_size)
def check_all():

View File

@@ -1,100 +1,97 @@
from functools import partial
import colossalai
import pytest
import torch
import torch.multiprocessing as mp
import torch.distributed as dist
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port, get_current_device
from colossalai.tensor import ColoTensorSpec, ProcessGroup, ColoTensor, ShardSpec
from colossalai.tensor.distspec import DistPlacementPattern
from tests.test_tensor.common_utils import split_param_row_tp1d, split_param_col_tp1d, debug_print
def exam_view_core(pg):
# the case of replicated ColoTensors
x = torch.randn(4, 4).cuda()
x_colo = ColoTensor(x, ColoTensorSpec(pg))
y = x.view(2, -1, 2)
y_colo = x_colo.view(2, -1, 2)
assert torch.all(y == y_colo)
assert y_colo.dist_spec.placement == DistPlacementPattern.REPLICATE
# the perfect case of col-sliced ColoTensors
split_param_col_tp1d(x_colo, pg)
z = x.view(torch.Size((2, 1, 2, -1)))
z_colo = x_colo.view(torch.Size((2, 1, 2, -1)))
if dist.get_rank() == 0:
z = z[:, :, :, 0:2]
else:
z = z[:, :, :, 2:]
assert torch.all(z == z_colo)
assert z_colo.dist_spec == x_colo.dist_spec
# the perfect case of row-sliced ColoTensors
split_param_row_tp1d(x_colo, pg)
z = x.view(torch.Size((-1, 2, 2)))
z_colo = x_colo.view(torch.Size((-1, 2, 2)))
if dist.get_rank() == 0:
z = z[0:2, :, :]
else:
z = z[2:, :, :]
assert torch.all(z == z_colo)
assert z_colo.dist_spec == x_colo.dist_spec
# the normal case of row-sliced ColoTensors
z = x.view(-1, 2, 2, 2)
z_colo = x_colo.view(-1, 2, 2, 2)
assert torch.all(z == z_colo)
assert y_colo.dist_spec.placement == DistPlacementPattern.REPLICATE
def exam_view_autograd(pg):
x = torch.randn(8, 2, device=get_current_device(), requires_grad=True)
y = torch.randn(8, 2, device=get_current_device(), requires_grad=True)
with torch.no_grad():
y.copy_(x)
y = ColoTensor(y, ColoTensorSpec(pg))
y_slice = y.redistribute(ShardSpec([-1], [pg.tp_world_size()]))
xx = x.view(2, 2, -1)
yy_slice = y_slice.view(2, 2, -1)
yy = yy_slice.to_replicate()
grad = torch.randn(2, 2, 4, device=get_current_device())
xx.backward(grad)
yy.backward(grad)
assert torch.all(x.grad == y.grad)
def exam_view_errors(pg):
x = torch.randn(8, 2, device=get_current_device())
x = ColoTensor(x, ColoTensorSpec(pg))
split_param_row_tp1d(x, pg)
x.view('a', 'b', 'c')
x.view(8, -1)
x.view([-2, -2, -2])
x.view((-1, -1, -1))
def run_dist(rank, world_size, port):
colossalai.launch(config=dict(), rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
pg = ProcessGroup(tp_degree=torch.distributed.get_world_size())
exam_view_core(pg)
exam_view_autograd(pg)
# exam_view_errors(pg)
@pytest.mark.dist
@pytest.mark.parametrize('world_size', [2])
@rerun_if_address_is_in_use()
def test_view(world_size):
run_func = partial(run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_view(2)
import pytest
import torch
import torch.distributed as dist
import colossalai
from colossalai.tensor import ColoTensor, ColoTensorSpec, ProcessGroup, ShardSpec
from colossalai.tensor.distspec import DistPlacementPattern
from colossalai.testing import rerun_if_address_is_in_use, spawn
from colossalai.utils import get_current_device
from tests.test_tensor.common_utils import debug_print, split_param_col_tp1d, split_param_row_tp1d
def exam_view_core(pg):
# the case of replicated ColoTensors
x = torch.randn(4, 4).cuda()
x_colo = ColoTensor(x, ColoTensorSpec(pg))
y = x.view(2, -1, 2)
y_colo = x_colo.view(2, -1, 2)
assert torch.all(y == y_colo)
assert y_colo.dist_spec.placement == DistPlacementPattern.REPLICATE
# the perfect case of col-sliced ColoTensors
split_param_col_tp1d(x_colo, pg)
z = x.view(torch.Size((2, 1, 2, -1)))
z_colo = x_colo.view(torch.Size((2, 1, 2, -1)))
if dist.get_rank() == 0:
z = z[:, :, :, 0:2]
else:
z = z[:, :, :, 2:]
assert torch.all(z == z_colo)
assert z_colo.dist_spec == x_colo.dist_spec
# the perfect case of row-sliced ColoTensors
split_param_row_tp1d(x_colo, pg)
z = x.view(torch.Size((-1, 2, 2)))
z_colo = x_colo.view(torch.Size((-1, 2, 2)))
if dist.get_rank() == 0:
z = z[0:2, :, :]
else:
z = z[2:, :, :]
assert torch.all(z == z_colo)
assert z_colo.dist_spec == x_colo.dist_spec
# the normal case of row-sliced ColoTensors
z = x.view(-1, 2, 2, 2)
z_colo = x_colo.view(-1, 2, 2, 2)
assert torch.all(z == z_colo)
assert y_colo.dist_spec.placement == DistPlacementPattern.REPLICATE
def exam_view_autograd(pg):
x = torch.randn(8, 2, device=get_current_device(), requires_grad=True)
y = torch.randn(8, 2, device=get_current_device(), requires_grad=True)
with torch.no_grad():
y.copy_(x)
y = ColoTensor(y, ColoTensorSpec(pg))
y_slice = y.redistribute(ShardSpec([-1], [pg.tp_world_size()]))
xx = x.view(2, 2, -1)
yy_slice = y_slice.view(2, 2, -1)
yy = yy_slice.to_replicate()
grad = torch.randn(2, 2, 4, device=get_current_device())
xx.backward(grad)
yy.backward(grad)
assert torch.all(x.grad == y.grad)
def exam_view_errors(pg):
x = torch.randn(8, 2, device=get_current_device())
x = ColoTensor(x, ColoTensorSpec(pg))
split_param_row_tp1d(x, pg)
x.view('a', 'b', 'c')
x.view(8, -1)
x.view([-2, -2, -2])
x.view((-1, -1, -1))
def run_dist(rank, world_size, port):
colossalai.launch(config=dict(), rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
pg = ProcessGroup(tp_degree=torch.distributed.get_world_size())
exam_view_core(pg)
exam_view_autograd(pg)
# exam_view_errors(pg)
@pytest.mark.dist
@pytest.mark.parametrize('world_size', [2])
@rerun_if_address_is_in_use()
def test_view(world_size):
spawn(run_dist, world_size)
if __name__ == '__main__':
test_view(2)