[autoparallel] fixed wrong generated strategy for dot op (#1746)

* [autoparallel] fixed wrong generated strategy for dot op

* polish code
This commit is contained in:
Frank Lee
2022-10-20 15:18:16 +08:00
committed by GitHub
parent 993b8875b6
commit 8b8937d901
13 changed files with 187 additions and 116 deletions

View File

@@ -5,13 +5,13 @@ from .sharding import (
enumerate_all_possible_1d_sharding,
enumerate_all_possible_2d_sharding,
generate_sharding_size,
switch_partition_dim,
tranpose_partition_dim,
update_partition_dim,
)
__all__ = [
'BroadcastType', 'get_broadcast_shape', 'is_broadcastable', 'recover_sharding_spec_for_broadcast_shape',
'generate_resharding_costs', 'generate_sharding_spec', 'ignore_sharding_exception', 'check_sharding_spec_validity'
'switch_partition_dim', 'update_partition_dim', 'enumerate_all_possible_1d_sharding',
'tranpose_partition_dim', 'update_partition_dim', 'enumerate_all_possible_1d_sharding',
'enumerate_all_possible_2d_sharding', 'generate_sharding_size'
]

View File

@@ -36,9 +36,10 @@ def ignore_sharding_exception(func):
def check_sharding_spec_validity(sharding_spec: ShardingSpec, tensor: torch.Tensor):
"""
This function checks whether the ShardingSpec is valid for the physical tensor.
This check includes 2 items:
This check includes 3 items:
1. the sharding spec covers all dimensions of the physical tensor
2. the sharding spec for each dimension is divisible by the number of devices.
3. the sharding spec's entire shape must match the tensor shape
#
"""
# make sure all dims are covered in sharding spec
@@ -65,3 +66,6 @@ def check_sharding_spec_validity(sharding_spec: ShardingSpec, tensor: torch.Tens
assert dim_size >= num_devices and dim_size % num_devices == 0, \
f'The dimension at index {i} has value {dim_size}, but it is sharded over {num_devices} devices.'
# make sure the entire shape matches the physical tensor shape
assert sharding_spec.entire_shape == tensor.shape

View File

@@ -8,12 +8,12 @@ import torch
from colossalai.tensor.sharding_spec import ShardingSpec
__all__ = [
'switch_partition_dim', 'update_partition_dim', 'enumerate_all_possible_1d_sharding',
'tranpose_partition_dim', 'update_partition_dim', 'enumerate_all_possible_1d_sharding',
'enumerate_all_possible_2d_sharding', 'generate_sharding_size'
]
def switch_partition_dim(sharding_spec: ShardingSpec, dim1: int, dim2: int) -> ShardingSpec:
def tranpose_partition_dim(sharding_spec: ShardingSpec, dim1: int, dim2: int) -> ShardingSpec:
"""
Switch the sharding mesh dimensions for two tensor dimensions. This operation is in-place.
@@ -22,19 +22,26 @@ def switch_partition_dim(sharding_spec: ShardingSpec, dim1: int, dim2: int) -> S
dim1 (int): the tensor dimension to switch
dim2 (int): the tensor dimension to switch
"""
assert len(sharding_spec.entire_shape) == 2
assert len(sharding_spec.entire_shape) >= 2, \
'The entire_shape of the sharding spec must have at least 2 dimensions'
dim_partition_dict = sharding_spec.dim_partition_dict
# transpose the dim partition
dim1_partition = dim_partition_dict.pop(dim1, None)
dim2_partition = dim_partition_dict.pop(dim2, None)
if dim1_partition:
dim_partition_dict[dim2] = dim1_partition
if dim2_partition:
dim_partition_dict[dim1] = dim2_partition
# get the transposed shape
new_shape = list(sharding_spec.entire_shape[:])
new_shape[dim2], new_shape[dim1] = new_shape[dim1], new_shape[dim2]
new_shape = torch.Size(new_shape)
# re-init the sharding spec
sharding_spec.__init__(sharding_spec.device_mesh, sharding_spec.entire_shape, dim_partition_dict)
sharding_spec.__init__(sharding_spec.device_mesh, new_shape, dim_partition_dict)
return sharding_spec