mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-09 13:00:52 +00:00
[autoparallel] refactored the autoparallel module for organization (#1706)
* [autoparallel] refactored the autoparallel module for organization * polish code
This commit is contained in:
12
colossalai/auto_parallel/tensor_shard/utils/__init__.py
Normal file
12
colossalai/auto_parallel/tensor_shard/utils/__init__.py
Normal file
@@ -0,0 +1,12 @@
|
||||
from .broadcast import (BroadcastType, get_broadcast_shape, is_broadcastable, recover_sharding_spec_for_broadcast_shape)
|
||||
from .factory import generate_resharding_costs, generate_sharding_spec
|
||||
from .misc import exception_handler
|
||||
from .sharding import (enumerate_all_possible_1d_sharding, enumerate_all_possible_2d_sharding, generate_sharding_size,
|
||||
switch_partition_dim, update_partition_dim)
|
||||
|
||||
__all__ = [
|
||||
'BroadcastType', 'get_broadcast_shape', 'is_broadcastable', 'recover_sharding_spec_for_broadcast_shape',
|
||||
'generate_resharding_costs', 'generate_sharding_spec', 'exception_handler', 'switch_partition_dim',
|
||||
'update_partition_dim', 'enumerate_all_possible_1d_sharding', 'enumerate_all_possible_2d_sharding',
|
||||
'generate_sharding_size'
|
||||
]
|
96
colossalai/auto_parallel/tensor_shard/utils/broadcast.py
Normal file
96
colossalai/auto_parallel/tensor_shard/utils/broadcast.py
Normal file
@@ -0,0 +1,96 @@
|
||||
import torch
|
||||
from enum import Enum, auto
|
||||
from typing import List
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||
|
||||
__all__ = ['BroadcastType', 'is_broadcastable', 'get_broadcast_shape', 'recover_sharding_spec_for_broadcast_shape']
|
||||
|
||||
|
||||
class BroadcastType(Enum):
|
||||
EQUAL = auto()
|
||||
PADDDING = auto()
|
||||
MULTIPLE = auto()
|
||||
|
||||
|
||||
def is_broadcastable(shape1: torch.Size, shape2: torch.Size) -> bool:
|
||||
"""
|
||||
Check if two shapes are broadcastable to each other.
|
||||
"""
|
||||
for s1, s2 in zip(shape1[::-1], shape2[::-1]):
|
||||
if s1 == 1 or s2 == 1 or s1 == s2:
|
||||
pass
|
||||
else:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def get_broadcast_shape(shape1: torch.Size, shape2: torch.Size) -> List[int]:
|
||||
"""
|
||||
Compute the broadcast shape given two shapes.
|
||||
"""
|
||||
assert is_broadcastable(shape1, shape2), f'{shape1} and {shape2} are not broadcastable'
|
||||
shape1_reverse = shape1[::-1]
|
||||
shape2_reverse = shape2[::-1]
|
||||
min_common_dim = min(len(shape1), len(shape2))
|
||||
dims = []
|
||||
for s1, s2 in zip(shape1_reverse, shape2_reverse):
|
||||
dims.append(max(s1, s2))
|
||||
|
||||
# append the remaining dims
|
||||
dims.extend(shape1_reverse[min_common_dim:])
|
||||
dims.extend(shape2_reverse[min_common_dim:])
|
||||
return dims[::-1]
|
||||
|
||||
|
||||
def recover_sharding_spec_for_broadcast_shape(logical_sharding_spec: ShardingSpec, logical_shape: torch.Size,
|
||||
physical_shape: torch.Size) -> ShardingSpec:
|
||||
"""
|
||||
This function computes the sharding spec for the physical shape of a broadcast tensor.
|
||||
|
||||
Args:
|
||||
logical_sharding_spec (ShardingSpec): the sharding spec for the broadcast tensor
|
||||
logical_shape (torch.Size): logical shape is the broadcast shape of a tensor
|
||||
physical_shape (torch.Size): the shape of the tensor before broadcasting
|
||||
"""
|
||||
# get the number of dimensions
|
||||
logical_num_dims = len(logical_shape)
|
||||
physical_num_dims = len(physical_shape)
|
||||
|
||||
# track the dim and its broadcasting type
|
||||
logical_dim_broadcast_info = {}
|
||||
|
||||
for i in range(logical_num_dims):
|
||||
# get the trailing dim size
|
||||
logical_dim_idx = logical_num_dims - i - 1
|
||||
phyiscal_dim_idx = physical_num_dims - i - 1
|
||||
logical_dim_size = logical_shape[logical_dim_idx]
|
||||
|
||||
if phyiscal_dim_idx >= 0:
|
||||
physical_dim_size = physical_shape[phyiscal_dim_idx]
|
||||
|
||||
if physical_dim_size == logical_dim_size:
|
||||
logical_dim_broadcast_info[logical_dim_idx] = BroadcastType.EQUAL
|
||||
elif physical_dim_size == 1 and physical_dim_size != logical_dim_size:
|
||||
logical_dim_broadcast_info[logical_dim_idx] = BroadcastType.MULTIPLE
|
||||
else:
|
||||
logical_dim_broadcast_info[logical_dim_idx] = BroadcastType.PADDDING
|
||||
|
||||
# generate the sharding spec for the physical shape
|
||||
physical_dim_partition = {}
|
||||
logical_dim_partition = logical_sharding_spec.dim_partition_dict
|
||||
|
||||
for shape_dim, mesh_dim in logical_dim_partition.items():
|
||||
logical_broadcast_type = logical_dim_broadcast_info[shape_dim]
|
||||
|
||||
if logical_broadcast_type == BroadcastType.PADDDING or logical_broadcast_type == BroadcastType.MULTIPLE:
|
||||
pass
|
||||
else:
|
||||
# get the corresponding physical dim
|
||||
physical_dim = physical_num_dims - (logical_num_dims - shape_dim)
|
||||
physical_dim_partition[physical_dim] = mesh_dim
|
||||
|
||||
physical_sharding_spec = ShardingSpec(device_mesh=logical_sharding_spec.device_mesh,
|
||||
entire_shape=physical_shape,
|
||||
dim_partition_dict=physical_dim_partition)
|
||||
|
||||
return physical_sharding_spec
|
90
colossalai/auto_parallel/tensor_shard/utils/factory.py
Normal file
90
colossalai/auto_parallel/tensor_shard/utils/factory.py
Normal file
@@ -0,0 +1,90 @@
|
||||
import operator
|
||||
import warnings
|
||||
from functools import reduce
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||
from torch.fx.node import Node
|
||||
|
||||
from ..constants import INFINITY_COST
|
||||
|
||||
__all__ = ['generate_sharding_spec', 'generate_resharding_costs']
|
||||
|
||||
|
||||
def generate_sharding_spec(input_: Union[Node, torch.Tensor], device_mesh: DeviceMesh,
|
||||
dim_partition_dict: Dict[int, List[int]]) -> ShardingSpec:
|
||||
"""
|
||||
Generate the sharding spec of the tensor based on the given dim_partition_dict.
|
||||
|
||||
|
||||
Args:
|
||||
input_ (Union[Node, torch.Tensor]): the input can be a Node object or a PyTorch tensor. If a node is used, it will look for its meta data associated with this node.
|
||||
device_mesh (DeviceMesh): a DeviceMesh object which contains the meta information about the cluster.
|
||||
dim_partition_dict (Dict[int, List[int]]): a dictionary to specify the sharding specs, the key is the tensor dimension and the value is the mesh dimension for sharding.
|
||||
"""
|
||||
|
||||
if isinstance(input_, Node):
|
||||
assert hasattr(input_, '_meta_data'), f'The given node has no attribte _meta_data'
|
||||
meta_tensor = input_._meta_data
|
||||
assert meta_tensor is not None, "The given node's _meta_data attribute is None"
|
||||
shape = meta_tensor.shape
|
||||
elif isinstance(input_, torch.Tensor):
|
||||
shape = input_.shape
|
||||
else:
|
||||
raise TypeError(
|
||||
f'We cannot generate sharding spec for {type(input_)} type, only torch.fx.Node or torch.Tensor is expected.'
|
||||
)
|
||||
for dim_index, sharding_index_list in dim_partition_dict.items():
|
||||
sharding_list = [device_mesh.mesh_shape[sharding_index] for sharding_index in sharding_index_list]
|
||||
sharding_size = reduce(operator.mul, sharding_list, 1)
|
||||
assert shape[
|
||||
dim_index] % sharding_size == 0, f'we cannot shard the {dim_index} dimension of tensor into {sharding_size} partitions.'
|
||||
|
||||
sharding_spec = ShardingSpec(device_mesh=device_mesh, entire_shape=shape, dim_partition_dict=dim_partition_dict)
|
||||
return sharding_spec
|
||||
|
||||
|
||||
def generate_resharding_costs(nodes: List[Node],
|
||||
sharding_specs: List[ShardingSpec],
|
||||
count_backward: Optional[bool] = True,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
index=None):
|
||||
'''
|
||||
Compute the resharding costs with this specific strategy.
|
||||
|
||||
Argument:
|
||||
nodes (List[Node]): a list of nodes
|
||||
sharding_spec_for_input(ShardingSpec): a list of ShardingSpec for the nodes.
|
||||
count_backward (Optional[bool]): whether to include the cost of resharding in the backward pass, default is True. False can be used for inference.
|
||||
dtype (Optional[torch.dtype]): the data type for cost calculation, default is None.
|
||||
'''
|
||||
# The resharding_cost of weight is counted due to sharing weight cases.
|
||||
resharding_costs = {}
|
||||
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
|
||||
|
||||
# shape consistency manager is a singleton class
|
||||
shape_consistency_manager = ShapeConsistencyManager()
|
||||
|
||||
for input_node, input_spec in zip(nodes, sharding_specs):
|
||||
resharding_costs[input_node] = []
|
||||
for strategy in input_node.strategies_vector:
|
||||
input_sharding_spec = strategy.output_sharding_spec
|
||||
if not isinstance(input_sharding_spec, ShardingSpec):
|
||||
assert isinstance(input_sharding_spec, list), 'only ShardingSpec or List[ShardingSpec] is expected.'
|
||||
input_sharding_spec = input_sharding_spec[index]
|
||||
assert isinstance(input_sharding_spec, ShardingSpec), f'The input node should NOT be a tuple of tensor.'
|
||||
try:
|
||||
# compute the resharding cost
|
||||
_, _, total_resharding_cost = shape_consistency_manager.shape_consistency(
|
||||
input_sharding_spec, input_spec)
|
||||
|
||||
# we need multiply the size of elem dtype to get correct communication cost
|
||||
resharding_cost = total_resharding_cost["total"] * size_per_elem_bytes
|
||||
except AssertionError as e:
|
||||
warnings.warn(f'{e}')
|
||||
resharding_cost = INFINITY_COST
|
||||
resharding_costs[input_node].append(resharding_cost)
|
||||
return resharding_costs
|
26
colossalai/auto_parallel/tensor_shard/utils/misc.py
Normal file
26
colossalai/auto_parallel/tensor_shard/utils/misc.py
Normal file
@@ -0,0 +1,26 @@
|
||||
import functools
|
||||
import warnings
|
||||
|
||||
__all__ = ['exception_handler']
|
||||
|
||||
|
||||
def exception_handler(func):
|
||||
"""
|
||||
A function wrapper to handle the AssertionError in the function.
|
||||
|
||||
Usage:
|
||||
# mute the assertion error in the function
|
||||
@exception_handler
|
||||
def do_something():
|
||||
...
|
||||
"""
|
||||
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
try:
|
||||
rst = func(*args, **kwargs)
|
||||
return rst
|
||||
except AssertionError as e:
|
||||
warnings.warn(f'{e}')
|
||||
|
||||
return wrapper
|
113
colossalai/auto_parallel/tensor_shard/utils/sharding.py
Normal file
113
colossalai/auto_parallel/tensor_shard/utils/sharding.py
Normal file
@@ -0,0 +1,113 @@
|
||||
import operator
|
||||
from copy import deepcopy
|
||||
from functools import reduce
|
||||
from typing import Dict
|
||||
|
||||
import torch
|
||||
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||
|
||||
__all__ = [
|
||||
'switch_partition_dim', 'update_partition_dim', 'enumerate_all_possible_1d_sharding',
|
||||
'enumerate_all_possible_2d_sharding', 'generate_sharding_size'
|
||||
]
|
||||
|
||||
|
||||
def switch_partition_dim(sharding_spec: ShardingSpec, dim1: int, dim2: int) -> ShardingSpec:
|
||||
"""
|
||||
Switch the sharding mesh dimensions for two tensor dimensions. This operation is in-place.
|
||||
|
||||
Args:
|
||||
sharding_spec (ShardingSpec): the sharding spec for which partition dim are switched
|
||||
dim1 (int): the tensor dimension to switch
|
||||
dim2 (int): the tensor dimension to switch
|
||||
"""
|
||||
assert len(sharding_spec.entire_shape) == 2
|
||||
dim_partition_dict = sharding_spec.dim_partition_dict
|
||||
dim1_partition = dim_partition_dict.pop(dim1, None)
|
||||
dim2_partition = dim_partition_dict.pop(dim2, None)
|
||||
|
||||
if dim1_partition:
|
||||
dim_partition_dict[dim2] = dim1_partition
|
||||
|
||||
if dim2_partition:
|
||||
dim_partition_dict[dim1] = dim2_partition
|
||||
|
||||
# re-init the sharding spec
|
||||
sharding_spec.__init__(sharding_spec.device_mesh, sharding_spec.entire_shape, dim_partition_dict)
|
||||
return sharding_spec
|
||||
|
||||
|
||||
def update_partition_dim(sharding_spec: ShardingSpec,
|
||||
dim_mapping: Dict[int, int],
|
||||
physical_shape: torch.Size,
|
||||
inplace: bool = False):
|
||||
"""
|
||||
This method is used to update the partition dim dict from the logical one to the physical one.
|
||||
|
||||
Args:
|
||||
sharding_spec (ShardingSpec): the sharding spec for which partition dims are updated
|
||||
dim_mapping (Dict[int, int]): the mapping from the logical tensor dimension to the physical tensor dimension
|
||||
physical_shape (torch.Size): the physical shape for the tensor
|
||||
"""
|
||||
|
||||
if inplace:
|
||||
current_sharding_spec = sharding_spec
|
||||
else:
|
||||
current_sharding_spec = deepcopy(sharding_spec)
|
||||
|
||||
old_dim_partition_dict = current_sharding_spec.dim_partition_dict
|
||||
new_dim_partition_dict = {}
|
||||
|
||||
# assign new dim
|
||||
for old_dim, new_dim in dim_mapping.items():
|
||||
mesh_dims = old_dim_partition_dict.pop(old_dim)
|
||||
new_dim_partition_dict[new_dim] = mesh_dims
|
||||
|
||||
for tensor_dim, mesh_dims in old_dim_partition_dict.items():
|
||||
if tensor_dim in new_dim_partition_dict:
|
||||
raise KeyError(f"There are duplicated entries for the tensor sharding dimension {tensor_dim}")
|
||||
else:
|
||||
new_dim_partition_dict[tensor_dim] = mesh_dims
|
||||
|
||||
# update sharding spec
|
||||
current_sharding_spec.__init__(device_mesh=sharding_spec.device_mesh,
|
||||
entire_shape=physical_shape,
|
||||
dim_partition_dict=new_dim_partition_dict)
|
||||
return current_sharding_spec
|
||||
|
||||
|
||||
def enumerate_all_possible_2d_sharding(mesh_dim_0, mesh_dim_1, dim_size):
|
||||
dim_partition_list = []
|
||||
# enumerate all the 2D sharding cases
|
||||
for i in range(dim_size):
|
||||
for j in range(i + 1, dim_size):
|
||||
dim_partition_dict_0 = {i: [mesh_dim_0], j: [mesh_dim_1]}
|
||||
dim_partition_dict_1 = {i: [mesh_dim_1], j: [mesh_dim_0]}
|
||||
dim_partition_list.append(dim_partition_dict_0)
|
||||
dim_partition_list.append(dim_partition_dict_1)
|
||||
for i in range(dim_size):
|
||||
dim_partition_dict_flatten = {i: [mesh_dim_0, mesh_dim_1]}
|
||||
dim_partition_list.append(dim_partition_dict_flatten)
|
||||
|
||||
return dim_partition_list
|
||||
|
||||
|
||||
def enumerate_all_possible_1d_sharding(mesh_dim_0, dim_size):
|
||||
dim_partition_list = []
|
||||
# enumerate all the 1D sharding cases
|
||||
for i in range(dim_size):
|
||||
dim_partition_dict_0 = {i: [mesh_dim_0]}
|
||||
dim_partition_list.append(dim_partition_dict_0)
|
||||
|
||||
return dim_partition_list
|
||||
|
||||
|
||||
def generate_sharding_size(dim_partition_dict, device_mesh):
|
||||
total_sharding_size = 1
|
||||
for mesh_dim_list in dim_partition_dict.values():
|
||||
mesh_dim_sharding_size = [device_mesh.shape[mesh_dim] for mesh_dim in mesh_dim_list]
|
||||
sharding_size = reduce(operator.mul, mesh_dim_sharding_size)
|
||||
total_sharding_size *= sharding_size
|
||||
|
||||
return total_sharding_size
|
Reference in New Issue
Block a user