[autoparallel] handled illegal sharding strategy (#1728)

* [autoparallel] handled illegal sharding strategy

* polish code
This commit is contained in:
Frank Lee
2022-10-19 12:53:06 +08:00
committed by GitHub
parent cbe9a4cb45
commit eee84908d4
36 changed files with 459 additions and 303 deletions

View File

@@ -1,10 +1,12 @@
import torch
from colossalai.device.device_mesh import DeviceMesh
from colossalai.tensor.utils import all_gather_simulator, all_to_all_simulator, shard_simulator
import operator
from copy import deepcopy
from enum import Enum
from functools import reduce
import operator
import torch
from colossalai.device.device_mesh import DeviceMesh
from colossalai.tensor.utils import (all_gather_simulator, all_to_all_simulator, shard_simulator)
__all__ = ['_DimSpec', 'ShardingException', 'ShardingSpec']
@@ -138,7 +140,19 @@ class _DimSpec:
return difference
class ShardingException(Exception):
class ShardingSpecException(Exception):
pass
class ShardingOutOfIndexError(ShardingSpecException):
pass
class DuplicatedShardingDimensionError(ShardingSpecException):
pass
class ShardingNotDivisibleError(ShardingSpecException):
pass
@@ -156,7 +170,11 @@ class ShardingSpec:
sharding_sequence(List[_DimSpec], optional): A straight view of ShardingSpec looks like [R, R, S0, S1].
'''
def __init__(self, device_mesh, entire_shape, dim_partition_dict=None, sharding_sequence=None):
def __init__(self,
device_mesh: DeviceMesh,
entire_shape: torch.Size,
dim_partition_dict=None,
sharding_sequence=None):
self.device_mesh = device_mesh
self.entire_shape = entire_shape
self.dim_partition_dict = dim_partition_dict
@@ -174,19 +192,36 @@ class ShardingSpec:
return ' '.join(res_list)
def _sanity_check(self):
'''
In sanity check, we need make sure all axes in logical device mesh only be used
once.
'''
dim_check_list = [i for i in range(self.device_mesh.logical_mesh_id.dim())]
# make sure all axes in logical device mesh only be used once
dim_check_list = list(range(self.device_mesh.logical_mesh_id.dim()))
for dim, shard_list in self.dim_partition_dict.items():
for element in shard_list:
if element in dim_check_list:
dim_check_list.remove(element)
else:
raise ValueError(
raise DuplicatedShardingDimensionError(
f"find an invalid sharding axis {element} in dim_partition_dict in tensor dimension {dim}.")
# make sure that the dimension is not out of index
for dim in self.dim_partition_dict.keys():
if dim >= len(self.entire_shape):
raise ShardingOutOfIndexError(
f"The dim_partition_dict specifies to shard dimension {dim} but the entire_shape only has {len(self.entire_shape)} dimensions"
)
# make sure that the sharding for a dimension is divisible by the number of devices
for dim, shard_list in self.dim_partition_dict.items():
tensor_dim_size = self.entire_shape[dim]
num_devices = 1
for element in shard_list:
num_devices *= self.device_mesh.mesh_shape[element]
if tensor_dim_size % num_devices != 0:
raise ShardingNotDivisibleError(
f'The size of dimension at index {dim} is {tensor_dim_size}, it cannot be sharded over {num_devices} devices.'
)
def convert_dict_to_shard_sequence(self):
'''
Convert dim_partition_dict into list of _DimSpec, and assign it to sharding_sequence.