[autoparallel] shard param and buffer as expected (#1753)

* [autoparallel] shard param and buffer as expected

* fix unit test issue
This commit is contained in:
YuliangLiu0306
2022-10-21 15:45:13 +08:00
committed by GitHub
parent cdb7d5e7d2
commit 980ed21723
6 changed files with 129 additions and 106 deletions

View File

@@ -1,17 +1,19 @@
import torch
from functools import partial
import pytest
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.distributed import ReduceOp
from colossalai.core import global_context as gpc
from colossalai.initialize import launch
from colossalai.utils import free_port
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.device.device_mesh import DeviceMesh
from colossalai.tensor.shape_consistency import CommSpec, CollectiveCommPattern
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.tensor.shape_consistency import CollectiveCommPattern, CommSpec
from colossalai.tensor.sharding_spec import ShardingSpec
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port
def check_all_gather(device_mesh, rank):
@@ -37,7 +39,7 @@ def check_all_gather(device_mesh, rank):
sharding_spec,
gather_dim=1,
logical_process_axis=1)
comm_spec.covert_spec_to_action(sharded_tensor_to_comm)
sharded_tensor_to_comm = sharded_tensor_to_comm = comm_spec.covert_spec_to_action(sharded_tensor_to_comm)
assert sharded_tensor_to_comm.equal(tensor_to_check)
@@ -60,7 +62,7 @@ def check_shard(device_mesh, rank):
# CommSpec:(comm_pattern:shard, shard_dim:1, logical_process_axis:1)
comm_spec = CommSpec(CollectiveCommPattern.SPLIT_FWD_GATHER_BWD, sharding_spec, shard_dim=1, logical_process_axis=1)
comm_spec.covert_spec_to_action(tensor_to_shard)
tensor_to_shard = comm_spec.covert_spec_to_action(tensor_to_shard)
if rank in (0, 2):
assert tensor_to_shard.equal(sharded_tensor_to_comm_0)
@@ -110,7 +112,7 @@ def check_all_to_all(device_mesh, rank):
gather_dim=0,
shard_dim=1,
logical_process_axis=0)
comm_spec.covert_spec_to_action(tensor_to_comm)
tensor_to_comm = comm_spec.covert_spec_to_action(tensor_to_comm)
assert tensor_to_comm.equal(tensor_to_check)
@@ -137,7 +139,7 @@ def check_all_reduce_fwd(device_mesh, rank):
sharding_spec = ShardingSpec(device_mesh, tensor_to_comm.shape, dim_partition_dict=dim_partition_dict)
comm_spec = CommSpec(CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD, sharding_spec, logical_process_axis=0)
comm_spec.covert_spec_to_action(tensor_to_comm)
tensor_to_comm = comm_spec.covert_spec_to_action(tensor_to_comm)
assert tensor_to_comm.equal(tensor_to_check)
@@ -155,7 +157,7 @@ def check_all_reduce_bwd(device_mesh, rank):
sharding_spec = ShardingSpec(device_mesh, tensor_to_comm.shape, dim_partition_dict=dim_partition_dict)
comm_spec = CommSpec(CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, sharding_spec, logical_process_axis=0)
comm_spec.covert_spec_to_action(tensor_to_comm)
tensor_to_comm = comm_spec.covert_spec_to_action(tensor_to_comm)
assert tensor_to_comm.equal(tensor_to_check)
@@ -178,7 +180,7 @@ def check_all_reduce_in_flatten_device_mesh(device_mesh, rank):
# CommSpec:(comm_pattern:all_reduce, logical_process_axis:[0, 1])
comm_spec = CommSpec(CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD, sharding_spec, logical_process_axis=[0, 1])
comm_spec.covert_spec_to_action(tensor_to_comm)
tensor_to_comm = comm_spec.covert_spec_to_action(tensor_to_comm)
assert tensor_to_comm.equal(tensor_to_check)