[DTensor] refactor CommSpec (#3034)

This commit is contained in:
YuliangLiu0306
2023-03-08 10:45:31 +08:00
committed by GitHub
parent ea0b52c12e
commit 29386a54e6
3 changed files with 501 additions and 1 deletions

View File

@@ -171,7 +171,7 @@ class ShardingSpec:
raise ShardingOutOfIndexError(
f'sharding_sequence should have {self.dims} elements, but got index {len(self.sharding_sequence)}.')
if max(list(self.dim_partition_dict.keys())) >= self.dims:
if list(self.dim_partition_dict.keys()) and max(list(self.dim_partition_dict.keys())) >= self.dims:
raise ShardingOutOfIndexError(
f'the key of dim_partition_dict should be less than {self.dims}, but got {max(list(self.dim_partition_dict.keys()))}.'
)