[autoparallel] fix bugs caused by negative dim key (#1808)

* [autoparallel] fix bugs caused by negative dim key

* fix import error

* fix matmul test issue

* fix unit test issue
This commit is contained in:
YuliangLiu0306
2022-11-08 17:03:50 +08:00
committed by GitHub
parent 4268ae017b
commit 49216d7ab1
12 changed files with 108 additions and 43 deletions

View File

@@ -6,6 +6,8 @@ import torch
from colossalai.device.device_mesh import DeviceMesh
from .utils import merge_same_dim_mesh_list
__all__ = ['_DimSpec', 'ShardingException', 'ShardingSpec']
ALLGATHER_COST = 20
@@ -181,8 +183,12 @@ class ShardingSpec:
self.dim_partition_dict = dim_partition_dict
self.sharding_sequence = sharding_sequence
if self.sharding_sequence is None:
assert self.dim_partition_dict is not None, f'dim_partition_dict should not be None, if sharding_sequence is NoneType object.'
self.dim_partition_dict = merge_same_dim_mesh_list(dim_size=len(entire_shape),
dim_partition_dict=self.dim_partition_dict)
self.convert_dict_to_shard_sequence()
elif self.dim_partition_dict is None:
assert self.sharding_sequence is not None, f'sharding_sequence should not be None, if dim_partition_dict is NoneType object.'
self.convert_shard_sequence_to_dict()
self._sanity_check()