[autoparallel] shard param and buffer as expected (#1753)

* [autoparallel] shard param and buffer as expected

* fix unit test issue
This commit is contained in:
YuliangLiu0306
2022-10-21 15:45:13 +08:00
committed by GitHub
parent cdb7d5e7d2
commit 980ed21723
6 changed files with 129 additions and 106 deletions

View File

@@ -1,15 +1,16 @@
from functools import partial
import pytest
import torch
import torch.multiprocessing as mp
from colossalai.tensor.sharding_spec import ShardingSpec
from colossalai.device.device_mesh import DeviceMesh
from colossalai.initialize import launch
from colossalai.utils import free_port
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.logging import disable_existing_loggers
from colossalai.tensor.shape_consistency import ShapeConsistencyManager, CollectiveCommPattern
from colossalai.tensor.shape_consistency import CollectiveCommPattern, ShapeConsistencyManager
from colossalai.tensor.sharding_spec import ShardingSpec
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port
def check_apply(rank, world_size, port):
@@ -63,7 +64,7 @@ def check_apply(rank, world_size, port):
tensor_to_check = torch.tensor([[1], [1], [3], [3]], dtype=tensor_to_comm.dtype).cuda()
tensor_to_comm.sharding_spec = sharding_spec_source
shape_consistency_manager.apply(tensor_to_comm, sharding_spec_target)
tensor_to_comm = shape_consistency_manager.apply(tensor_to_comm, sharding_spec_target)
assert tensor_to_comm.equal(tensor_to_check)
assert str(tensor_to_comm.sharding_spec.sharding_sequence) == str(sharding_spec_target.sharding_sequence)