From 56088e6d9858ad58550e0e0cb4302ad5132e7def Mon Sep 17 00:00:00 2001 From: YuliangLiu0306 <72588413+YuliangLiu0306@users.noreply.github.com> Date: Thu, 13 Oct 2022 13:42:13 +0800 Subject: [PATCH] [autoparallel] add pooling handler (#1690) * [autoparallel] add pooling handler * polish code --- .../op_handler/normal_pooling_handler.py | 40 ++++++ .../auto_parallel/solver/strategy/__init__.py | 3 +- .../strategy/normal_pooling_generator.py | 117 ++++++++++++++++++ .../test_norm_pooling_handler.py | 54 ++++++++ 4 files changed, 213 insertions(+), 1 deletion(-) create mode 100644 colossalai/auto_parallel/solver/op_handler/normal_pooling_handler.py create mode 100644 colossalai/auto_parallel/solver/strategy/normal_pooling_generator.py create mode 100644 tests/test_auto_parallel/test_node_handler/test_norm_pooling_handler.py diff --git a/colossalai/auto_parallel/solver/op_handler/normal_pooling_handler.py b/colossalai/auto_parallel/solver/op_handler/normal_pooling_handler.py new file mode 100644 index 000000000..59baa9631 --- /dev/null +++ b/colossalai/auto_parallel/solver/op_handler/normal_pooling_handler.py @@ -0,0 +1,40 @@ +import torch +import torch.nn.functional as F +from .node_handler import ModuleHandler, NodeHandler +from ..sharding_strategy import ShardingStrategy_V2, OperationDataType, OperationData +from ..strategy import NormalPoolStrategyGenerator, StrategyGenerator_V2 +from typing import List, Dict +from .registry import operator_registry + +__all__ = ['LinearModuleHandler', 'LinearFunctionHandler'] + + +@operator_registry.register(torch.nn.MaxPool1d) +@operator_registry.register(torch.nn.MaxPool2d) +@operator_registry.register(torch.nn.MaxPool1d) +@operator_registry.register(torch.nn.AvgPool1d) +@operator_registry.register(torch.nn.AvgPool2d) +@operator_registry.register(torch.nn.AvgPool3d) +class NormPoolingHandler(ModuleHandler): + """ + A NormPoolingHandler which deals with the sharding strategies for nn.MaxPoolxd module. + """ + + def get_strategy_generator(self) -> List[StrategyGenerator_V2]: + op_data_mapping = self.get_operation_data_mapping() + generators = [] + generators.append(NormalPoolStrategyGenerator(op_data_mapping, self.device_mesh)) + return generators + + def get_operation_data_mapping(self) -> Dict[str, OperationData]: + # use transposed shape for strategies + # the strategies will be transformed back to its original shape in self.post_process + physical_input_operand = OperationData(name=str(self.node.args[0]), + type=OperationDataType.ARG, + data=self.node.args[0]._meta_data) + physical_weight_operand = OperationData(name="kernel", type=OperationDataType.ARG, data=self.module.kernel_size) + physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data) + + mapping = {"input": physical_input_operand, "other": physical_weight_operand, "output": physical_output} + + return mapping diff --git a/colossalai/auto_parallel/solver/strategy/__init__.py b/colossalai/auto_parallel/solver/strategy/__init__.py index a7bffb2e8..e7ecbb58c 100644 --- a/colossalai/auto_parallel/solver/strategy/__init__.py +++ b/colossalai/auto_parallel/solver/strategy/__init__.py @@ -7,10 +7,11 @@ from .getitem_generator import GetItemStrategyGenerator, TensorStrategyGenerator from .layer_norm_generator import LayerNormGenerator from .where_generator import WhereGenerator from .reshape_generator import ReshapeGenerator +from .normal_pooling_generator import NormalPoolStrategyGenerator __all__ = [ 'StrategyGenerator_V2', 'DotProductStrategyGenerator', 'MatVecStrategyGenerator', 'LinearProjectionStrategyGenerator', 'BatchedMatMulStrategyGenerator', 'ConvStrategyGenerator', 'UnaryElementwiseGenerator', 'BatchNormStrategyGenerator', 'GetItemStrategyGenerator', 'TensorStrategyGenerator', - 'TensorTupleStrategyGenerator', 'LayerNormGenerator', "WhereGenerator", 'ReshapeGenerator' + 'TensorTupleStrategyGenerator', 'LayerNormGenerator', "WhereGenerator", 'ReshapeGenerator', 'NormalPoolStrategyGenerator' ] diff --git a/colossalai/auto_parallel/solver/strategy/normal_pooling_generator.py b/colossalai/auto_parallel/solver/strategy/normal_pooling_generator.py new file mode 100644 index 000000000..a6d416797 --- /dev/null +++ b/colossalai/auto_parallel/solver/strategy/normal_pooling_generator.py @@ -0,0 +1,117 @@ +import operator +from functools import reduce +from ..sharding_strategy import ShardingStrategy_V2, TrainCycleItem, MemoryCost +from colossalai.tensor.shape_consistency import CollectiveCommPattern +from .strategy_generator import StrategyGenerator_V2 +from typing import List +from .._utils import exception_handler, enumerate_all_possible_1d_sharding, enumerate_all_possible_2d_sharding +import copy + + +class NormalPoolStrategyGenerator(StrategyGenerator_V2): + """ + NormalPoolStrategyGenerator is a generic class to generate strategies for pool operation like MaxPoolxd. + The reason we call this normal pool is AvgPoolxd and MaxPoolxd are taking the kernel size element from image, + and reduce them depening on the operation type. + """ + + def validate(self) -> bool: + ''' + In sanity check, we need make sure the input data having correct dimension size. + For Pool1d, the dim of input data should be 3([N, C, L]). + For Pool2d, the dim of input data should be 4([N, C, H, W]). + For Pool3d, the dim of input data should be 5([N, C, H, W, D]). + ''' + input_op_data = self.op_data['input'] + assert input_op_data.dim() in (3, 4, + 5), f'We suppose the dim of input fed into Pool op should in range of [3, 5].' + + def update_compute_cost(self, strategy: ShardingStrategy_V2) -> TrainCycleItem: + ''' + Compute the computation cost per device with this specific strategy. + + Note: compute_cost need to be devided by TFLOPS, now it just shows the computation size. + ''' + # TODO: compute_cost need to be devided by TFLOPS, now it just shows the computation size. + # 1D: (Lout) * N * C * kernel + # 2D: (H * W) * N * Cout * Cin * kernel + # 3D: (H * W * D) * N * Cout * Cin * kernel + sharded_output_shape = strategy.sharding_specs[self.op_data['output']].get_sharded_shape_per_device() + sharded_input_shape = strategy.sharding_specs[self.op_data['input']].get_sharded_shape_per_device() + + kernel_size = self.op_data["other"].data + if isinstance(kernel_size, int): + kernel_size = [kernel_size] * (len(sharded_output_shape) - 2) + kernel_size_product = reduce(operator.mul, kernel_size) + output_size_product = reduce(operator.mul, sharded_output_shape) + input_size_product = reduce(operator.mul, sharded_input_shape) + + forward_compute_cost = output_size_product * kernel_size_product + backward_compute_cost = input_size_product * kernel_size_product + + total_compute_cost = forward_compute_cost + backward_compute_cost + + compute_cost = TrainCycleItem(fwd=forward_compute_cost, bwd=backward_compute_cost, total=total_compute_cost) + return compute_cost + + def update_memory_cost(self, strategy: ShardingStrategy_V2) -> ShardingStrategy_V2: + forward_size_mapping = { + 'input': self._compute_size_in_bytes(strategy, "input"), + 'output': self._compute_size_in_bytes(strategy, "output") + } + + backward_size_mapping = copy.deepcopy(forward_size_mapping) + backward_size_mapping.pop("output") + # compute fwd cost incurred + # fwd_cost = input + output + fwd_activation_cost = sum([v for k, v in forward_size_mapping.items()]) + fwd_mem_cost = MemoryCost(activation=fwd_activation_cost, parameter=0) + + # compute bwd cost incurred + # bwd_cost = input_grad + bwd_activation_cost = sum([v for k, v in backward_size_mapping.items()]) + bwd_mem_cost = MemoryCost(activation=bwd_activation_cost, parameter=0) + + # compute total cost + total_mem_cost = MemoryCost(activation=fwd_activation_cost + bwd_activation_cost, parameter=0) + memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost) + strategy.memory_cost = memory_cost + + def _generate_strategy_with_dim_partition(self, dim_partition): + dim_partition_dict_mapping = {"input": dim_partition, "output": dim_partition} + + sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) + + name = f'{sharding_spec_mapping["output"].sharding_sequence} = {sharding_spec_mapping["input"].sharding_sequence}' + communication_action_mapping = {} + + strategy = self.get_sharding_strategy(name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping) + + return strategy + + def enumerate_all_possible_batch_dimensions_dim_partition(self, mesh_dim_0, mesh_dim_1): + dim_partition_list = [] + dim_partition_list.extend(enumerate_all_possible_1d_sharding(mesh_dim_0, 2)) + dim_partition_list.extend(enumerate_all_possible_1d_sharding(mesh_dim_1, 2)) + dim_partition_list.extend(enumerate_all_possible_2d_sharding(mesh_dim_0, mesh_dim_1, 2)) + # append {} for non_split case + dim_partition_list.append({}) + + return dim_partition_list + + def generate(self) -> List[ShardingStrategy_V2]: + strategy_list = [] + + dim_partition_list = self.enumerate_all_possible_batch_dimensions_dim_partition(0, 1) + for dim_partition in dim_partition_list: + strategy = self._generate_strategy_with_dim_partition(dim_partition) + strategy_list.append(strategy) + + for strategy in strategy_list: + self.update_communication_cost(strategy) + self.update_compute_cost(strategy) + self.update_memory_cost(strategy) + + return strategy_list diff --git a/tests/test_auto_parallel/test_node_handler/test_norm_pooling_handler.py b/tests/test_auto_parallel/test_node_handler/test_norm_pooling_handler.py new file mode 100644 index 000000000..3b03c7e91 --- /dev/null +++ b/tests/test_auto_parallel/test_node_handler/test_norm_pooling_handler.py @@ -0,0 +1,54 @@ +from colossalai.fx.tracer.meta_patch.patched_module import linear +import torch +import torch.nn as nn +from colossalai.fx import ColoTracer, ColoGraphModule +from colossalai.auto_parallel.solver.op_handler.normal_pooling_handler import NormPoolingHandler +from colossalai.auto_parallel.solver.sharding_strategy import OperationData, OperationDataType, StrategiesVector +from colossalai.device.device_mesh import DeviceMesh + + +def test_norm_pool_handler(): + model = nn.Sequential(nn.MaxPool2d(4, padding=1).to('meta')) + tracer = ColoTracer() + # graph(): + # %input_1 : torch.Tensor [#users=1] = placeholder[target=input] + # %_0 : [#users=1] = call_module[target=0](args = (%input_1,), kwargs = {}) + # return _0 + graph = tracer.trace(model, meta_args={"input": torch.rand(4, 4, 64, 64).to('meta')}) + + gm = ColoGraphModule(model, graph) + physical_mesh_id = torch.arange(0, 4) + + mesh_shape = (2, 2) + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) + conv_mod_node = list(graph.nodes)[1] + strategies_vector = StrategiesVector(conv_mod_node) + + # build handler + handler = NormPoolingHandler(node=conv_mod_node, device_mesh=device_mesh, strategies_vector=strategies_vector) + # check operation data mapping + mapping = handler.get_operation_data_mapping() + + for name, op_data in mapping.items(): + op_data: OperationData + # make sure they have valid values + assert op_data.data is not None + + assert mapping['input'].name == "input_1" + assert mapping['input'].data.is_meta + assert mapping['input'].data.shape == torch.Size([4, 4, 64, 64]) + assert mapping['input'].type == OperationDataType.ARG + assert mapping['input'].logical_shape == torch.Size([4, 4, 64, 64]) + + assert mapping['output'].name == "_0" + assert mapping['output'].data.is_meta + assert mapping['output'].data.shape == torch.Size([4, 4, 16, 16]) + assert mapping['output'].type == OperationDataType.OUTPUT + + strategies_vector = handler.register_strategy() + strategy_name_list = [val.name for val in strategies_vector] + assert len(strategy_name_list) == 9 + + +if __name__ == '__main__': + test_norm_pool_handler()