[shardformer] integrated linear 1D with dtensor (#3996)

* [shardformer] integrated linear 1D with dtensor

* polish code
This commit is contained in:
Frank Lee
2023-06-15 18:03:38 +08:00
parent d3bc530849
commit 015af592f8
9 changed files with 707 additions and 408 deletions

View File

@@ -34,7 +34,7 @@ class Layout:
def get_sharded_shape_per_device(self):
sharded_shape = list(self.entire_shape)
for dim, shard_list in self.sharding_spec.dim_partition_dict.items():
mesh_list = [self.device_mesh.mesh_shape[mesh_dim] for mesh_dim in shard_list]
mesh_list = [self.device_mesh.shape[mesh_dim] for mesh_dim in shard_list]
shard_partitions = reduce(operator.mul, mesh_list, 1)
assert sharded_shape[
dim] % shard_partitions == 0, f'Cannot shard dimension {dim} into {shard_partitions} partitions.'
@@ -45,14 +45,15 @@ class Layout:
sharding_spec = self.sharding_spec
# make sure all axes in logical device mesh only be used once
dim_check_list = list(range(self.device_mesh.logical_mesh_id.dim()))
for dim, shard_list in sharding_spec.dim_partition_dict.items():
for element in shard_list:
if element in dim_check_list:
dim_check_list.remove(element)
else:
raise DuplicatedShardingDimensionError(
f"find an invalid sharding axis {element} in dim_partition_dict in tensor dimension {dim}.")
if self.device_mesh.logical_mesh_id is not None:
dim_check_list = list(range(self.device_mesh.logical_mesh_id.dim()))
for dim, shard_list in sharding_spec.dim_partition_dict.items():
for element in shard_list:
if element in dim_check_list:
dim_check_list.remove(element)
else:
raise DuplicatedShardingDimensionError(
f"find an invalid sharding axis {element} in dim_partition_dict in tensor dimension {dim}.")
# make sure that the sharding for a dimension is divisible by the number of devices
for dim, shard_list in sharding_spec.dim_partition_dict.items():
@@ -60,7 +61,7 @@ class Layout:
num_devices = 1
for element in shard_list:
num_devices *= self.device_mesh.mesh_shape[element]
num_devices *= self.device_mesh.shape[element]
if tensor_dim_size % num_devices != 0:
raise ShardingNotDivisibleError(