[hotfix] fix unit test test_module_spec (#1321)

This commit is contained in:
HELSON
2022-07-15 14:02:32 +08:00
committed by GitHub
parent 9e4c6449b0
commit 1b41686461
3 changed files with 29 additions and 22 deletions

View File

@@ -1,11 +1,11 @@
from copy import copy
from copy import deepcopy
import pytest
from functools import partial
import torch
import torch.multiprocessing as mp
from colossalai.tensor import ColoTensorSpec, ComputePattern, ComputeSpec, ShardSpec, ReplicaSpec
from colossalai.tensor import ColoTensor, ComputePattern, ComputeSpec, ShardSpec, ColoTensorSpec
from colossalai.nn.parallel.layers import init_colo_module, check_colo_module
from _utils import tensor_equal, tensor_shard_equal, set_seed
@@ -112,21 +112,25 @@ def run_linear_with_spec(mode):
with ColoInitContext(device=get_current_device()):
model = torch.nn.Linear(4, 8)
model_handy = copy(model)
model_handy = deepcopy(model)
world_size = torch.distributed.get_world_size()
pg = ProcessGroup(tp_degree=world_size)
compute_spec = ComputeSpec(ComputePattern.TP1D)
init_colo_module(model, compute_spec, pg=pg, recursive=True, mode=mode)
x = torch.rand(2, 4).cuda()
colo_x = ColoTensor.from_torch_tensor(x, ColoTensorSpec(pg))
out = model(x)
colo_out = model_handy(x)
colo_out = model_handy(colo_x)
assert tensor_equal(out, colo_out)
grad = torch.rand_like(out)
out.backward(grad)
colo_out.backward(grad)
assert tensor_shard_equal(model.weight.grad, model_handy.weight.grad, pg.tp_local_rank(), pg.tp_world_size())
assert tensor_shard_equal(model.bias.grad, model_handy.bias.grad, pg.tp_local_rank(), pg.tp_world_size())
assert tensor_shard_equal(model_handy.weight.grad, model.weight.grad, pg.tp_local_rank(), pg.tp_world_size())
assert tensor_shard_equal(model_handy.bias.grad, model.bias.grad, pg.tp_local_rank(), pg.tp_world_size())
def run_check_shared_param():
@@ -196,7 +200,7 @@ def run_dist_check(rank, world_size, port):
@pytest.mark.dist
@pytest.mark.parametrize('world_size', [1, 4])
@pytest.mark.skip("under development lazy init ColoParameter in Context")
@pytest.mark.skip("for higher testing speed")
@rerun_if_address_is_in_use()
def test_module_linear_1d(world_size):
run_func = partial(run_dist, world_size=world_size, port=free_port())
@@ -205,7 +209,7 @@ def test_module_linear_1d(world_size):
@pytest.mark.dist
@pytest.mark.parametrize('world_size', [1, 4])
@pytest.mark.skip("under development lazy init ColoParameter in Context")
@pytest.mark.skip("for higher testing speed")
@rerun_if_address_is_in_use()
def test_module_model(world_size):
run_func = partial(run_dist_model, world_size=world_size, port=free_port())
@@ -214,7 +218,7 @@ def test_module_model(world_size):
@pytest.mark.dist
@pytest.mark.parametrize('world_size', [1, 2])
@pytest.mark.skip("under development lazy init ColoParameter in Context")
@pytest.mark.skip("for higher testing speed")
@rerun_if_address_is_in_use()
def test_module_check(world_size):
run_func = partial(run_dist_check, world_size=world_size, port=free_port())
@@ -222,4 +226,4 @@ def test_module_check(world_size):
if __name__ == '__main__':
test_module_check(2)
test_module_linear_1d(4)