mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-04 18:40:28 +00:00
[test] fixed tests failed due to dtensor change (#4082)
* [test] fixed tests failed due to dtensor change * polish code
This commit is contained in:
@@ -285,7 +285,7 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
|
||||
comm_pattern = CollectiveCommPattern.SPLIT_FWD_GATHER_BWD
|
||||
|
||||
# legal sharding dims means the mesh_id is still available to use.
|
||||
legal_sharding_dims = [i for i in range(len(source_spec.device_mesh.mesh_shape))]
|
||||
legal_sharding_dims = [i for i in range(len(source_spec.device_mesh.shape))]
|
||||
for dim, shard_list in source_spec.dim_partition_dict.items():
|
||||
for element in shard_list:
|
||||
legal_sharding_dims.remove(element)
|
||||
@@ -435,7 +435,7 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
|
||||
"""
|
||||
input_shape = compute_shape(comm_spec.sharding_spec)
|
||||
input_numel = np.prod(input_shape)
|
||||
output_numel = input_numel * comm_spec.device_mesh.mesh_shape[comm_spec.logical_process_axis]
|
||||
output_numel = input_numel * comm_spec.device_mesh.shape[comm_spec.logical_process_axis]
|
||||
peak_numel = max(peak_numel, alloc_numel + output_numel * 2)
|
||||
alloc_numel += output_numel
|
||||
if discard_input:
|
||||
@@ -461,7 +461,7 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
|
||||
# generate a new tensor
|
||||
input_shape = compute_shape(comm_spec.sharding_spec)
|
||||
input_numel = np.prod(input_shape)
|
||||
output_numel = input_numel // comm_spec.device_mesh.mesh_shape[comm_spec.logical_process_axis]
|
||||
output_numel = input_numel // comm_spec.device_mesh.shape[comm_spec.logical_process_axis]
|
||||
alloc_numel += output_numel
|
||||
peak_numel = max(peak_numel, alloc_numel)
|
||||
if discard_input:
|
||||
|
Reference in New Issue
Block a user