mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-12 20:54:35 +00:00
[autoparallel] add split handler (#2032)
* [autoparallel] add split handler * add numerical test and runtime passes
This commit is contained in:
@@ -13,6 +13,7 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.tensor.comm_spec import CommSpec
|
||||
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||
|
||||
shape_consistency_manager = ShapeConsistencyManager()
|
||||
|
||||
@@ -27,6 +28,23 @@ def runtime_apply(node: Node, origin_dict: Dict, input_dict: Dict, node_index: i
|
||||
return shape_consistency_manager.apply_for_autoparallel_runtime(node, origin_sharding_spec, target_sharding_spec)
|
||||
|
||||
|
||||
def runtime_apply_for_iterable_object(node: Node, origin_dict: Dict, input_dict: Dict, node_index: int,
|
||||
user_node_index: int):
|
||||
"""
|
||||
This method will be invoked during runtime to do the shape consistency, which makes sure the activations in type of tuple or list
|
||||
is converted into the user node expected form.
|
||||
"""
|
||||
rst = []
|
||||
for index, (origin_sharding_spec,
|
||||
target_sharding_spec) in enumerate(zip(origin_dict[node_index],
|
||||
input_dict[node_index][user_node_index])):
|
||||
rst.append(
|
||||
shape_consistency_manager.apply_for_autoparallel_runtime(node[index], origin_sharding_spec,
|
||||
target_sharding_spec))
|
||||
rst = type(node)(rst)
|
||||
return rst
|
||||
|
||||
|
||||
def runtime_comm_spec_apply(tensor: torch.Tensor, comm_actions_dict: Dict, node_index: int, op_data_name: str):
|
||||
"""
|
||||
This method will be invoked during runtime to apply the comm action following the instruction of comm spec.
|
||||
@@ -81,13 +99,34 @@ def _shape_consistency_apply(gm: torch.fx.GraphModule):
|
||||
continue
|
||||
|
||||
for user_node_index, user_node in enumerate(node.strategies_vector.successor_nodes):
|
||||
if node.sharding_spec.sharding_sequence_difference(node.target_sharding_specs[user_node_index]) == 0:
|
||||
continue
|
||||
with mod_graph.inserting_before(user_node):
|
||||
shape_consistency_node = mod_graph.create_node('call_function',
|
||||
runtime_apply,
|
||||
args=(node, origin_dict_node, input_dict_node,
|
||||
node_to_index_dict[node], user_node_index))
|
||||
if isinstance(node.sharding_spec, (list, tuple)):
|
||||
assert isinstance(
|
||||
node.target_sharding_specs,
|
||||
(list,
|
||||
tuple)), 'target sharding specs should be tuple or list when node.sharding_spec is tuple or list'
|
||||
total_difference = 0
|
||||
for sharding_spec, target_sharding_spec in zip(node.sharding_spec,
|
||||
node.target_sharding_specs[user_node_index]):
|
||||
total_difference += sharding_spec.sharding_sequence_difference(target_sharding_spec)
|
||||
if total_difference == 0:
|
||||
continue
|
||||
with mod_graph.inserting_before(user_node):
|
||||
shape_consistency_node = mod_graph.create_node('call_function',
|
||||
runtime_apply_for_iterable_object,
|
||||
args=(node, origin_dict_node, input_dict_node,
|
||||
node_to_index_dict[node], user_node_index))
|
||||
|
||||
else:
|
||||
assert isinstance(node.sharding_spec,
|
||||
ShardingSpec), 'node.sharding_spec should be type of ShardingSpec, tuple or list.'
|
||||
if node.sharding_spec.sharding_sequence_difference(node.target_sharding_specs[user_node_index]) == 0:
|
||||
continue
|
||||
with mod_graph.inserting_before(user_node):
|
||||
shape_consistency_node = mod_graph.create_node('call_function',
|
||||
runtime_apply,
|
||||
args=(node, origin_dict_node, input_dict_node,
|
||||
node_to_index_dict[node], user_node_index))
|
||||
|
||||
new_args = list(user_node.args)
|
||||
new_kwargs = dict(user_node.kwargs)
|
||||
# the origin node may be a positional argument or key word argument of user node
|
||||
|
@@ -100,8 +100,24 @@ def _node_args_converting(gm: torch.fx.GraphModule, device_mesh: DeviceMesh):
|
||||
# skip the placeholder node added in _solution_annotation pass
|
||||
if not hasattr(node, 'sharding_spec'):
|
||||
continue
|
||||
output_dim_partition_dict = node.sharding_spec.dim_partition_dict
|
||||
device_mesh = node.sharding_spec.device_mesh
|
||||
|
||||
def _process_sharding_spec(sharding_spec):
|
||||
if isinstance(sharding_spec, ShardingSpec):
|
||||
dim_partition_dict = sharding_spec.dim_partition_dict
|
||||
device_mesh = sharding_spec.device_mesh
|
||||
return dim_partition_dict, device_mesh
|
||||
if sharding_spec is None:
|
||||
return None, None
|
||||
assert isinstance(sharding_spec,
|
||||
(tuple, list)), 'sharding_spec should be type of ShardingSpec, tuple, list or None'
|
||||
|
||||
device_mesh = sharding_spec[0].device_mesh
|
||||
dim_partition_dict = []
|
||||
for element in sharding_spec:
|
||||
dim_partition_dict.append(_process_sharding_spec(element))
|
||||
return dim_partition_dict, sharding_spec
|
||||
|
||||
output_dim_partition_dict, device_mesh = _process_sharding_spec(node.sharding_spec)
|
||||
new_args = []
|
||||
|
||||
if node.op == 'call_method':
|
||||
|
@@ -1,8 +1,10 @@
|
||||
from .permute_handler import PermuteHandler
|
||||
from .reshape_generator import PermuteGenerator, TransposeGenerator, ViewGenerator
|
||||
from .reshape_generator import PermuteGenerator, SplitGenerator, TransposeGenerator, ViewGenerator
|
||||
from .split_handler import SplitHandler
|
||||
from .transpose_handler import TransposeHandler
|
||||
from .view_handler import ViewHandler
|
||||
|
||||
__all__ = [
|
||||
'ViewGenerator', 'ViewHandler', 'PermuteGenerator', 'PermuteHandler', 'TransposeGenerator', 'TransposeGenerator'
|
||||
'ViewGenerator', 'ViewHandler', 'PermuteGenerator', 'PermuteHandler', 'TransposeGenerator', 'TransposeGenerator',
|
||||
'SplitHandler', 'SplitGenerator'
|
||||
]
|
||||
|
@@ -17,7 +17,7 @@ from colossalai.auto_parallel.tensor_shard.utils import (
|
||||
from colossalai.tensor.shape_consistency import CollectiveCommPattern
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||
|
||||
__all__ = ['ReshapeGenerator', 'ViewGenerator', 'PermuteGenerator', 'TransposeGenerator']
|
||||
__all__ = ['ReshapeGenerator', 'ViewGenerator', 'PermuteGenerator', 'TransposeGenerator', 'SplitGenerator']
|
||||
|
||||
|
||||
class ReshapeGenerator(FollowingStrategyGenerator):
|
||||
@@ -227,3 +227,73 @@ class TransposeGenerator(ReshapeGenerator):
|
||||
strategy_list.append(strategy)
|
||||
|
||||
return strategy_list
|
||||
|
||||
|
||||
class SplitGenerator(ReshapeGenerator):
|
||||
"""
|
||||
SplitGenerator deals with the sharding strategies of split op.
|
||||
"""
|
||||
|
||||
def collate_strategies(self) -> List[ShardingStrategy]:
|
||||
strategy_list = []
|
||||
for index, strategy in enumerate(self.predecessor_node.strategies_vector):
|
||||
recover_dims = None
|
||||
dim_partition_dict_mapping = {}
|
||||
communication_action_mapping = {}
|
||||
input_sharding_spec = strategy.output_sharding_specs[self.op_data["input"]]
|
||||
dim_partition_dict_for_input = copy.deepcopy(input_sharding_spec.dim_partition_dict)
|
||||
split_size, split_dim = self.op_data['split_info'].data
|
||||
|
||||
if split_dim in dim_partition_dict_for_input:
|
||||
recover_dims = dim_partition_dict_for_input.pop(split_dim)
|
||||
|
||||
dim_partition_dict_for_output = [
|
||||
copy.deepcopy(dim_partition_dict_for_input) for _ in range(len(self.op_data["output"].data))
|
||||
]
|
||||
assert len(dim_partition_dict_for_output) >= 2
|
||||
dim_partition_dict_mapping = {
|
||||
"input": dim_partition_dict_for_input,
|
||||
"output": dim_partition_dict_for_output,
|
||||
}
|
||||
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
|
||||
# add index into name to pass the duplicated check
|
||||
# we keep same strategies with different name for node merging, and it will not increase the searching space,
|
||||
# because in solver, this node will be merged into other nodes, and solver will not create a new variable for this node.
|
||||
name = f'{sharding_spec_mapping["input"].sharding_sequence}_{index}'
|
||||
|
||||
# add comm action if the input need to be recovered to replica in the split dimension.
|
||||
if recover_dims:
|
||||
# if there is only one sharding dimension, we should use the value instead of list as logical_process_axis.
|
||||
if len(recover_dims) == 1:
|
||||
recover_dims = recover_dims[0]
|
||||
input_comm_action = self.get_communication_action(
|
||||
sharding_spec=sharding_spec_mapping["input"],
|
||||
communication_pattern=CollectiveCommPattern.GATHER_FWD_SPLIT_BWD,
|
||||
logical_process_axis=recover_dims,
|
||||
comm_type=CommType.BEFORE,
|
||||
arg_index=0)
|
||||
# it will gather the input through gather_dim during forward phase.
|
||||
input_comm_action.comm_spec.gather_dim = split_dim
|
||||
# it will split the input activation grad through split_dim during backward phase.
|
||||
input_comm_action.comm_spec.shard_dim = split_dim
|
||||
|
||||
elif len(recover_dims) >= 2:
|
||||
# original sharding spec
|
||||
source_spec = input_sharding_spec
|
||||
# target sharding spec
|
||||
target_spec = sharding_spec_mapping["input"]
|
||||
comm_spec = {'src_spec': source_spec, 'tgt_spec': target_spec}
|
||||
input_comm_action = CommAction(comm_spec=comm_spec, comm_type=CommType.BEFORE, arg_index=0)
|
||||
|
||||
else:
|
||||
input_comm_action = None
|
||||
|
||||
if input_comm_action is not None:
|
||||
communication_action_mapping["input"] = input_comm_action
|
||||
|
||||
strategy = self.get_sharding_strategy(name=name,
|
||||
sharding_spec_mapping=sharding_spec_mapping,
|
||||
communication_action_mapping=communication_action_mapping)
|
||||
strategy_list.append(strategy)
|
||||
|
||||
return strategy_list
|
||||
|
@@ -0,0 +1,63 @@
|
||||
from typing import Dict, List
|
||||
|
||||
import torch
|
||||
|
||||
from ...sharding_strategy import OperationData, OperationDataType
|
||||
from ..node_handler import NodeHandler
|
||||
from ..registry import operator_registry
|
||||
from ..strategy import StrategyGenerator
|
||||
from .reshape_generator import SplitGenerator
|
||||
|
||||
__all__ = ['SplitHandler']
|
||||
|
||||
|
||||
@operator_registry.register(torch.Tensor.split)
|
||||
@operator_registry.register(torch.split)
|
||||
class SplitHandler(NodeHandler):
|
||||
"""
|
||||
A SplitHandler which deals with the sharding strategies for torch.permute or torch.split.
|
||||
"""
|
||||
|
||||
def get_strategy_generator(self) -> List[StrategyGenerator]:
|
||||
op_data_mapping = self.get_operation_data_mapping()
|
||||
generators = []
|
||||
generators.append(SplitGenerator(op_data_mapping, self.device_mesh, self.node.args[0]))
|
||||
return generators
|
||||
|
||||
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
|
||||
# check if the input operand is a parameter
|
||||
if isinstance(self.node.args[0]._meta_data, torch.nn.parameter.Parameter):
|
||||
data_type = OperationDataType.PARAM
|
||||
else:
|
||||
data_type = OperationDataType.ARG
|
||||
|
||||
input_data = self.node.args[0]._meta_data
|
||||
physical_input_operand = OperationData(name=str(self.node.args[0]), type=data_type, data=input_data)
|
||||
split_size = self.node.args[1]
|
||||
if len(self.node.args) == 3:
|
||||
# (input, split_size, split_dim)
|
||||
split_dim = self.node.args[2]
|
||||
else:
|
||||
if self.node.kwargs:
|
||||
split_dim = self.node.kwargs['dim']
|
||||
else:
|
||||
split_dim = 0
|
||||
|
||||
num_dims = self.node.args[0]._meta_data.dim()
|
||||
# recover negative value to positive
|
||||
if split_dim < 0:
|
||||
split_dim += num_dims
|
||||
|
||||
split_info = (split_size, split_dim)
|
||||
physical_shape_operand = OperationData(name='split_info', type=OperationDataType.ARG, data=split_info)
|
||||
|
||||
output_data = self.node._meta_data
|
||||
physical_output_operand = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=output_data)
|
||||
|
||||
mapping = {
|
||||
"input": physical_input_operand,
|
||||
"split_info": physical_shape_operand,
|
||||
"output": physical_output_operand
|
||||
}
|
||||
|
||||
return mapping
|
@@ -10,8 +10,6 @@ from .strategy import ReshapeGenerator, StrategyGenerator
|
||||
__all__ = ['ReshapeHandler']
|
||||
|
||||
|
||||
@operator_registry.register(torch.Tensor.split)
|
||||
@operator_registry.register(torch.split)
|
||||
@operator_registry.register(torch.flatten)
|
||||
@operator_registry.register(torch.nn.AdaptiveAvgPool2d)
|
||||
class ReshapeHandler(NodeHandler):
|
||||
|
@@ -49,12 +49,23 @@ class OutputGenerator(OutputStrategyGenerator):
|
||||
"""
|
||||
Generate replica strategy for output node.
|
||||
"""
|
||||
dim_partition_dict_mapping = {
|
||||
"output": {},
|
||||
}
|
||||
dim_partition_dict_mapping = {}
|
||||
dim_partition_dict_for_output = []
|
||||
for index, _ in enumerate(self.predecessor_nodes):
|
||||
mapping_name = f"input_{index}"
|
||||
dim_partition_dict_mapping[mapping_name] = {}
|
||||
if isinstance(self.op_data[mapping_name].data, (tuple, list)):
|
||||
dim_partition_dict_for_input = [{} for _ in range(len(self.op_data[mapping_name].data))]
|
||||
else:
|
||||
dim_partition_dict_for_input = {}
|
||||
dim_partition_dict_mapping[mapping_name] = dim_partition_dict_for_input
|
||||
dim_partition_dict_for_output.append(dim_partition_dict_for_input)
|
||||
|
||||
if len(dim_partition_dict_for_output) == 1:
|
||||
dim_partition_dict_for_output = dim_partition_dict_for_output[0]
|
||||
else:
|
||||
dim_partition_dict_for_output = tuple(dim_partition_dict_for_output)
|
||||
|
||||
dim_partition_dict_mapping['output'] = dim_partition_dict_for_output
|
||||
|
||||
communication_action_mapping = {}
|
||||
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
|
||||
|
Reference in New Issue
Block a user