[autoparallel] fixed wrong generated strategy for dot op (#1746)

* [autoparallel] fixed wrong generated strategy for dot op

* polish code
This commit is contained in:
Frank Lee
2022-10-20 15:18:16 +08:00
committed by GitHub
parent 993b8875b6
commit 8b8937d901
13 changed files with 187 additions and 116 deletions

View File

@@ -8,12 +8,12 @@ import torch
from colossalai.tensor.sharding_spec import ShardingSpec
__all__ = [
'switch_partition_dim', 'update_partition_dim', 'enumerate_all_possible_1d_sharding',
'tranpose_partition_dim', 'update_partition_dim', 'enumerate_all_possible_1d_sharding',
'enumerate_all_possible_2d_sharding', 'generate_sharding_size'
]
def switch_partition_dim(sharding_spec: ShardingSpec, dim1: int, dim2: int) -> ShardingSpec:
def tranpose_partition_dim(sharding_spec: ShardingSpec, dim1: int, dim2: int) -> ShardingSpec:
"""
Switch the sharding mesh dimensions for two tensor dimensions. This operation is in-place.
@@ -22,19 +22,26 @@ def switch_partition_dim(sharding_spec: ShardingSpec, dim1: int, dim2: int) -> S
dim1 (int): the tensor dimension to switch
dim2 (int): the tensor dimension to switch
"""
assert len(sharding_spec.entire_shape) == 2
assert len(sharding_spec.entire_shape) >= 2, \
'The entire_shape of the sharding spec must have at least 2 dimensions'
dim_partition_dict = sharding_spec.dim_partition_dict
# transpose the dim partition
dim1_partition = dim_partition_dict.pop(dim1, None)
dim2_partition = dim_partition_dict.pop(dim2, None)
if dim1_partition:
dim_partition_dict[dim2] = dim1_partition
if dim2_partition:
dim_partition_dict[dim1] = dim2_partition
# get the transposed shape
new_shape = list(sharding_spec.entire_shape[:])
new_shape[dim2], new_shape[dim1] = new_shape[dim1], new_shape[dim2]
new_shape = torch.Size(new_shape)
# re-init the sharding spec
sharding_spec.__init__(sharding_spec.device_mesh, sharding_spec.entire_shape, dim_partition_dict)
sharding_spec.__init__(sharding_spec.device_mesh, new_shape, dim_partition_dict)
return sharding_spec