[tensor] add shape consistency feature to support auto spec transform (#1418)

* [tensor] add shape consistency feature to supportauto sharding spec transform.

* [tensor] remove unused argument in simulator, add doc string for target pair.
This commit is contained in:
YuliangLiu0306
2022-08-10 11:29:17 +08:00
committed by GitHub
parent 4fb3c52cf0
commit 33f0744d51
3 changed files with 424 additions and 5 deletions

View File

@@ -13,7 +13,7 @@ class _DimSpec:
'''
def __init__(self, shard_list):
self.is_replica = shard_list is None
self.is_replica = len(shard_list) == 0
self.shard_list = shard_list
def __eq__(self, other):
@@ -52,12 +52,16 @@ class ShardingSpec:
and the value of the key decribe which logical axis will be sharded in that dimension.
'''
def __init__(self, device_mesh, entire_shape, dim_partition_dict):
def __init__(self, device_mesh, entire_shape, dim_partition_dict=None, sharding_sequence=None):
self.device_mesh = device_mesh
self.entire_shape = entire_shape
self.dim_partition_dict = dim_partition_dict
self.sharding_sequence = sharding_sequence
if self.sharding_sequence is None:
self.convert_dict_to_shard_sequence()
elif self.dim_partition_dict is None:
self.convert_shard_sequence_to_dict()
self._sanity_check()
self.sharding_sequence = self.convert_dict_to_shard_sequence()
def __repr__(self):
res_list = ["DistSpec:"]
@@ -80,10 +84,19 @@ class ShardingSpec:
f"find an invalid sharding axis {element} in dim_partition_dict in tensor dimension {dim}.")
def convert_dict_to_shard_sequence(self):
sharding_sequence = [_DimSpec(None)] * len(self.entire_shape)
sharding_sequence = [_DimSpec([])] * len(self.entire_shape)
for dim, shard_list in self.dim_partition_dict.items():
sharding_sequence[dim] = _DimSpec(shard_list)
return sharding_sequence
self.sharding_sequence = sharding_sequence
def convert_shard_sequence_to_dict(self):
new_dim_partition_dict = {}
for index, dim_spec in enumerate(self.sharding_sequence):
if not dim_spec.is_replica:
if index not in new_dim_partition_dict:
new_dim_partition_dict[index] = []
new_dim_partition_dict[index].append(dim_spec.shard_list)
self.dim_partition_dict = new_dim_partition_dict
def sharding_sequence_difference(self, other):
'''