mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 19:13:01 +00:00
[autoparallel] fix bugs caused by negative dim key (#1808)
* [autoparallel] fix bugs caused by negative dim key * fix import error * fix matmul test issue * fix unit test issue
This commit is contained in:
@@ -6,6 +6,8 @@ import torch
|
||||
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
|
||||
from .utils import merge_same_dim_mesh_list
|
||||
|
||||
__all__ = ['_DimSpec', 'ShardingException', 'ShardingSpec']
|
||||
|
||||
ALLGATHER_COST = 20
|
||||
@@ -181,8 +183,12 @@ class ShardingSpec:
|
||||
self.dim_partition_dict = dim_partition_dict
|
||||
self.sharding_sequence = sharding_sequence
|
||||
if self.sharding_sequence is None:
|
||||
assert self.dim_partition_dict is not None, f'dim_partition_dict should not be None, if sharding_sequence is NoneType object.'
|
||||
self.dim_partition_dict = merge_same_dim_mesh_list(dim_size=len(entire_shape),
|
||||
dim_partition_dict=self.dim_partition_dict)
|
||||
self.convert_dict_to_shard_sequence()
|
||||
elif self.dim_partition_dict is None:
|
||||
assert self.sharding_sequence is not None, f'sharding_sequence should not be None, if dim_partition_dict is NoneType object.'
|
||||
self.convert_shard_sequence_to_dict()
|
||||
self._sanity_check()
|
||||
|
||||
|
Reference in New Issue
Block a user