[test] fixed tests failed due to dtensor change (#4082)

* [test] fixed tests failed due to dtensor change

* polish code
This commit is contained in:
Frank Lee
2023-06-26 15:50:07 +08:00
parent 92f6791095
commit c4b1b65931
37 changed files with 233 additions and 289 deletions

View File

@@ -285,7 +285,7 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
comm_pattern = CollectiveCommPattern.SPLIT_FWD_GATHER_BWD
# legal sharding dims means the mesh_id is still available to use.
legal_sharding_dims = [i for i in range(len(source_spec.device_mesh.mesh_shape))]
legal_sharding_dims = [i for i in range(len(source_spec.device_mesh.shape))]
for dim, shard_list in source_spec.dim_partition_dict.items():
for element in shard_list:
legal_sharding_dims.remove(element)
@@ -435,7 +435,7 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
"""
input_shape = compute_shape(comm_spec.sharding_spec)
input_numel = np.prod(input_shape)
output_numel = input_numel * comm_spec.device_mesh.mesh_shape[comm_spec.logical_process_axis]
output_numel = input_numel * comm_spec.device_mesh.shape[comm_spec.logical_process_axis]
peak_numel = max(peak_numel, alloc_numel + output_numel * 2)
alloc_numel += output_numel
if discard_input:
@@ -461,7 +461,7 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
# generate a new tensor
input_shape = compute_shape(comm_spec.sharding_spec)
input_numel = np.prod(input_shape)
output_numel = input_numel // comm_spec.device_mesh.mesh_shape[comm_spec.logical_process_axis]
output_numel = input_numel // comm_spec.device_mesh.shape[comm_spec.logical_process_axis]
alloc_numel += output_numel
peak_numel = max(peak_numel, alloc_numel)
if discard_input: