mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-08 12:30:42 +00:00
[autoparallel] handled illegal strategy in node handler (#1743)
* [autoparallel] handled illegal strategy in node handler * polish code
This commit is contained in:
@@ -1,12 +1,17 @@
|
||||
from .broadcast import (BroadcastType, get_broadcast_shape, is_broadcastable, recover_sharding_spec_for_broadcast_shape)
|
||||
from .broadcast import BroadcastType, get_broadcast_shape, is_broadcastable, recover_sharding_spec_for_broadcast_shape
|
||||
from .factory import generate_resharding_costs, generate_sharding_spec
|
||||
from .misc import ignore_sharding_exception
|
||||
from .sharding import (enumerate_all_possible_1d_sharding, enumerate_all_possible_2d_sharding, generate_sharding_size,
|
||||
switch_partition_dim, update_partition_dim)
|
||||
from .misc import check_sharding_spec_validity, ignore_sharding_exception
|
||||
from .sharding import (
|
||||
enumerate_all_possible_1d_sharding,
|
||||
enumerate_all_possible_2d_sharding,
|
||||
generate_sharding_size,
|
||||
switch_partition_dim,
|
||||
update_partition_dim,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
'BroadcastType', 'get_broadcast_shape', 'is_broadcastable', 'recover_sharding_spec_for_broadcast_shape',
|
||||
'generate_resharding_costs', 'generate_sharding_spec', 'ignore_sharding_exception', 'switch_partition_dim',
|
||||
'update_partition_dim', 'enumerate_all_possible_1d_sharding', 'enumerate_all_possible_2d_sharding',
|
||||
'generate_sharding_size'
|
||||
'generate_resharding_costs', 'generate_sharding_spec', 'ignore_sharding_exception', 'check_sharding_spec_validity'
|
||||
'switch_partition_dim', 'update_partition_dim', 'enumerate_all_possible_1d_sharding',
|
||||
'enumerate_all_possible_2d_sharding', 'generate_sharding_size'
|
||||
]
|
||||
|
@@ -1,7 +1,9 @@
|
||||
import functools
|
||||
|
||||
import torch
|
||||
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.tensor.sharding_spec import ShardingSpecException
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec, ShardingSpecException
|
||||
|
||||
__all__ = ['ignore_sharding_exception']
|
||||
|
||||
@@ -29,3 +31,37 @@ def ignore_sharding_exception(func):
|
||||
return None
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def check_sharding_spec_validity(sharding_spec: ShardingSpec, tensor: torch.Tensor):
|
||||
"""
|
||||
This function checks whether the ShardingSpec is valid for the physical tensor.
|
||||
This check includes 2 items:
|
||||
1. the sharding spec covers all dimensions of the physical tensor
|
||||
2. the sharding spec for each dimension is divisible by the number of devices.
|
||||
#
|
||||
"""
|
||||
# make sure all dims are covered in sharding spec
|
||||
sharding_len = len(sharding_spec.sharding_sequence)
|
||||
tensor_num_dim = tensor.dim()
|
||||
num_devices_in_col = sharding_spec.device_mesh.mesh_shape[0]
|
||||
num_devices_in_row = sharding_spec.device_mesh.mesh_shape[1]
|
||||
assert sharding_len == tensor_num_dim, \
|
||||
f'The ShardingSpec ({sharding_spec.sharding_sequence}) is created for {sharding_len}-dimension tensor, but the given tensor is {tensor_num_dim}-dimension ({tensor.shape}).'
|
||||
|
||||
# make sure the sharding is valid for each dim
|
||||
for i in range(tensor_num_dim):
|
||||
dim_size = tensor.shape[i]
|
||||
dim_spec = sharding_spec.sharding_sequence[i]
|
||||
|
||||
if str(dim_spec).startswith('S'):
|
||||
devices_str = str(dim_spec).lstrip('S')
|
||||
num_devices = 1
|
||||
|
||||
if '0' in devices_str:
|
||||
num_devices *= num_devices_in_col
|
||||
if '1' in devices_str:
|
||||
num_devices *= num_devices_in_row
|
||||
|
||||
assert dim_size >= num_devices and dim_size % num_devices == 0, \
|
||||
f'The dimension at index {i} has value {dim_size}, but it is sharded over {num_devices} devices.'
|
||||
|
Reference in New Issue
Block a user