mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 02:51:59 +00:00
[autoparallel] handled illegal sharding strategy (#1728)
* [autoparallel] handled illegal sharding strategy * polish code
This commit is contained in:
@@ -1,10 +1,12 @@
|
||||
import torch
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.tensor.utils import all_gather_simulator, all_to_all_simulator, shard_simulator
|
||||
import operator
|
||||
from copy import deepcopy
|
||||
from enum import Enum
|
||||
from functools import reduce
|
||||
import operator
|
||||
|
||||
import torch
|
||||
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.tensor.utils import (all_gather_simulator, all_to_all_simulator, shard_simulator)
|
||||
|
||||
__all__ = ['_DimSpec', 'ShardingException', 'ShardingSpec']
|
||||
|
||||
@@ -138,7 +140,19 @@ class _DimSpec:
|
||||
return difference
|
||||
|
||||
|
||||
class ShardingException(Exception):
|
||||
class ShardingSpecException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class ShardingOutOfIndexError(ShardingSpecException):
|
||||
pass
|
||||
|
||||
|
||||
class DuplicatedShardingDimensionError(ShardingSpecException):
|
||||
pass
|
||||
|
||||
|
||||
class ShardingNotDivisibleError(ShardingSpecException):
|
||||
pass
|
||||
|
||||
|
||||
@@ -156,7 +170,11 @@ class ShardingSpec:
|
||||
sharding_sequence(List[_DimSpec], optional): A straight view of ShardingSpec looks like [R, R, S0, S1].
|
||||
'''
|
||||
|
||||
def __init__(self, device_mesh, entire_shape, dim_partition_dict=None, sharding_sequence=None):
|
||||
def __init__(self,
|
||||
device_mesh: DeviceMesh,
|
||||
entire_shape: torch.Size,
|
||||
dim_partition_dict=None,
|
||||
sharding_sequence=None):
|
||||
self.device_mesh = device_mesh
|
||||
self.entire_shape = entire_shape
|
||||
self.dim_partition_dict = dim_partition_dict
|
||||
@@ -174,19 +192,36 @@ class ShardingSpec:
|
||||
return ' '.join(res_list)
|
||||
|
||||
def _sanity_check(self):
|
||||
'''
|
||||
In sanity check, we need make sure all axes in logical device mesh only be used
|
||||
once.
|
||||
'''
|
||||
dim_check_list = [i for i in range(self.device_mesh.logical_mesh_id.dim())]
|
||||
# make sure all axes in logical device mesh only be used once
|
||||
dim_check_list = list(range(self.device_mesh.logical_mesh_id.dim()))
|
||||
for dim, shard_list in self.dim_partition_dict.items():
|
||||
for element in shard_list:
|
||||
if element in dim_check_list:
|
||||
dim_check_list.remove(element)
|
||||
else:
|
||||
raise ValueError(
|
||||
raise DuplicatedShardingDimensionError(
|
||||
f"find an invalid sharding axis {element} in dim_partition_dict in tensor dimension {dim}.")
|
||||
|
||||
# make sure that the dimension is not out of index
|
||||
for dim in self.dim_partition_dict.keys():
|
||||
if dim >= len(self.entire_shape):
|
||||
raise ShardingOutOfIndexError(
|
||||
f"The dim_partition_dict specifies to shard dimension {dim} but the entire_shape only has {len(self.entire_shape)} dimensions"
|
||||
)
|
||||
|
||||
# make sure that the sharding for a dimension is divisible by the number of devices
|
||||
for dim, shard_list in self.dim_partition_dict.items():
|
||||
tensor_dim_size = self.entire_shape[dim]
|
||||
num_devices = 1
|
||||
|
||||
for element in shard_list:
|
||||
num_devices *= self.device_mesh.mesh_shape[element]
|
||||
|
||||
if tensor_dim_size % num_devices != 0:
|
||||
raise ShardingNotDivisibleError(
|
||||
f'The size of dimension at index {dim} is {tensor_dim_size}, it cannot be sharded over {num_devices} devices.'
|
||||
)
|
||||
|
||||
def convert_dict_to_shard_sequence(self):
|
||||
'''
|
||||
Convert dim_partition_dict into list of _DimSpec, and assign it to sharding_sequence.
|
||||
|
Reference in New Issue
Block a user