[fix] fixed the collective pattern name for consistency (#1649)

* [fix] fixed the collective pattern name for consistency

* polish code
This commit is contained in:
Frank Lee
2022-09-26 16:39:37 +08:00
committed by GitHub
parent b2b2a4af98
commit 154d3ef432
2 changed files with 7 additions and 7 deletions

View File

@@ -136,7 +136,7 @@ def check_all_reduce_fwd(device_mesh, rank):
# device_mesh_shape: (2, 2)
sharding_spec = ShardingSpec(device_mesh, tensor_to_comm.shape, dim_partition_dict=dim_partition_dict)
comm_spec = CommSpec(CollectiveCommPattern.REDUCE_FWD_IDENTITY_BWD, sharding_spec, logical_process_axis=0)
comm_spec = CommSpec(CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD, sharding_spec, logical_process_axis=0)
comm_spec.covert_spec_to_action(tensor_to_comm)
assert tensor_to_comm.equal(tensor_to_check)
@@ -177,7 +177,7 @@ def check_all_reduce_in_flatten_device_mesh(device_mesh, rank):
sharding_spec = ShardingSpec(device_mesh, tensor_to_comm.shape, dim_partition_dict=dim_partition_dict)
# CommSpec:(comm_pattern:all_reduce, logical_process_axis:[0, 1])
comm_spec = CommSpec(CollectiveCommPattern.REDUCE_FWD_IDENTITY_BWD, sharding_spec, logical_process_axis=[0, 1])
comm_spec = CommSpec(CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD, sharding_spec, logical_process_axis=[0, 1])
comm_spec.covert_spec_to_action(tensor_to_comm)
assert tensor_to_comm.equal(tensor_to_check)