mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-18 16:00:49 +00:00
[autoparallel] fixed wrong generated strategy for dot op (#1746)
* [autoparallel] fixed wrong generated strategy for dot op * polish code
This commit is contained in:
@@ -5,13 +5,13 @@ from .sharding import (
|
||||
enumerate_all_possible_1d_sharding,
|
||||
enumerate_all_possible_2d_sharding,
|
||||
generate_sharding_size,
|
||||
switch_partition_dim,
|
||||
tranpose_partition_dim,
|
||||
update_partition_dim,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
'BroadcastType', 'get_broadcast_shape', 'is_broadcastable', 'recover_sharding_spec_for_broadcast_shape',
|
||||
'generate_resharding_costs', 'generate_sharding_spec', 'ignore_sharding_exception', 'check_sharding_spec_validity'
|
||||
'switch_partition_dim', 'update_partition_dim', 'enumerate_all_possible_1d_sharding',
|
||||
'tranpose_partition_dim', 'update_partition_dim', 'enumerate_all_possible_1d_sharding',
|
||||
'enumerate_all_possible_2d_sharding', 'generate_sharding_size'
|
||||
]
|
||||
|
@@ -36,9 +36,10 @@ def ignore_sharding_exception(func):
|
||||
def check_sharding_spec_validity(sharding_spec: ShardingSpec, tensor: torch.Tensor):
|
||||
"""
|
||||
This function checks whether the ShardingSpec is valid for the physical tensor.
|
||||
This check includes 2 items:
|
||||
This check includes 3 items:
|
||||
1. the sharding spec covers all dimensions of the physical tensor
|
||||
2. the sharding spec for each dimension is divisible by the number of devices.
|
||||
3. the sharding spec's entire shape must match the tensor shape
|
||||
#
|
||||
"""
|
||||
# make sure all dims are covered in sharding spec
|
||||
@@ -65,3 +66,6 @@ def check_sharding_spec_validity(sharding_spec: ShardingSpec, tensor: torch.Tens
|
||||
|
||||
assert dim_size >= num_devices and dim_size % num_devices == 0, \
|
||||
f'The dimension at index {i} has value {dim_size}, but it is sharded over {num_devices} devices.'
|
||||
|
||||
# make sure the entire shape matches the physical tensor shape
|
||||
assert sharding_spec.entire_shape == tensor.shape
|
||||
|
@@ -8,12 +8,12 @@ import torch
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||
|
||||
__all__ = [
|
||||
'switch_partition_dim', 'update_partition_dim', 'enumerate_all_possible_1d_sharding',
|
||||
'tranpose_partition_dim', 'update_partition_dim', 'enumerate_all_possible_1d_sharding',
|
||||
'enumerate_all_possible_2d_sharding', 'generate_sharding_size'
|
||||
]
|
||||
|
||||
|
||||
def switch_partition_dim(sharding_spec: ShardingSpec, dim1: int, dim2: int) -> ShardingSpec:
|
||||
def tranpose_partition_dim(sharding_spec: ShardingSpec, dim1: int, dim2: int) -> ShardingSpec:
|
||||
"""
|
||||
Switch the sharding mesh dimensions for two tensor dimensions. This operation is in-place.
|
||||
|
||||
@@ -22,19 +22,26 @@ def switch_partition_dim(sharding_spec: ShardingSpec, dim1: int, dim2: int) -> S
|
||||
dim1 (int): the tensor dimension to switch
|
||||
dim2 (int): the tensor dimension to switch
|
||||
"""
|
||||
assert len(sharding_spec.entire_shape) == 2
|
||||
assert len(sharding_spec.entire_shape) >= 2, \
|
||||
'The entire_shape of the sharding spec must have at least 2 dimensions'
|
||||
dim_partition_dict = sharding_spec.dim_partition_dict
|
||||
|
||||
# transpose the dim partition
|
||||
dim1_partition = dim_partition_dict.pop(dim1, None)
|
||||
dim2_partition = dim_partition_dict.pop(dim2, None)
|
||||
|
||||
if dim1_partition:
|
||||
dim_partition_dict[dim2] = dim1_partition
|
||||
|
||||
if dim2_partition:
|
||||
dim_partition_dict[dim1] = dim2_partition
|
||||
|
||||
# get the transposed shape
|
||||
new_shape = list(sharding_spec.entire_shape[:])
|
||||
new_shape[dim2], new_shape[dim1] = new_shape[dim1], new_shape[dim2]
|
||||
new_shape = torch.Size(new_shape)
|
||||
|
||||
# re-init the sharding spec
|
||||
sharding_spec.__init__(sharding_spec.device_mesh, sharding_spec.entire_shape, dim_partition_dict)
|
||||
sharding_spec.__init__(sharding_spec.device_mesh, new_shape, dim_partition_dict)
|
||||
return sharding_spec
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user