mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-24 11:08:50 +00:00
[test] refactor tests with spawn (#3452)
* [test] added spawn decorator * polish code * polish code * polish code * polish code * polish code * polish code
This commit is contained in:
@@ -1,14 +1,11 @@
|
||||
import colossalai
|
||||
import torch
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.multiprocessing as mp
|
||||
from colossalai.tensor import ColoTensor, ProcessGroup
|
||||
from colossalai.tensor import ColoTensorSpec
|
||||
from colossalai.testing import rerun_if_address_is_in_use
|
||||
from colossalai.utils import free_port
|
||||
from functools import partial
|
||||
from tests.test_tensor.common_utils import tensor_shard_equal, tensor_equal, split_param_row_tp1d, split_param_col_tp1d
|
||||
|
||||
import colossalai
|
||||
from colossalai.tensor import ColoTensor, ColoTensorSpec, ProcessGroup
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
from tests.test_tensor.common_utils import split_param_col_tp1d, split_param_row_tp1d, tensor_equal, tensor_shard_equal
|
||||
|
||||
|
||||
class Conv1D(nn.Module):
|
||||
@@ -69,8 +66,7 @@ def run_dist(rank, world_size, port):
|
||||
@pytest.mark.parametrize('world_size', [1, 4])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_addmm_1d(world_size):
|
||||
run_func = partial(run_dist, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
spawn(run_dist, world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@@ -1,14 +1,11 @@
|
||||
from torch.nn import functional as F
|
||||
from functools import partial
|
||||
|
||||
import colossalai
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
from colossalai.testing import rerun_if_address_is_in_use
|
||||
from colossalai.utils import free_port
|
||||
from torch.nn import functional as F
|
||||
|
||||
import colossalai
|
||||
from colossalai.tensor import ColoParameter, ColoTensorSpec, ProcessGroup
|
||||
from tests.test_tensor.common_utils import tensor_equal, tensor_shard_equal, split_param_col_tp1d
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
from tests.test_tensor.common_utils import split_param_col_tp1d, tensor_equal, tensor_shard_equal
|
||||
|
||||
|
||||
def run_with_spec(spec_init_func):
|
||||
@@ -39,8 +36,7 @@ def run_dist(rank, world_size, port):
|
||||
@pytest.mark.parametrize('world_size', [1, 4])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_embedding_bag_1d(world_size):
|
||||
run_func = partial(run_dist, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
spawn(run_dist, world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@@ -1,14 +1,11 @@
|
||||
from torch.nn import functional as F
|
||||
from functools import partial
|
||||
|
||||
import colossalai
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
from colossalai.testing import rerun_if_address_is_in_use
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.tensor import ColoTensorSpec, ProcessGroup, ColoTensor
|
||||
from tests.test_tensor.common_utils import tensor_equal, tensor_shard_equal, split_param_col_tp1d, split_param_row_tp1d
|
||||
from torch.nn import functional as F
|
||||
|
||||
import colossalai
|
||||
from colossalai.tensor import ColoTensor, ColoTensorSpec, ProcessGroup
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
from tests.test_tensor.common_utils import split_param_col_tp1d, split_param_row_tp1d, tensor_equal, tensor_shard_equal
|
||||
|
||||
|
||||
def run_with_spec(spec_init_func, pg: ProcessGroup):
|
||||
@@ -40,8 +37,7 @@ def run_dist(rank, world_size, port):
|
||||
@pytest.mark.parametrize('world_size', [1, 4])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_embedding_1d(world_size):
|
||||
run_func = partial(run_dist, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
spawn(run_dist, world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@@ -1,14 +1,11 @@
|
||||
from functools import partial
|
||||
|
||||
import colossalai
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn.functional as F
|
||||
from colossalai.testing import rerun_if_address_is_in_use
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.tensor import ColoTensorSpec, ProcessGroup, ColoTensor
|
||||
from tests.test_tensor.common_utils import tensor_equal, tensor_shard_equal, split_param_col_tp1d, split_param_row_tp1d
|
||||
|
||||
import colossalai
|
||||
from colossalai.tensor import ColoTensor, ColoTensorSpec, ProcessGroup
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
from tests.test_tensor.common_utils import split_param_col_tp1d, split_param_row_tp1d, tensor_equal, tensor_shard_equal
|
||||
|
||||
|
||||
def run_with_spec(spec_init_func, split_bias):
|
||||
@@ -44,8 +41,7 @@ def run_dist(rank, world_size, port):
|
||||
@pytest.mark.parametrize('world_size', [1, 4])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_linear_1d(world_size):
|
||||
run_func = partial(run_dist, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
spawn(run_dist, world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@@ -1,52 +1,48 @@
|
||||
import torch
|
||||
import pytest
|
||||
import colossalai
|
||||
import torch.nn.functional as F
|
||||
import torch.multiprocessing as mp
|
||||
from functools import partial
|
||||
from colossalai.tensor import ColoTensor, ProcessGroup, ColoTensorSpec
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.testing import rerun_if_address_is_in_use
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.tensor import ShardSpec, ComputeSpec, ComputePattern
|
||||
|
||||
|
||||
def check_cross_entropy():
|
||||
input_t = torch.randn(4, 4, device=get_current_device(), requires_grad=True)
|
||||
input_ct = torch.randn(4, 4, device=get_current_device(), requires_grad=True)
|
||||
with torch.no_grad():
|
||||
input_ct.copy_(input_t)
|
||||
|
||||
target = torch.randint(4, (4,), dtype=torch.int64, device=get_current_device())
|
||||
|
||||
world_size = torch.distributed.get_world_size()
|
||||
pg = ProcessGroup(tp_degree=world_size)
|
||||
input_t_colo = ColoTensor.from_torch_tensor(tensor=input_ct, spec=ColoTensorSpec(pg))
|
||||
input_shard = input_t_colo.redistribute(ShardSpec([-1], [pg.tp_world_size()]))
|
||||
input_shard.set_tensor_spec(dist_spec=None, compute_spec=ComputeSpec(ComputePattern.TP1D))
|
||||
|
||||
output = F.cross_entropy(input_t, target)
|
||||
output_colo = F.cross_entropy(input_shard, target)
|
||||
assert torch.allclose(output_colo, output)
|
||||
|
||||
output.backward()
|
||||
output_colo.backward()
|
||||
|
||||
assert torch.allclose(input_t.grad, input_ct.grad)
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
check_cross_entropy()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize('world_size', [1, 2])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_loss_func(world_size):
|
||||
run_func = partial(run_dist, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_loss_func(1)
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
import colossalai
|
||||
from colossalai.tensor import ColoTensor, ColoTensorSpec, ComputePattern, ComputeSpec, ProcessGroup, ShardSpec
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
|
||||
def check_cross_entropy():
|
||||
input_t = torch.randn(4, 4, device=get_current_device(), requires_grad=True)
|
||||
input_ct = torch.randn(4, 4, device=get_current_device(), requires_grad=True)
|
||||
with torch.no_grad():
|
||||
input_ct.copy_(input_t)
|
||||
|
||||
target = torch.randint(4, (4,), dtype=torch.int64, device=get_current_device())
|
||||
|
||||
world_size = torch.distributed.get_world_size()
|
||||
pg = ProcessGroup(tp_degree=world_size)
|
||||
input_t_colo = ColoTensor.from_torch_tensor(tensor=input_ct, spec=ColoTensorSpec(pg))
|
||||
input_shard = input_t_colo.redistribute(ShardSpec([-1], [pg.tp_world_size()]))
|
||||
input_shard.set_tensor_spec(dist_spec=None, compute_spec=ComputeSpec(ComputePattern.TP1D))
|
||||
|
||||
output = F.cross_entropy(input_t, target)
|
||||
output_colo = F.cross_entropy(input_shard, target)
|
||||
assert torch.allclose(output_colo, output)
|
||||
|
||||
output.backward()
|
||||
output_colo.backward()
|
||||
|
||||
assert torch.allclose(input_t.grad, input_ct.grad)
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
check_cross_entropy()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize('world_size', [1, 2])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_loss_func(world_size):
|
||||
spawn(run_dist, world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_loss_func(1)
|
||||
|
@@ -1,14 +1,12 @@
|
||||
import torch
|
||||
import pytest
|
||||
import colossalai
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.multiprocessing as mp
|
||||
from functools import partial
|
||||
from colossalai.tensor import ColoTensor, ProcessGroup, ColoTensorSpec, ShardSpec
|
||||
from colossalai.utils import get_current_device
|
||||
from torch.nn import Parameter
|
||||
from colossalai.testing import rerun_if_address_is_in_use
|
||||
from colossalai.utils import free_port
|
||||
|
||||
import colossalai
|
||||
from colossalai.tensor import ColoTensor, ColoTensorSpec, ProcessGroup, ShardSpec
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
|
||||
def _run_layer_norm():
|
||||
@@ -66,8 +64,7 @@ def run_dist(rank, world_size, port):
|
||||
@pytest.mark.parametrize('world_size', [2])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_element_wise_ops(world_size):
|
||||
run_func = partial(run_dist, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
spawn(run_dist, world_size)
|
||||
|
||||
|
||||
def run_dist2(rank, world_size, port):
|
||||
@@ -79,8 +76,7 @@ def run_dist2(rank, world_size, port):
|
||||
@pytest.mark.parametrize('world_size', [1])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_ln(world_size):
|
||||
run_func = partial(run_dist2, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
spawn(run_dist2, world_size)
|
||||
|
||||
|
||||
def check_all():
|
||||
|
@@ -1,100 +1,97 @@
|
||||
from functools import partial
|
||||
|
||||
import colossalai
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.distributed as dist
|
||||
from colossalai.testing import rerun_if_address_is_in_use
|
||||
from colossalai.utils import free_port, get_current_device
|
||||
from colossalai.tensor import ColoTensorSpec, ProcessGroup, ColoTensor, ShardSpec
|
||||
from colossalai.tensor.distspec import DistPlacementPattern
|
||||
from tests.test_tensor.common_utils import split_param_row_tp1d, split_param_col_tp1d, debug_print
|
||||
|
||||
|
||||
def exam_view_core(pg):
|
||||
# the case of replicated ColoTensors
|
||||
x = torch.randn(4, 4).cuda()
|
||||
x_colo = ColoTensor(x, ColoTensorSpec(pg))
|
||||
|
||||
y = x.view(2, -1, 2)
|
||||
y_colo = x_colo.view(2, -1, 2)
|
||||
|
||||
assert torch.all(y == y_colo)
|
||||
assert y_colo.dist_spec.placement == DistPlacementPattern.REPLICATE
|
||||
# the perfect case of col-sliced ColoTensors
|
||||
split_param_col_tp1d(x_colo, pg)
|
||||
|
||||
z = x.view(torch.Size((2, 1, 2, -1)))
|
||||
z_colo = x_colo.view(torch.Size((2, 1, 2, -1)))
|
||||
if dist.get_rank() == 0:
|
||||
z = z[:, :, :, 0:2]
|
||||
else:
|
||||
z = z[:, :, :, 2:]
|
||||
assert torch.all(z == z_colo)
|
||||
assert z_colo.dist_spec == x_colo.dist_spec
|
||||
# the perfect case of row-sliced ColoTensors
|
||||
split_param_row_tp1d(x_colo, pg)
|
||||
|
||||
z = x.view(torch.Size((-1, 2, 2)))
|
||||
z_colo = x_colo.view(torch.Size((-1, 2, 2)))
|
||||
if dist.get_rank() == 0:
|
||||
z = z[0:2, :, :]
|
||||
else:
|
||||
z = z[2:, :, :]
|
||||
assert torch.all(z == z_colo)
|
||||
assert z_colo.dist_spec == x_colo.dist_spec
|
||||
# the normal case of row-sliced ColoTensors
|
||||
z = x.view(-1, 2, 2, 2)
|
||||
z_colo = x_colo.view(-1, 2, 2, 2)
|
||||
assert torch.all(z == z_colo)
|
||||
assert y_colo.dist_spec.placement == DistPlacementPattern.REPLICATE
|
||||
|
||||
|
||||
def exam_view_autograd(pg):
|
||||
x = torch.randn(8, 2, device=get_current_device(), requires_grad=True)
|
||||
y = torch.randn(8, 2, device=get_current_device(), requires_grad=True)
|
||||
with torch.no_grad():
|
||||
y.copy_(x)
|
||||
y = ColoTensor(y, ColoTensorSpec(pg))
|
||||
y_slice = y.redistribute(ShardSpec([-1], [pg.tp_world_size()]))
|
||||
|
||||
xx = x.view(2, 2, -1)
|
||||
yy_slice = y_slice.view(2, 2, -1)
|
||||
yy = yy_slice.to_replicate()
|
||||
grad = torch.randn(2, 2, 4, device=get_current_device())
|
||||
|
||||
xx.backward(grad)
|
||||
yy.backward(grad)
|
||||
assert torch.all(x.grad == y.grad)
|
||||
|
||||
|
||||
def exam_view_errors(pg):
|
||||
x = torch.randn(8, 2, device=get_current_device())
|
||||
x = ColoTensor(x, ColoTensorSpec(pg))
|
||||
split_param_row_tp1d(x, pg)
|
||||
|
||||
x.view('a', 'b', 'c')
|
||||
x.view(8, -1)
|
||||
x.view([-2, -2, -2])
|
||||
x.view((-1, -1, -1))
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
colossalai.launch(config=dict(), rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
pg = ProcessGroup(tp_degree=torch.distributed.get_world_size())
|
||||
exam_view_core(pg)
|
||||
exam_view_autograd(pg)
|
||||
# exam_view_errors(pg)
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize('world_size', [2])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_view(world_size):
|
||||
run_func = partial(run_dist, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_view(2)
|
||||
import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
import colossalai
|
||||
from colossalai.tensor import ColoTensor, ColoTensorSpec, ProcessGroup, ShardSpec
|
||||
from colossalai.tensor.distspec import DistPlacementPattern
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
from colossalai.utils import get_current_device
|
||||
from tests.test_tensor.common_utils import debug_print, split_param_col_tp1d, split_param_row_tp1d
|
||||
|
||||
|
||||
def exam_view_core(pg):
|
||||
# the case of replicated ColoTensors
|
||||
x = torch.randn(4, 4).cuda()
|
||||
x_colo = ColoTensor(x, ColoTensorSpec(pg))
|
||||
|
||||
y = x.view(2, -1, 2)
|
||||
y_colo = x_colo.view(2, -1, 2)
|
||||
|
||||
assert torch.all(y == y_colo)
|
||||
assert y_colo.dist_spec.placement == DistPlacementPattern.REPLICATE
|
||||
# the perfect case of col-sliced ColoTensors
|
||||
split_param_col_tp1d(x_colo, pg)
|
||||
|
||||
z = x.view(torch.Size((2, 1, 2, -1)))
|
||||
z_colo = x_colo.view(torch.Size((2, 1, 2, -1)))
|
||||
if dist.get_rank() == 0:
|
||||
z = z[:, :, :, 0:2]
|
||||
else:
|
||||
z = z[:, :, :, 2:]
|
||||
assert torch.all(z == z_colo)
|
||||
assert z_colo.dist_spec == x_colo.dist_spec
|
||||
# the perfect case of row-sliced ColoTensors
|
||||
split_param_row_tp1d(x_colo, pg)
|
||||
|
||||
z = x.view(torch.Size((-1, 2, 2)))
|
||||
z_colo = x_colo.view(torch.Size((-1, 2, 2)))
|
||||
if dist.get_rank() == 0:
|
||||
z = z[0:2, :, :]
|
||||
else:
|
||||
z = z[2:, :, :]
|
||||
assert torch.all(z == z_colo)
|
||||
assert z_colo.dist_spec == x_colo.dist_spec
|
||||
# the normal case of row-sliced ColoTensors
|
||||
z = x.view(-1, 2, 2, 2)
|
||||
z_colo = x_colo.view(-1, 2, 2, 2)
|
||||
assert torch.all(z == z_colo)
|
||||
assert y_colo.dist_spec.placement == DistPlacementPattern.REPLICATE
|
||||
|
||||
|
||||
def exam_view_autograd(pg):
|
||||
x = torch.randn(8, 2, device=get_current_device(), requires_grad=True)
|
||||
y = torch.randn(8, 2, device=get_current_device(), requires_grad=True)
|
||||
with torch.no_grad():
|
||||
y.copy_(x)
|
||||
y = ColoTensor(y, ColoTensorSpec(pg))
|
||||
y_slice = y.redistribute(ShardSpec([-1], [pg.tp_world_size()]))
|
||||
|
||||
xx = x.view(2, 2, -1)
|
||||
yy_slice = y_slice.view(2, 2, -1)
|
||||
yy = yy_slice.to_replicate()
|
||||
grad = torch.randn(2, 2, 4, device=get_current_device())
|
||||
|
||||
xx.backward(grad)
|
||||
yy.backward(grad)
|
||||
assert torch.all(x.grad == y.grad)
|
||||
|
||||
|
||||
def exam_view_errors(pg):
|
||||
x = torch.randn(8, 2, device=get_current_device())
|
||||
x = ColoTensor(x, ColoTensorSpec(pg))
|
||||
split_param_row_tp1d(x, pg)
|
||||
|
||||
x.view('a', 'b', 'c')
|
||||
x.view(8, -1)
|
||||
x.view([-2, -2, -2])
|
||||
x.view((-1, -1, -1))
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
colossalai.launch(config=dict(), rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
pg = ProcessGroup(tp_degree=torch.distributed.get_world_size())
|
||||
exam_view_core(pg)
|
||||
exam_view_autograd(pg)
|
||||
# exam_view_errors(pg)
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize('world_size', [2])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_view(world_size):
|
||||
spawn(run_dist, world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_view(2)
|
||||
|
Reference in New Issue
Block a user