mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 19:13:01 +00:00
[test] fixed tests failed due to dtensor change (#4082)
* [test] fixed tests failed due to dtensor change * polish code
This commit is contained in:
@@ -195,7 +195,7 @@ class ShardingSpec:
|
||||
def __repr__(self):
|
||||
res_list = ["DistSpec:"]
|
||||
res_list.append(f"\n\tshard_sequence: " + ",".join(str(dimspec) for dimspec in self.sharding_sequence))
|
||||
res_list.append(f"\n\tdevice_mesh_shape: {self.device_mesh.mesh_shape}")
|
||||
res_list.append(f"\n\tdevice_mesh_shape: {self.device_mesh.shape}")
|
||||
return ' '.join(res_list)
|
||||
|
||||
def _sanity_check(self):
|
||||
@@ -222,7 +222,7 @@ class ShardingSpec:
|
||||
num_devices = 1
|
||||
|
||||
for element in shard_list:
|
||||
num_devices *= self.device_mesh.mesh_shape[element]
|
||||
num_devices *= self.device_mesh.shape[element]
|
||||
|
||||
if tensor_dim_size % num_devices != 0:
|
||||
raise ShardingNotDivisibleError(
|
||||
@@ -288,7 +288,7 @@ class ShardingSpec:
|
||||
|
||||
sharded_shape = list(self.entire_shape)
|
||||
for dim, shard_list in self.dim_partition_dict.items():
|
||||
mesh_list = [self.device_mesh.mesh_shape[mesh_dim] for mesh_dim in shard_list]
|
||||
mesh_list = [self.device_mesh.shape[mesh_dim] for mesh_dim in shard_list]
|
||||
shard_partitions = reduce(operator.mul, mesh_list, 1)
|
||||
assert sharded_shape[
|
||||
dim] % shard_partitions == 0, f'Cannot shard dimension {dim} into {shard_partitions} partitions.'
|
||||
|
Reference in New Issue
Block a user