From 37df666f38efad28b4cb681e8278c0deadc8679c Mon Sep 17 00:00:00 2001 From: YuliangLiu0306 <72588413+YuliangLiu0306@users.noreply.github.com> Date: Wed, 8 Feb 2023 15:02:49 +0800 Subject: [PATCH] [autoparallel] refactor handlers which reshape input tensors (#2615) * [autoparallel] refactor handlers which reshape input tensors * polish --- .../tensor_shard/node_handler/__init__.py | 12 +- ..._handler.py => default_reshape_handler.py} | 10 +- .../node_handler/experimental/__init__.py | 10 - .../experimental/reshape_generator.py | 299 ------------------ .../{experimental => }/permute_handler.py | 9 +- .../{experimental => }/split_handler.py | 9 +- .../node_handler/strategy/__init__.py | 15 +- .../strategy/reshape_generator.py | 267 +++++++++++++++- .../{experimental => }/transpose_handler.py | 9 +- .../{experimental => }/view_handler.py | 9 +- ...ler.py => test_default_reshape_handler.py} | 8 +- .../test_node_handler/test_getitem_handler.py | 6 +- .../test_permute_and_transpose_handler.py | 2 +- .../test_node_handler/test_split_handler.py | 5 +- .../test_node_handler/test_view_handler.py | 2 +- 15 files changed, 307 insertions(+), 365 deletions(-) rename colossalai/auto_parallel/tensor_shard/node_handler/{reshape_handler.py => default_reshape_handler.py} (87%) delete mode 100644 colossalai/auto_parallel/tensor_shard/node_handler/experimental/__init__.py delete mode 100644 colossalai/auto_parallel/tensor_shard/node_handler/experimental/reshape_generator.py rename colossalai/auto_parallel/tensor_shard/node_handler/{experimental => }/permute_handler.py (92%) rename colossalai/auto_parallel/tensor_shard/node_handler/{experimental => }/split_handler.py (89%) rename colossalai/auto_parallel/tensor_shard/node_handler/{experimental => }/transpose_handler.py (90%) rename colossalai/auto_parallel/tensor_shard/node_handler/{experimental => }/view_handler.py (88%) rename tests/test_auto_parallel/test_tensor_shard/test_node_handler/{test_reshape_handler.py => test_default_reshape_handler.py} (91%) diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/__init__.py b/colossalai/auto_parallel/tensor_shard/node_handler/__init__.py index 87bd8966b..0050358ce 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/__init__.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/__init__.py @@ -3,8 +3,8 @@ from .batch_norm_handler import BatchNormModuleHandler from .binary_elementwise_handler import BinaryElementwiseHandler from .bmm_handler import AddBMMFunctionHandler, BMMFunctionHandler from .conv_handler import ConvFunctionHandler, ConvModuleHandler +from .default_reshape_handler import DefaultReshapeHandler from .embedding_handler import EmbeddingFunctionHandler, EmbeddingModuleHandler -from .experimental import PermuteHandler, ViewHandler from .getattr_handler import GetattrHandler from .getitem_handler import GetItemHandler from .layer_norm_handler import LayerNormModuleHandler @@ -13,20 +13,24 @@ from .matmul_handler import MatMulHandler from .normal_pooling_handler import NormPoolingHandler from .option import ShardOption from .output_handler import OutputHandler +from .permute_handler import PermuteHandler from .placeholder_handler import PlaceholderHandler from .registry import operator_registry -from .reshape_handler import ReshapeHandler from .softmax_handler import SoftmaxHandler +from .split_handler import SplitHandler from .sum_handler import SumHandler from .tensor_constructor_handler import TensorConstructorHandler +from .transpose_handler import TransposeHandler from .unary_elementwise_handler import UnaryElementwiseHandler +from .view_handler import ViewHandler from .where_handler import WhereHandler __all__ = [ 'LinearFunctionHandler', 'LinearModuleHandler', 'BMMFunctionHandler', 'AddBMMFunctionHandler', 'LayerNormModuleHandler', 'BatchNormModuleHandler', 'ConvModuleHandler', 'ConvFunctionHandler', - 'UnaryElementwiseHandler', 'ReshapeHandler', 'PlaceholderHandler', 'OutputHandler', 'WhereHandler', + 'UnaryElementwiseHandler', 'DefaultReshapeHandler', 'PlaceholderHandler', 'OutputHandler', 'WhereHandler', 'NormPoolingHandler', 'BinaryElementwiseHandler', 'MatMulHandler', 'operator_registry', 'ADDMMFunctionHandler', 'GetItemHandler', 'GetattrHandler', 'ViewHandler', 'PermuteHandler', 'TensorConstructorHandler', - 'EmbeddingModuleHandler', 'EmbeddingFunctionHandler', 'SumHandler', 'SoftmaxHandler', 'ShardOption' + 'EmbeddingModuleHandler', 'EmbeddingFunctionHandler', 'SumHandler', 'SoftmaxHandler', 'ShardOption', + 'TransposeHandler', 'SplitHandler' ] diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/reshape_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/default_reshape_handler.py similarity index 87% rename from colossalai/auto_parallel/tensor_shard/node_handler/reshape_handler.py rename to colossalai/auto_parallel/tensor_shard/node_handler/default_reshape_handler.py index 7763b1884..0c5b9f39e 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/reshape_handler.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/default_reshape_handler.py @@ -5,23 +5,23 @@ import torch from ..sharding_strategy import OperationData, OperationDataType from .node_handler import MetaInfoNodeHandler, NodeHandler from .registry import operator_registry -from .strategy import ReshapeGenerator, StrategyGenerator +from .strategy import DefaultReshapeGenerator, StrategyGenerator -__all__ = ['ReshapeHandler'] +__all__ = ['DefaultReshapeHandler'] @operator_registry.register(torch.flatten) @operator_registry.register(torch.Tensor.unsqueeze) @operator_registry.register(torch.nn.AdaptiveAvgPool2d) -class ReshapeHandler(MetaInfoNodeHandler): +class DefaultReshapeHandler(MetaInfoNodeHandler): """ - A ReshapeHandler which deals with the sharding strategies for Reshape Op, such as torch.reshape. + A DefaultReshapeHandler which deals with the sharding strategies for Reshape Op, such as torch.reshape. """ def get_strategy_generator(self) -> List[StrategyGenerator]: op_data_mapping = self.get_operation_data_mapping() generators = [] - generators.append(ReshapeGenerator(op_data_mapping, self.device_mesh, self.node.args[0])) + generators.append(DefaultReshapeGenerator(op_data_mapping, self.device_mesh, self.node.args[0])) return generators def infer_logical_shape(self, data): diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/experimental/__init__.py b/colossalai/auto_parallel/tensor_shard/node_handler/experimental/__init__.py deleted file mode 100644 index 15f66104b..000000000 --- a/colossalai/auto_parallel/tensor_shard/node_handler/experimental/__init__.py +++ /dev/null @@ -1,10 +0,0 @@ -from .permute_handler import PermuteHandler -from .reshape_generator import PermuteGenerator, SplitGenerator, TransposeGenerator, ViewGenerator -from .split_handler import SplitHandler -from .transpose_handler import TransposeHandler -from .view_handler import ViewHandler - -__all__ = [ - 'ViewGenerator', 'ViewHandler', 'PermuteGenerator', 'PermuteHandler', 'TransposeGenerator', 'TransposeGenerator', - 'SplitHandler', 'SplitGenerator' -] diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/experimental/reshape_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/experimental/reshape_generator.py deleted file mode 100644 index b7248d011..000000000 --- a/colossalai/auto_parallel/tensor_shard/node_handler/experimental/reshape_generator.py +++ /dev/null @@ -1,299 +0,0 @@ -import copy -from typing import List - -from colossalai.auto_parallel.tensor_shard.node_handler.strategy.strategy_generator import FollowingStrategyGenerator -from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( - CommAction, - CommType, - MemoryCost, - ShardingStrategy, - TrainCycleItem, -) -from colossalai.auto_parallel.tensor_shard.utils import ( - check_keep_sharding_status, - detect_reshape_mapping, - infer_output_dim_partition_dict, -) -from colossalai.tensor.shape_consistency import CollectiveCommPattern -from colossalai.tensor.sharding_spec import ShardingSpec - -__all__ = ['ReshapeGenerator', 'ViewGenerator', 'PermuteGenerator', 'TransposeGenerator', 'SplitGenerator'] - - -class ReshapeGenerator(FollowingStrategyGenerator): - """ - ReshapeGenerator is the base class for all the reshape operation. - """ - - def validate(self) -> bool: - return super().validate() - - def update_compute_cost(self, strategy: ShardingStrategy): - compute_cost = TrainCycleItem(fwd=10, bwd=10, total=20) - strategy.compute_cost = compute_cost - - def update_memory_cost(self, strategy: ShardingStrategy): - ''' - Compute the memory cost per device with this specific strategy. - ''' - forward_size_mapping = { - 'input': self._compute_size_in_bytes(strategy, "input"), - 'output': self._compute_size_in_bytes(strategy, "output") - } - - backward_size_mapping = copy.deepcopy(forward_size_mapping) - backward_size_mapping.pop("output") - # compute fwd cost incurred - # fwd_cost = input + output - fwd_activation_cost = sum([v for k, v in forward_size_mapping.items() if not self.is_param(k)]) - fwd_parameter_cost = sum([v for k, v in forward_size_mapping.items() if self.is_param(k)]) - fwd_mem_cost = MemoryCost(activation=fwd_activation_cost, parameter=fwd_parameter_cost) - - # compute bwd cost incurred - # bwd_cost = input_grad - bwd_activation_cost = sum([v for k, v in backward_size_mapping.items() if not self.is_param(k)]) - bwd_parameter_cost = sum([v for k, v in backward_size_mapping.items() if self.is_param(k)]) - bwd_mem_cost = MemoryCost(activation=bwd_activation_cost, parameter=bwd_parameter_cost) - - # compute total cost - total_mem_cost = MemoryCost(activation=fwd_activation_cost + bwd_activation_cost, - parameter=fwd_parameter_cost + bwd_parameter_cost) - memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost) - strategy.memory_cost = memory_cost - - def collate_strategies(self) -> List[ShardingStrategy]: - return super().collate_strategies() - - -class ViewGenerator(ReshapeGenerator): - """ - ViewGenerator deals with the sharding strategies of view op. - """ - - def collate_strategies(self) -> List[ShardingStrategy]: - strategy_list = [] - for index, strategy in enumerate(self.predecessor_node.strategies_vector): - dim_partition_dict_mapping = {} - communication_action_mapping = {} - input_sharding_spec = strategy.output_sharding_specs[self.op_data["input"]] - - origin_shape = self.op_data['input'].data.shape - tgt_shape = self.op_data['tgt_shape'].data - - reshape_mapping_dict = detect_reshape_mapping(origin_shape, tgt_shape) - - dim_partition_dict_for_input = input_sharding_spec.dim_partition_dict - keep_sharding_status = check_keep_sharding_status(dim_partition_dict_for_input, reshape_mapping_dict) - - if keep_sharding_status: - dim_partition_dict_for_output = infer_output_dim_partition_dict(dim_partition_dict_for_input, - reshape_mapping_dict) - else: - dim_partition_dict_for_output = {} - - dim_partition_dict_mapping = { - "input": dim_partition_dict_for_input, - "output": dim_partition_dict_for_output, - } - sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) - - # add index into name to pass the duplicated check - # we keep same strategies with different name for node merging, and it will not increase the searching space, - # because in solver, this node will be merged into other nodes, and solver will not create a new variable for this node. - if keep_sharding_status: - name = f'{sharding_spec_mapping["input"].sharding_sequence} -> {sharding_spec_mapping["output"].sharding_sequence}_{index}' - else: - name = f'{sharding_spec_mapping["input"].sharding_sequence} -> FULLY REPLICATED_{index}' - - # add comm action for converting input to fully replicated - total_mesh_dim_list = [] - for mesh_dim_list in dim_partition_dict_for_input.values(): - total_mesh_dim_list.extend(mesh_dim_list) - # if there is only one sharding dimension, we should use the value instead of list as logical_process_axis. - if len(total_mesh_dim_list) == 1: - total_mesh_dim_list = total_mesh_dim_list[0] - # the total mesh dim list only has one element, so the shard dim has only one element as well. - shard_dim = list(dim_partition_dict_for_input.keys())[0] - input_comm_action = self.get_communication_action( - sharding_spec=sharding_spec_mapping["input"], - communication_pattern=CollectiveCommPattern.GATHER_FWD_SPLIT_BWD, - logical_process_axis=total_mesh_dim_list, - comm_type=CommType.BEFORE, - arg_index=0) - # it will gather the input through gather_dim during forward phase. - input_comm_action.comm_spec.gather_dim = shard_dim - # it will split the input activation grad through shard_dim during backward phase. - input_comm_action.comm_spec.shard_dim = shard_dim - - elif len(total_mesh_dim_list) >= 2: - source_spec = sharding_spec_mapping["input"] - target_spec = ShardingSpec(device_mesh=self.device_mesh, - entire_shape=source_spec.entire_shape, - dim_partition_dict={}) - comm_spec = {'src_spec': source_spec, 'tgt_spec': target_spec} - input_comm_action = CommAction(comm_spec=comm_spec, comm_type=CommType.BEFORE, arg_index=0) - - else: - input_comm_action = None - - if input_comm_action is not None: - communication_action_mapping["input"] = input_comm_action - - strategy = self.get_sharding_strategy(name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping=communication_action_mapping) - strategy_list.append(strategy) - - return strategy_list - - -class PermuteGenerator(ReshapeGenerator): - """ - PermuteGenerator deals with the sharding strategies of permute op. - """ - - def collate_strategies(self) -> List[ShardingStrategy]: - strategy_list = [] - for index, strategy in enumerate(self.predecessor_node.strategies_vector): - dim_partition_dict_mapping = {} - communication_action_mapping = {} - input_sharding_spec = strategy.output_sharding_specs[self.op_data["input"]] - - permute_dims = self.op_data['permute_dims'].data - dim_partition_dict_for_input = input_sharding_spec.dim_partition_dict - dim_partition_dict_for_output = {} - for dim_index, permute_dim in enumerate(permute_dims): - if permute_dim in dim_partition_dict_for_input: - dim_partition_dict_for_output[dim_index] = dim_partition_dict_for_input[permute_dim] - - dim_partition_dict_mapping = { - "input": dim_partition_dict_for_input, - "output": dim_partition_dict_for_output, - } - sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) - - # add index into name to pass the duplicated check - # we keep same strategies with different name for node merging, and it will not increase the searching space, - # because in solver, this node will be merged into other nodes, and solver will not create a new variable for this node. - name = f'{sharding_spec_mapping["input"].sharding_sequence} -> {sharding_spec_mapping["output"].sharding_sequence}_{index}' - - strategy = self.get_sharding_strategy(name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping=communication_action_mapping) - strategy_list.append(strategy) - - return strategy_list - - -class TransposeGenerator(ReshapeGenerator): - """ - TransposeGenerator deals with the sharding strategies of permute op. - """ - - def collate_strategies(self) -> List[ShardingStrategy]: - strategy_list = [] - for index, strategy in enumerate(self.predecessor_node.strategies_vector): - dim_partition_dict_mapping = {} - communication_action_mapping = {} - input_sharding_spec = strategy.output_sharding_specs[self.op_data["input"]] - dim_partition_dict_for_input = input_sharding_spec.dim_partition_dict - dim_partition_dict_for_output = {} - - transpose_dims = self.op_data['transpose_dims'].data - dim_0 = transpose_dims[0] - dim_1 = transpose_dims[1] - for dim, sharded_dims in dim_partition_dict_for_input.items(): - if dim == dim_0: - dim_partition_dict_for_output[dim_1] = dim_partition_dict_for_input[dim_0] - elif dim == dim_1: - dim_partition_dict_for_output[dim_0] = dim_partition_dict_for_input[dim_1] - else: - dim_partition_dict_for_output[dim] = sharded_dims - - dim_partition_dict_mapping = { - "input": dim_partition_dict_for_input, - "output": dim_partition_dict_for_output, - } - sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) - - # add index into name to pass the duplicated check - # we keep same strategies with different name for node merging, and it will not increase the searching space, - # because in solver, this node will be merged into other nodes, and solver will not create a new variable for this node. - name = f'{sharding_spec_mapping["input"].sharding_sequence} -> {sharding_spec_mapping["output"].sharding_sequence}_{index}' - - strategy = self.get_sharding_strategy(name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping=communication_action_mapping) - strategy_list.append(strategy) - - return strategy_list - - -class SplitGenerator(ReshapeGenerator): - """ - SplitGenerator deals with the sharding strategies of split op. - """ - - def collate_strategies(self) -> List[ShardingStrategy]: - strategy_list = [] - for index, strategy in enumerate(self.predecessor_node.strategies_vector): - recover_dims = None - dim_partition_dict_mapping = {} - communication_action_mapping = {} - input_sharding_spec = strategy.output_sharding_specs[self.op_data["input"]] - dim_partition_dict_for_input = copy.deepcopy(input_sharding_spec.dim_partition_dict) - split_size, split_dim = self.op_data['split_info'].data - - if split_dim in dim_partition_dict_for_input: - recover_dims = dim_partition_dict_for_input.pop(split_dim) - - dim_partition_dict_for_output = [ - copy.deepcopy(dim_partition_dict_for_input) for _ in range(len(self.op_data["output"].data)) - ] - assert len(dim_partition_dict_for_output) >= 2 - dim_partition_dict_mapping = { - "input": dim_partition_dict_for_input, - "output": dim_partition_dict_for_output, - } - sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) - # add index into name to pass the duplicated check - # we keep same strategies with different name for node merging, and it will not increase the searching space, - # because in solver, this node will be merged into other nodes, and solver will not create a new variable for this node. - name = f'{sharding_spec_mapping["input"].sharding_sequence}_{index}' - - # add comm action if the input need to be recovered to replica in the split dimension. - if recover_dims: - # if there is only one sharding dimension, we should use the value instead of list as logical_process_axis. - if len(recover_dims) == 1: - recover_dims = recover_dims[0] - input_comm_action = self.get_communication_action( - sharding_spec=sharding_spec_mapping["input"], - communication_pattern=CollectiveCommPattern.GATHER_FWD_SPLIT_BWD, - logical_process_axis=recover_dims, - comm_type=CommType.BEFORE, - arg_index=0) - # it will gather the input through gather_dim during forward phase. - input_comm_action.comm_spec.gather_dim = split_dim - # it will split the input activation grad through split_dim during backward phase. - input_comm_action.comm_spec.shard_dim = split_dim - - elif len(recover_dims) >= 2: - # original sharding spec - source_spec = input_sharding_spec - # target sharding spec - target_spec = sharding_spec_mapping["input"] - comm_spec = {'src_spec': source_spec, 'tgt_spec': target_spec} - input_comm_action = CommAction(comm_spec=comm_spec, comm_type=CommType.BEFORE, arg_index=0) - - else: - input_comm_action = None - - if input_comm_action is not None: - communication_action_mapping["input"] = input_comm_action - - strategy = self.get_sharding_strategy(name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping=communication_action_mapping) - strategy_list.append(strategy) - - return strategy_list diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/experimental/permute_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/permute_handler.py similarity index 92% rename from colossalai/auto_parallel/tensor_shard/node_handler/experimental/permute_handler.py rename to colossalai/auto_parallel/tensor_shard/node_handler/permute_handler.py index 6d625e153..91e4a5105 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/experimental/permute_handler.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/permute_handler.py @@ -2,11 +2,10 @@ from typing import Dict, List import torch -from ...sharding_strategy import OperationData, OperationDataType -from ..node_handler import NodeHandler -from ..registry import operator_registry -from ..strategy import StrategyGenerator -from .reshape_generator import PermuteGenerator +from ..sharding_strategy import OperationData, OperationDataType +from .node_handler import NodeHandler +from .registry import operator_registry +from .strategy import PermuteGenerator, StrategyGenerator __all__ = ['PermuteHandler'] diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/experimental/split_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/split_handler.py similarity index 89% rename from colossalai/auto_parallel/tensor_shard/node_handler/experimental/split_handler.py rename to colossalai/auto_parallel/tensor_shard/node_handler/split_handler.py index 38c5eed7d..653d158b7 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/experimental/split_handler.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/split_handler.py @@ -2,11 +2,10 @@ from typing import Dict, List import torch -from ...sharding_strategy import OperationData, OperationDataType -from ..node_handler import NodeHandler -from ..registry import operator_registry -from ..strategy import StrategyGenerator -from .reshape_generator import SplitGenerator +from ..sharding_strategy import OperationData, OperationDataType +from .node_handler import NodeHandler +from .registry import operator_registry +from .strategy import SplitGenerator, StrategyGenerator __all__ = ['SplitHandler'] diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/__init__.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/__init__.py index 8d25475f9..db1f31521 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/__init__.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/__init__.py @@ -14,7 +14,13 @@ from .matmul_strategy_generator import ( from .normal_pooling_generator import NormalPoolStrategyGenerator from .output_generator import OutputGenerator from .placeholder_generator import PlaceholderGenerator -from .reshape_generator import ReshapeGenerator +from .reshape_generator import ( + DefaultReshapeGenerator, + PermuteGenerator, + SplitGenerator, + TransposeGenerator, + ViewGenerator, +) from .softmax_generator import SoftmaxGenerator from .strategy_generator import StrategyGenerator from .sum_generator import SumGenerator @@ -26,7 +32,8 @@ __all__ = [ 'StrategyGenerator', 'DotProductStrategyGenerator', 'MatVecStrategyGenerator', 'LinearProjectionStrategyGenerator', 'BatchedMatMulStrategyGenerator', 'ConvStrategyGenerator', 'UnaryElementwiseGenerator', 'BatchNormStrategyGenerator', 'GetItemStrategyGenerator', 'TensorStrategyGenerator', 'TensorTupleStrategyGenerator', - 'LayerNormGenerator', 'ReshapeGenerator', 'PlaceholderGenerator', 'OutputGenerator', 'WhereGenerator', - 'ReshapeGenerator', 'NormalPoolStrategyGenerator', 'BinaryElementwiseStrategyGenerator', 'GetattrGenerator', - 'TensorConstructorGenerator', 'EmbeddingStrategyGenerator', 'SumGenerator', 'SoftmaxGenerator' + 'LayerNormGenerator', 'PlaceholderGenerator', 'OutputGenerator', 'WhereGenerator', 'NormalPoolStrategyGenerator', + 'BinaryElementwiseStrategyGenerator', 'GetattrGenerator', 'TensorConstructorGenerator', + 'EmbeddingStrategyGenerator', 'SumGenerator', 'SoftmaxGenerator', 'ViewGenerator', 'PermuteGenerator', + 'TransposeGenerator', 'SplitGenerator', 'DefaultReshapeGenerator' ] diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/reshape_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/reshape_generator.py index 0b3506c27..39983e918 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/reshape_generator.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/reshape_generator.py @@ -1,6 +1,7 @@ import copy from typing import List +from colossalai.auto_parallel.tensor_shard.node_handler.strategy.strategy_generator import FollowingStrategyGenerator from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( CommAction, CommType, @@ -8,17 +9,20 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( ShardingStrategy, TrainCycleItem, ) +from colossalai.auto_parallel.tensor_shard.utils import ( + check_keep_sharding_status, + detect_reshape_mapping, + infer_output_dim_partition_dict, +) from colossalai.tensor.shape_consistency import CollectiveCommPattern from colossalai.tensor.sharding_spec import ShardingSpec -from .strategy_generator import FollowingStrategyGenerator - -__all__ = ['ReshapeGenerator'] +__all__ = ['ReshapeGenerator', 'ViewGenerator', 'PermuteGenerator', 'TransposeGenerator', 'SplitGenerator'] class ReshapeGenerator(FollowingStrategyGenerator): """ - ReshapeGenerator which deals with the sharding strategies of Reshape Op, such as torch.Tensor.permute. + ReshapeGenerator is the base class for all the reshape operation. """ def validate(self) -> bool: @@ -57,11 +61,255 @@ class ReshapeGenerator(FollowingStrategyGenerator): memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost) strategy.memory_cost = memory_cost + def collate_strategies(self) -> List[ShardingStrategy]: + return super().collate_strategies() + + +class ViewGenerator(ReshapeGenerator): + """ + ViewGenerator deals with the sharding strategies of view op. + """ + def collate_strategies(self) -> List[ShardingStrategy]: strategy_list = [] - # For reshape function, to keep the computing correctness we keep the sharding - # spec of input is fully replicated. In addition, we will keep the output in - # replica status and let the successor node choose the way to resharding the + for index, strategy in enumerate(self.predecessor_node.strategies_vector): + dim_partition_dict_mapping = {} + communication_action_mapping = {} + input_sharding_spec = strategy.output_sharding_specs[self.op_data["input"]] + + origin_shape = self.op_data['input'].data.shape + tgt_shape = self.op_data['tgt_shape'].data + + reshape_mapping_dict = detect_reshape_mapping(origin_shape, tgt_shape) + + dim_partition_dict_for_input = input_sharding_spec.dim_partition_dict + keep_sharding_status = check_keep_sharding_status(dim_partition_dict_for_input, reshape_mapping_dict) + + if keep_sharding_status: + dim_partition_dict_for_output = infer_output_dim_partition_dict(dim_partition_dict_for_input, + reshape_mapping_dict) + else: + dim_partition_dict_for_output = {} + + dim_partition_dict_mapping = { + "input": dim_partition_dict_for_input, + "output": dim_partition_dict_for_output, + } + sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) + + # add index into name to pass the duplicated check + # we keep same strategies with different name for node merging, and it will not increase the searching space, + # because in solver, this node will be merged into other nodes, and solver will not create a new variable for this node. + if keep_sharding_status: + name = f'{sharding_spec_mapping["input"].sharding_sequence} -> {sharding_spec_mapping["output"].sharding_sequence}_{index}' + else: + name = f'{sharding_spec_mapping["input"].sharding_sequence} -> FULLY REPLICATED_{index}' + + # add comm action for converting input to fully replicated + total_mesh_dim_list = [] + for mesh_dim_list in dim_partition_dict_for_input.values(): + total_mesh_dim_list.extend(mesh_dim_list) + # if there is only one sharding dimension, we should use the value instead of list as logical_process_axis. + if len(total_mesh_dim_list) == 1: + total_mesh_dim_list = total_mesh_dim_list[0] + # the total mesh dim list only has one element, so the shard dim has only one element as well. + shard_dim = list(dim_partition_dict_for_input.keys())[0] + input_comm_action = self.get_communication_action( + sharding_spec=sharding_spec_mapping["input"], + communication_pattern=CollectiveCommPattern.GATHER_FWD_SPLIT_BWD, + logical_process_axis=total_mesh_dim_list, + comm_type=CommType.BEFORE, + arg_index=0) + # it will gather the input through gather_dim during forward phase. + input_comm_action.comm_spec.gather_dim = shard_dim + # it will split the input activation grad through shard_dim during backward phase. + input_comm_action.comm_spec.shard_dim = shard_dim + + elif len(total_mesh_dim_list) >= 2: + source_spec = sharding_spec_mapping["input"] + target_spec = ShardingSpec(device_mesh=self.device_mesh, + entire_shape=source_spec.entire_shape, + dim_partition_dict={}) + comm_spec = {'src_spec': source_spec, 'tgt_spec': target_spec} + input_comm_action = CommAction(comm_spec=comm_spec, comm_type=CommType.BEFORE, arg_index=0) + + else: + input_comm_action = None + + if input_comm_action is not None: + communication_action_mapping["input"] = input_comm_action + + strategy = self.get_sharding_strategy(name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping) + strategy_list.append(strategy) + + return strategy_list + + +class PermuteGenerator(ReshapeGenerator): + """ + PermuteGenerator deals with the sharding strategies of permute op. + """ + + def collate_strategies(self) -> List[ShardingStrategy]: + strategy_list = [] + for index, strategy in enumerate(self.predecessor_node.strategies_vector): + dim_partition_dict_mapping = {} + communication_action_mapping = {} + input_sharding_spec = strategy.output_sharding_specs[self.op_data["input"]] + + permute_dims = self.op_data['permute_dims'].data + dim_partition_dict_for_input = input_sharding_spec.dim_partition_dict + dim_partition_dict_for_output = {} + for dim_index, permute_dim in enumerate(permute_dims): + if permute_dim in dim_partition_dict_for_input: + dim_partition_dict_for_output[dim_index] = dim_partition_dict_for_input[permute_dim] + + dim_partition_dict_mapping = { + "input": dim_partition_dict_for_input, + "output": dim_partition_dict_for_output, + } + sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) + + # add index into name to pass the duplicated check + # we keep same strategies with different name for node merging, and it will not increase the searching space, + # because in solver, this node will be merged into other nodes, and solver will not create a new variable for this node. + name = f'{sharding_spec_mapping["input"].sharding_sequence} -> {sharding_spec_mapping["output"].sharding_sequence}_{index}' + + strategy = self.get_sharding_strategy(name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping) + strategy_list.append(strategy) + + return strategy_list + + +class TransposeGenerator(ReshapeGenerator): + """ + TransposeGenerator deals with the sharding strategies of permute op. + """ + + def collate_strategies(self) -> List[ShardingStrategy]: + strategy_list = [] + for index, strategy in enumerate(self.predecessor_node.strategies_vector): + dim_partition_dict_mapping = {} + communication_action_mapping = {} + input_sharding_spec = strategy.output_sharding_specs[self.op_data["input"]] + dim_partition_dict_for_input = input_sharding_spec.dim_partition_dict + dim_partition_dict_for_output = {} + + transpose_dims = self.op_data['transpose_dims'].data + dim_0 = transpose_dims[0] + dim_1 = transpose_dims[1] + for dim, sharded_dims in dim_partition_dict_for_input.items(): + if dim == dim_0: + dim_partition_dict_for_output[dim_1] = dim_partition_dict_for_input[dim_0] + elif dim == dim_1: + dim_partition_dict_for_output[dim_0] = dim_partition_dict_for_input[dim_1] + else: + dim_partition_dict_for_output[dim] = sharded_dims + + dim_partition_dict_mapping = { + "input": dim_partition_dict_for_input, + "output": dim_partition_dict_for_output, + } + sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) + + # add index into name to pass the duplicated check + # we keep same strategies with different name for node merging, and it will not increase the searching space, + # because in solver, this node will be merged into other nodes, and solver will not create a new variable for this node. + name = f'{sharding_spec_mapping["input"].sharding_sequence} -> {sharding_spec_mapping["output"].sharding_sequence}_{index}' + + strategy = self.get_sharding_strategy(name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping) + strategy_list.append(strategy) + + return strategy_list + + +class SplitGenerator(ReshapeGenerator): + """ + SplitGenerator deals with the sharding strategies of split op. + """ + + def collate_strategies(self) -> List[ShardingStrategy]: + strategy_list = [] + for index, strategy in enumerate(self.predecessor_node.strategies_vector): + recover_dims = None + dim_partition_dict_mapping = {} + communication_action_mapping = {} + input_sharding_spec = strategy.output_sharding_specs[self.op_data["input"]] + dim_partition_dict_for_input = copy.deepcopy(input_sharding_spec.dim_partition_dict) + split_size, split_dim = self.op_data['split_info'].data + + if split_dim in dim_partition_dict_for_input: + recover_dims = dim_partition_dict_for_input.pop(split_dim) + + dim_partition_dict_for_output = [ + copy.deepcopy(dim_partition_dict_for_input) for _ in range(len(self.op_data["output"].data)) + ] + assert len(dim_partition_dict_for_output) >= 2 + dim_partition_dict_mapping = { + "input": dim_partition_dict_for_input, + "output": dim_partition_dict_for_output, + } + sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) + # add index into name to pass the duplicated check + # we keep same strategies with different name for node merging, and it will not increase the searching space, + # because in solver, this node will be merged into other nodes, and solver will not create a new variable for this node. + name = f'{sharding_spec_mapping["input"].sharding_sequence}_{index}' + + # add comm action if the input need to be recovered to replica in the split dimension. + if recover_dims: + # if there is only one sharding dimension, we should use the value instead of list as logical_process_axis. + if len(recover_dims) == 1: + recover_dims = recover_dims[0] + input_comm_action = self.get_communication_action( + sharding_spec=sharding_spec_mapping["input"], + communication_pattern=CollectiveCommPattern.GATHER_FWD_SPLIT_BWD, + logical_process_axis=recover_dims, + comm_type=CommType.BEFORE, + arg_index=0) + # it will gather the input through gather_dim during forward phase. + input_comm_action.comm_spec.gather_dim = split_dim + # it will split the input activation grad through split_dim during backward phase. + input_comm_action.comm_spec.shard_dim = split_dim + + elif len(recover_dims) >= 2: + # original sharding spec + source_spec = input_sharding_spec + # target sharding spec + target_spec = sharding_spec_mapping["input"] + comm_spec = {'src_spec': source_spec, 'tgt_spec': target_spec} + input_comm_action = CommAction(comm_spec=comm_spec, comm_type=CommType.BEFORE, arg_index=0) + + else: + input_comm_action = None + + if input_comm_action is not None: + communication_action_mapping["input"] = input_comm_action + + strategy = self.get_sharding_strategy(name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping) + strategy_list.append(strategy) + + return strategy_list + + +class DefaultReshapeGenerator(ReshapeGenerator): + """ + DefaultReshapeGenerator which deals with the sharding strategies of Reshape Op which have to recover the tensor + to Replica status. + """ + + def collate_strategies(self) -> List[ShardingStrategy]: + strategy_list = [] + # For default reshape strategy, to keep the computing correctness we keep the + # sharding spec of input is fully replicated. In addition, we will keep the output + # in replica status and let the successor node choose the way to resharding the # output node. Therefore, the different strategies of input node with same # output sharding spec will generate same strategy for reshape function. for index, strategy in enumerate(self.predecessor_node.strategies_vector): @@ -114,9 +362,4 @@ class ReshapeGenerator(FollowingStrategyGenerator): communication_action_mapping=communication_action_mapping) strategy_list.append(strategy) - for strategy in strategy_list: - self.update_communication_cost(strategy) - self.update_compute_cost(strategy) - self.update_memory_cost(strategy) - return strategy_list diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/experimental/transpose_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/transpose_handler.py similarity index 90% rename from colossalai/auto_parallel/tensor_shard/node_handler/experimental/transpose_handler.py rename to colossalai/auto_parallel/tensor_shard/node_handler/transpose_handler.py index 3c7336a93..7a9d37726 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/experimental/transpose_handler.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/transpose_handler.py @@ -2,11 +2,10 @@ from typing import Dict, List import torch -from ...sharding_strategy import OperationData, OperationDataType -from ..node_handler import NodeHandler -from ..registry import operator_registry -from ..strategy import StrategyGenerator -from .reshape_generator import TransposeGenerator +from ..sharding_strategy import OperationData, OperationDataType +from .node_handler import NodeHandler +from .registry import operator_registry +from .strategy import StrategyGenerator, TransposeGenerator __all__ = ['TransposeHandler'] diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/experimental/view_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/view_handler.py similarity index 88% rename from colossalai/auto_parallel/tensor_shard/node_handler/experimental/view_handler.py rename to colossalai/auto_parallel/tensor_shard/node_handler/view_handler.py index 6be634593..7dff89d1d 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/experimental/view_handler.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/view_handler.py @@ -2,11 +2,10 @@ from typing import Dict, List import torch -from ...sharding_strategy import OperationData, OperationDataType -from ..node_handler import NodeHandler -from ..registry import operator_registry -from ..strategy import StrategyGenerator -from .reshape_generator import ViewGenerator +from ..sharding_strategy import OperationData, OperationDataType +from .node_handler import NodeHandler +from .registry import operator_registry +from .strategy import StrategyGenerator, ViewGenerator __all__ = ['ViewHandler'] diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_reshape_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_default_reshape_handler.py similarity index 91% rename from tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_reshape_handler.py rename to tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_default_reshape_handler.py index de277002b..ea7c2b729 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_reshape_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_default_reshape_handler.py @@ -1,8 +1,8 @@ import torch import torch.nn as nn +from colossalai.auto_parallel.tensor_shard.node_handler import DefaultReshapeHandler from colossalai.auto_parallel.tensor_shard.node_handler.conv_handler import ConvFunctionHandler -from colossalai.auto_parallel.tensor_shard.node_handler.reshape_handler import ReshapeHandler from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector from colossalai.device.device_mesh import DeviceMesh from colossalai.fx import ColoGraphModule, ColoTracer @@ -51,9 +51,9 @@ def test_reshape_handler(): strategies_vector=conv_strategies_vector) conv_handler.register_strategy(compute_resharding_cost=False) setattr(conv_mod_node, 'strategies_vector', conv_strategies_vector) - reshape_handler = ReshapeHandler(node=reshape_node, - device_mesh=device_mesh, - strategies_vector=reshape_strategies_vector) + reshape_handler = DefaultReshapeHandler(node=reshape_node, + device_mesh=device_mesh, + strategies_vector=reshape_strategies_vector) reshape_handler.register_strategy(compute_resharding_cost=False) diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_getitem_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_getitem_handler.py index 3c35da61b..c72d2a6a8 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_getitem_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_getitem_handler.py @@ -5,10 +5,10 @@ import torch import torch.multiprocessing as mp import torch.nn as nn +from colossalai.auto_parallel.tensor_shard.node_handler.default_reshape_handler import DefaultReshapeHandler from colossalai.auto_parallel.tensor_shard.node_handler.getitem_handler import GetItemHandler from colossalai.auto_parallel.tensor_shard.node_handler.linear_handler import LinearFunctionHandler from colossalai.auto_parallel.tensor_shard.node_handler.placeholder_handler import PlaceholderHandler -from colossalai.auto_parallel.tensor_shard.node_handler.reshape_handler import ReshapeHandler from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector from colossalai.device.device_mesh import DeviceMesh from colossalai.fx import ColoGraphModule, ColoTracer @@ -153,7 +153,9 @@ def test_getitem_from_tuple_handler(): ) input_handler.register_strategy(compute_resharding_cost=False) setattr(input_node, 'strategies_vector', input_strategies_vector) - split_handler = ReshapeHandler(node=split_node, device_mesh=device_mesh, strategies_vector=split_strategies_vector) + split_handler = DefaultReshapeHandler(node=split_node, + device_mesh=device_mesh, + strategies_vector=split_strategies_vector) split_handler.register_strategy(compute_resharding_cost=False) setattr(split_node, 'strategies_vector', split_strategies_vector) getitem_handler = GetItemHandler(node=getitem_node, diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_permute_and_transpose_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_permute_and_transpose_handler.py index c695b8843..b12db1332 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_permute_and_transpose_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_permute_and_transpose_handler.py @@ -5,8 +5,8 @@ import torch import torch.multiprocessing as mp import torch.nn as nn +from colossalai.auto_parallel.tensor_shard.node_handler import PermuteHandler, TransposeHandler from colossalai.auto_parallel.tensor_shard.node_handler.conv_handler import ConvFunctionHandler -from colossalai.auto_parallel.tensor_shard.node_handler.experimental import PermuteHandler, TransposeHandler from colossalai.auto_parallel.tensor_shard.node_handler.linear_handler import LinearFunctionHandler from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector from colossalai.device.device_mesh import DeviceMesh diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_split_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_split_handler.py index 9e8e905c5..813651869 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_split_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_split_handler.py @@ -5,8 +5,8 @@ import torch import torch.multiprocessing as mp import torch.nn as nn +from colossalai.auto_parallel.tensor_shard.node_handler import SplitHandler from colossalai.auto_parallel.tensor_shard.node_handler.conv_handler import ConvFunctionHandler -from colossalai.auto_parallel.tensor_shard.node_handler.experimental import SplitHandler from colossalai.auto_parallel.tensor_shard.node_handler.linear_handler import LinearFunctionHandler from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector from colossalai.device.device_mesh import DeviceMesh @@ -156,8 +156,7 @@ def check_split_handler(rank, split_size, split_dim, model_cls, world_size, port # reshape handler is a following strategy handler, so the number of strategies is equal to the predecessor node. assert len(split_strategies_vector) == len(previous_strategies_vector) strategy_name_list = [strategy.name for strategy in split_strategies_vector] - for name in strategy_name_list: - print(name) + if model_cls.__name__ == 'ConvSplitModel': if split_dim == 0: diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_view_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_view_handler.py index 08a702789..d07d2f76c 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_view_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_view_handler.py @@ -5,8 +5,8 @@ import torch import torch.multiprocessing as mp import torch.nn as nn +from colossalai.auto_parallel.tensor_shard.node_handler import ViewHandler from colossalai.auto_parallel.tensor_shard.node_handler.conv_handler import ConvFunctionHandler -from colossalai.auto_parallel.tensor_shard.node_handler.experimental import ViewHandler from colossalai.auto_parallel.tensor_shard.node_handler.linear_handler import LinearFunctionHandler from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector from colossalai.device.device_mesh import DeviceMesh