mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 11:02:05 +00:00
[tensor] add shape consistency feature to support auto spec transform (#1418)
* [tensor] add shape consistency feature to supportauto sharding spec transform. * [tensor] remove unused argument in simulator, add doc string for target pair.
This commit is contained in:
@@ -13,7 +13,7 @@ class _DimSpec:
|
||||
'''
|
||||
|
||||
def __init__(self, shard_list):
|
||||
self.is_replica = shard_list is None
|
||||
self.is_replica = len(shard_list) == 0
|
||||
self.shard_list = shard_list
|
||||
|
||||
def __eq__(self, other):
|
||||
@@ -52,12 +52,16 @@ class ShardingSpec:
|
||||
and the value of the key decribe which logical axis will be sharded in that dimension.
|
||||
'''
|
||||
|
||||
def __init__(self, device_mesh, entire_shape, dim_partition_dict):
|
||||
def __init__(self, device_mesh, entire_shape, dim_partition_dict=None, sharding_sequence=None):
|
||||
self.device_mesh = device_mesh
|
||||
self.entire_shape = entire_shape
|
||||
self.dim_partition_dict = dim_partition_dict
|
||||
self.sharding_sequence = sharding_sequence
|
||||
if self.sharding_sequence is None:
|
||||
self.convert_dict_to_shard_sequence()
|
||||
elif self.dim_partition_dict is None:
|
||||
self.convert_shard_sequence_to_dict()
|
||||
self._sanity_check()
|
||||
self.sharding_sequence = self.convert_dict_to_shard_sequence()
|
||||
|
||||
def __repr__(self):
|
||||
res_list = ["DistSpec:"]
|
||||
@@ -80,10 +84,19 @@ class ShardingSpec:
|
||||
f"find an invalid sharding axis {element} in dim_partition_dict in tensor dimension {dim}.")
|
||||
|
||||
def convert_dict_to_shard_sequence(self):
|
||||
sharding_sequence = [_DimSpec(None)] * len(self.entire_shape)
|
||||
sharding_sequence = [_DimSpec([])] * len(self.entire_shape)
|
||||
for dim, shard_list in self.dim_partition_dict.items():
|
||||
sharding_sequence[dim] = _DimSpec(shard_list)
|
||||
return sharding_sequence
|
||||
self.sharding_sequence = sharding_sequence
|
||||
|
||||
def convert_shard_sequence_to_dict(self):
|
||||
new_dim_partition_dict = {}
|
||||
for index, dim_spec in enumerate(self.sharding_sequence):
|
||||
if not dim_spec.is_replica:
|
||||
if index not in new_dim_partition_dict:
|
||||
new_dim_partition_dict[index] = []
|
||||
new_dim_partition_dict[index].append(dim_spec.shard_list)
|
||||
self.dim_partition_dict = new_dim_partition_dict
|
||||
|
||||
def sharding_sequence_difference(self, other):
|
||||
'''
|
||||
|
Reference in New Issue
Block a user