[autoparallel] add split handler (#2032)

* [autoparallel] add split handler

* add numerical test and runtime passes
This commit is contained in:
YuliangLiu0306
2022-11-29 11:03:51 +08:00
committed by GitHub
parent 28aa9a4294
commit 0dbcd4a6f5
9 changed files with 500 additions and 22 deletions

View File

@@ -13,6 +13,7 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
from colossalai.device.device_mesh import DeviceMesh
from colossalai.tensor.comm_spec import CommSpec
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
from colossalai.tensor.sharding_spec import ShardingSpec
shape_consistency_manager = ShapeConsistencyManager()
@@ -27,6 +28,23 @@ def runtime_apply(node: Node, origin_dict: Dict, input_dict: Dict, node_index: i
return shape_consistency_manager.apply_for_autoparallel_runtime(node, origin_sharding_spec, target_sharding_spec)
def runtime_apply_for_iterable_object(node: Node, origin_dict: Dict, input_dict: Dict, node_index: int,
user_node_index: int):
"""
This method will be invoked during runtime to do the shape consistency, which makes sure the activations in type of tuple or list
is converted into the user node expected form.
"""
rst = []
for index, (origin_sharding_spec,
target_sharding_spec) in enumerate(zip(origin_dict[node_index],
input_dict[node_index][user_node_index])):
rst.append(
shape_consistency_manager.apply_for_autoparallel_runtime(node[index], origin_sharding_spec,
target_sharding_spec))
rst = type(node)(rst)
return rst
def runtime_comm_spec_apply(tensor: torch.Tensor, comm_actions_dict: Dict, node_index: int, op_data_name: str):
"""
This method will be invoked during runtime to apply the comm action following the instruction of comm spec.
@@ -81,13 +99,34 @@ def _shape_consistency_apply(gm: torch.fx.GraphModule):
continue
for user_node_index, user_node in enumerate(node.strategies_vector.successor_nodes):
if node.sharding_spec.sharding_sequence_difference(node.target_sharding_specs[user_node_index]) == 0:
continue
with mod_graph.inserting_before(user_node):
shape_consistency_node = mod_graph.create_node('call_function',
runtime_apply,
args=(node, origin_dict_node, input_dict_node,
node_to_index_dict[node], user_node_index))
if isinstance(node.sharding_spec, (list, tuple)):
assert isinstance(
node.target_sharding_specs,
(list,
tuple)), 'target sharding specs should be tuple or list when node.sharding_spec is tuple or list'
total_difference = 0
for sharding_spec, target_sharding_spec in zip(node.sharding_spec,
node.target_sharding_specs[user_node_index]):
total_difference += sharding_spec.sharding_sequence_difference(target_sharding_spec)
if total_difference == 0:
continue
with mod_graph.inserting_before(user_node):
shape_consistency_node = mod_graph.create_node('call_function',
runtime_apply_for_iterable_object,
args=(node, origin_dict_node, input_dict_node,
node_to_index_dict[node], user_node_index))
else:
assert isinstance(node.sharding_spec,
ShardingSpec), 'node.sharding_spec should be type of ShardingSpec, tuple or list.'
if node.sharding_spec.sharding_sequence_difference(node.target_sharding_specs[user_node_index]) == 0:
continue
with mod_graph.inserting_before(user_node):
shape_consistency_node = mod_graph.create_node('call_function',
runtime_apply,
args=(node, origin_dict_node, input_dict_node,
node_to_index_dict[node], user_node_index))
new_args = list(user_node.args)
new_kwargs = dict(user_node.kwargs)
# the origin node may be a positional argument or key word argument of user node

View File

@@ -100,8 +100,24 @@ def _node_args_converting(gm: torch.fx.GraphModule, device_mesh: DeviceMesh):
# skip the placeholder node added in _solution_annotation pass
if not hasattr(node, 'sharding_spec'):
continue
output_dim_partition_dict = node.sharding_spec.dim_partition_dict
device_mesh = node.sharding_spec.device_mesh
def _process_sharding_spec(sharding_spec):
if isinstance(sharding_spec, ShardingSpec):
dim_partition_dict = sharding_spec.dim_partition_dict
device_mesh = sharding_spec.device_mesh
return dim_partition_dict, device_mesh
if sharding_spec is None:
return None, None
assert isinstance(sharding_spec,
(tuple, list)), 'sharding_spec should be type of ShardingSpec, tuple, list or None'
device_mesh = sharding_spec[0].device_mesh
dim_partition_dict = []
for element in sharding_spec:
dim_partition_dict.append(_process_sharding_spec(element))
return dim_partition_dict, sharding_spec
output_dim_partition_dict, device_mesh = _process_sharding_spec(node.sharding_spec)
new_args = []
if node.op == 'call_method':

View File

@@ -1,8 +1,10 @@
from .permute_handler import PermuteHandler
from .reshape_generator import PermuteGenerator, TransposeGenerator, ViewGenerator
from .reshape_generator import PermuteGenerator, SplitGenerator, TransposeGenerator, ViewGenerator
from .split_handler import SplitHandler
from .transpose_handler import TransposeHandler
from .view_handler import ViewHandler
__all__ = [
'ViewGenerator', 'ViewHandler', 'PermuteGenerator', 'PermuteHandler', 'TransposeGenerator', 'TransposeGenerator'
'ViewGenerator', 'ViewHandler', 'PermuteGenerator', 'PermuteHandler', 'TransposeGenerator', 'TransposeGenerator',
'SplitHandler', 'SplitGenerator'
]

View File

@@ -17,7 +17,7 @@ from colossalai.auto_parallel.tensor_shard.utils import (
from colossalai.tensor.shape_consistency import CollectiveCommPattern
from colossalai.tensor.sharding_spec import ShardingSpec
__all__ = ['ReshapeGenerator', 'ViewGenerator', 'PermuteGenerator', 'TransposeGenerator']
__all__ = ['ReshapeGenerator', 'ViewGenerator', 'PermuteGenerator', 'TransposeGenerator', 'SplitGenerator']
class ReshapeGenerator(FollowingStrategyGenerator):
@@ -227,3 +227,73 @@ class TransposeGenerator(ReshapeGenerator):
strategy_list.append(strategy)
return strategy_list
class SplitGenerator(ReshapeGenerator):
"""
SplitGenerator deals with the sharding strategies of split op.
"""
def collate_strategies(self) -> List[ShardingStrategy]:
strategy_list = []
for index, strategy in enumerate(self.predecessor_node.strategies_vector):
recover_dims = None
dim_partition_dict_mapping = {}
communication_action_mapping = {}
input_sharding_spec = strategy.output_sharding_specs[self.op_data["input"]]
dim_partition_dict_for_input = copy.deepcopy(input_sharding_spec.dim_partition_dict)
split_size, split_dim = self.op_data['split_info'].data
if split_dim in dim_partition_dict_for_input:
recover_dims = dim_partition_dict_for_input.pop(split_dim)
dim_partition_dict_for_output = [
copy.deepcopy(dim_partition_dict_for_input) for _ in range(len(self.op_data["output"].data))
]
assert len(dim_partition_dict_for_output) >= 2
dim_partition_dict_mapping = {
"input": dim_partition_dict_for_input,
"output": dim_partition_dict_for_output,
}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
# add index into name to pass the duplicated check
# we keep same strategies with different name for node merging, and it will not increase the searching space,
# because in solver, this node will be merged into other nodes, and solver will not create a new variable for this node.
name = f'{sharding_spec_mapping["input"].sharding_sequence}_{index}'
# add comm action if the input need to be recovered to replica in the split dimension.
if recover_dims:
# if there is only one sharding dimension, we should use the value instead of list as logical_process_axis.
if len(recover_dims) == 1:
recover_dims = recover_dims[0]
input_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping["input"],
communication_pattern=CollectiveCommPattern.GATHER_FWD_SPLIT_BWD,
logical_process_axis=recover_dims,
comm_type=CommType.BEFORE,
arg_index=0)
# it will gather the input through gather_dim during forward phase.
input_comm_action.comm_spec.gather_dim = split_dim
# it will split the input activation grad through split_dim during backward phase.
input_comm_action.comm_spec.shard_dim = split_dim
elif len(recover_dims) >= 2:
# original sharding spec
source_spec = input_sharding_spec
# target sharding spec
target_spec = sharding_spec_mapping["input"]
comm_spec = {'src_spec': source_spec, 'tgt_spec': target_spec}
input_comm_action = CommAction(comm_spec=comm_spec, comm_type=CommType.BEFORE, arg_index=0)
else:
input_comm_action = None
if input_comm_action is not None:
communication_action_mapping["input"] = input_comm_action
strategy = self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
strategy_list.append(strategy)
return strategy_list

View File

@@ -0,0 +1,63 @@
from typing import Dict, List
import torch
from ...sharding_strategy import OperationData, OperationDataType
from ..node_handler import NodeHandler
from ..registry import operator_registry
from ..strategy import StrategyGenerator
from .reshape_generator import SplitGenerator
__all__ = ['SplitHandler']
@operator_registry.register(torch.Tensor.split)
@operator_registry.register(torch.split)
class SplitHandler(NodeHandler):
"""
A SplitHandler which deals with the sharding strategies for torch.permute or torch.split.
"""
def get_strategy_generator(self) -> List[StrategyGenerator]:
op_data_mapping = self.get_operation_data_mapping()
generators = []
generators.append(SplitGenerator(op_data_mapping, self.device_mesh, self.node.args[0]))
return generators
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
# check if the input operand is a parameter
if isinstance(self.node.args[0]._meta_data, torch.nn.parameter.Parameter):
data_type = OperationDataType.PARAM
else:
data_type = OperationDataType.ARG
input_data = self.node.args[0]._meta_data
physical_input_operand = OperationData(name=str(self.node.args[0]), type=data_type, data=input_data)
split_size = self.node.args[1]
if len(self.node.args) == 3:
# (input, split_size, split_dim)
split_dim = self.node.args[2]
else:
if self.node.kwargs:
split_dim = self.node.kwargs['dim']
else:
split_dim = 0
num_dims = self.node.args[0]._meta_data.dim()
# recover negative value to positive
if split_dim < 0:
split_dim += num_dims
split_info = (split_size, split_dim)
physical_shape_operand = OperationData(name='split_info', type=OperationDataType.ARG, data=split_info)
output_data = self.node._meta_data
physical_output_operand = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=output_data)
mapping = {
"input": physical_input_operand,
"split_info": physical_shape_operand,
"output": physical_output_operand
}
return mapping

View File

@@ -10,8 +10,6 @@ from .strategy import ReshapeGenerator, StrategyGenerator
__all__ = ['ReshapeHandler']
@operator_registry.register(torch.Tensor.split)
@operator_registry.register(torch.split)
@operator_registry.register(torch.flatten)
@operator_registry.register(torch.nn.AdaptiveAvgPool2d)
class ReshapeHandler(NodeHandler):

View File

@@ -49,12 +49,23 @@ class OutputGenerator(OutputStrategyGenerator):
"""
Generate replica strategy for output node.
"""
dim_partition_dict_mapping = {
"output": {},
}
dim_partition_dict_mapping = {}
dim_partition_dict_for_output = []
for index, _ in enumerate(self.predecessor_nodes):
mapping_name = f"input_{index}"
dim_partition_dict_mapping[mapping_name] = {}
if isinstance(self.op_data[mapping_name].data, (tuple, list)):
dim_partition_dict_for_input = [{} for _ in range(len(self.op_data[mapping_name].data))]
else:
dim_partition_dict_for_input = {}
dim_partition_dict_mapping[mapping_name] = dim_partition_dict_for_input
dim_partition_dict_for_output.append(dim_partition_dict_for_input)
if len(dim_partition_dict_for_output) == 1:
dim_partition_dict_for_output = dim_partition_dict_for_output[0]
else:
dim_partition_dict_for_output = tuple(dim_partition_dict_for_output)
dim_partition_dict_mapping['output'] = dim_partition_dict_for_output
communication_action_mapping = {}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)