mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-12 12:47:21 +00:00
[autoparallel] shard param and buffer as expected (#1753)
* [autoparallel] shard param and buffer as expected * fix unit test issue
This commit is contained in:
@@ -1,17 +1,19 @@
|
||||
import torch
|
||||
from functools import partial
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.multiprocessing as mp
|
||||
from torch.distributed import ReduceOp
|
||||
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.initialize import launch
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.testing import rerun_if_address_is_in_use
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.tensor.shape_consistency import CommSpec, CollectiveCommPattern
|
||||
from colossalai.initialize import launch
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.tensor.shape_consistency import CollectiveCommPattern, CommSpec
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||
from colossalai.testing import rerun_if_address_is_in_use
|
||||
from colossalai.utils import free_port
|
||||
|
||||
|
||||
def check_all_gather(device_mesh, rank):
|
||||
@@ -37,7 +39,7 @@ def check_all_gather(device_mesh, rank):
|
||||
sharding_spec,
|
||||
gather_dim=1,
|
||||
logical_process_axis=1)
|
||||
comm_spec.covert_spec_to_action(sharded_tensor_to_comm)
|
||||
sharded_tensor_to_comm = sharded_tensor_to_comm = comm_spec.covert_spec_to_action(sharded_tensor_to_comm)
|
||||
|
||||
assert sharded_tensor_to_comm.equal(tensor_to_check)
|
||||
|
||||
@@ -60,7 +62,7 @@ def check_shard(device_mesh, rank):
|
||||
|
||||
# CommSpec:(comm_pattern:shard, shard_dim:1, logical_process_axis:1)
|
||||
comm_spec = CommSpec(CollectiveCommPattern.SPLIT_FWD_GATHER_BWD, sharding_spec, shard_dim=1, logical_process_axis=1)
|
||||
comm_spec.covert_spec_to_action(tensor_to_shard)
|
||||
tensor_to_shard = comm_spec.covert_spec_to_action(tensor_to_shard)
|
||||
|
||||
if rank in (0, 2):
|
||||
assert tensor_to_shard.equal(sharded_tensor_to_comm_0)
|
||||
@@ -110,7 +112,7 @@ def check_all_to_all(device_mesh, rank):
|
||||
gather_dim=0,
|
||||
shard_dim=1,
|
||||
logical_process_axis=0)
|
||||
comm_spec.covert_spec_to_action(tensor_to_comm)
|
||||
tensor_to_comm = comm_spec.covert_spec_to_action(tensor_to_comm)
|
||||
|
||||
assert tensor_to_comm.equal(tensor_to_check)
|
||||
|
||||
@@ -137,7 +139,7 @@ def check_all_reduce_fwd(device_mesh, rank):
|
||||
sharding_spec = ShardingSpec(device_mesh, tensor_to_comm.shape, dim_partition_dict=dim_partition_dict)
|
||||
|
||||
comm_spec = CommSpec(CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD, sharding_spec, logical_process_axis=0)
|
||||
comm_spec.covert_spec_to_action(tensor_to_comm)
|
||||
tensor_to_comm = comm_spec.covert_spec_to_action(tensor_to_comm)
|
||||
|
||||
assert tensor_to_comm.equal(tensor_to_check)
|
||||
|
||||
@@ -155,7 +157,7 @@ def check_all_reduce_bwd(device_mesh, rank):
|
||||
sharding_spec = ShardingSpec(device_mesh, tensor_to_comm.shape, dim_partition_dict=dim_partition_dict)
|
||||
|
||||
comm_spec = CommSpec(CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, sharding_spec, logical_process_axis=0)
|
||||
comm_spec.covert_spec_to_action(tensor_to_comm)
|
||||
tensor_to_comm = comm_spec.covert_spec_to_action(tensor_to_comm)
|
||||
|
||||
assert tensor_to_comm.equal(tensor_to_check)
|
||||
|
||||
@@ -178,7 +180,7 @@ def check_all_reduce_in_flatten_device_mesh(device_mesh, rank):
|
||||
|
||||
# CommSpec:(comm_pattern:all_reduce, logical_process_axis:[0, 1])
|
||||
comm_spec = CommSpec(CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD, sharding_spec, logical_process_axis=[0, 1])
|
||||
comm_spec.covert_spec_to_action(tensor_to_comm)
|
||||
tensor_to_comm = comm_spec.covert_spec_to_action(tensor_to_comm)
|
||||
|
||||
assert tensor_to_comm.equal(tensor_to_check)
|
||||
|
||||
|
Reference in New Issue
Block a user