From 095854477f1992fef7c5692c3392e5adaa904792 Mon Sep 17 00:00:00 2001 From: YuliangLiu0306 <72588413+YuliangLiu0306@users.noreply.github.com> Date: Wed, 28 Sep 2022 11:24:59 +0800 Subject: [PATCH] [autoparallel] add conv handler v2 (#1663) --- .../solver/op_handler/conv_handler_v2.py | 145 ++++++ .../solver/op_handler/node_handler.py | 2 - .../auto_parallel/solver/strategy/__init__.py | 3 +- .../strategy/conv_strategy_generator.py | 491 ++++++++++++++++++ .../test_node_handler/test_conv_handler_v2.py | 210 ++++++++ .../test_linear_handler_v2.py | 2 +- 6 files changed, 849 insertions(+), 4 deletions(-) create mode 100644 colossalai/auto_parallel/solver/op_handler/conv_handler_v2.py create mode 100644 colossalai/auto_parallel/solver/strategy/conv_strategy_generator.py create mode 100644 tests/test_auto_parallel/test_node_handler/test_conv_handler_v2.py diff --git a/colossalai/auto_parallel/solver/op_handler/conv_handler_v2.py b/colossalai/auto_parallel/solver/op_handler/conv_handler_v2.py new file mode 100644 index 000000000..69a96c8ed --- /dev/null +++ b/colossalai/auto_parallel/solver/op_handler/conv_handler_v2.py @@ -0,0 +1,145 @@ +import torch +import torch.nn.functional as F +from .node_handler import ModuleHandler, NodeHandler +from ..sharding_strategy import ShardingStrategy_V2, OperationDataType, OperationData +from ..strategy import ConvStrategyGenerator, StrategyGenerator_V2 +from typing import List, Dict +from .registry import operator_registry + +__all__ = ['LinearModuleHandler', 'LinearFunctionHandler'] + + +@operator_registry.register(torch.nn.Conv1d) +@operator_registry.register(torch.nn.Conv2d) +@operator_registry.register(torch.nn.Conv3d) +class ConvModuleHandler(ModuleHandler): + """ + A ConvModuleHandler which deals with the sharding strategies for nn.Convxd module. + """ + + def get_strategy_generator(self) -> List[StrategyGenerator_V2]: + op_data_mapping = self.get_operation_data_mapping() + generators = [] + generators.append(ConvStrategyGenerator(op_data_mapping, self.device_mesh)) + return generators + + def get_operation_data_mapping(self) -> Dict[str, OperationData]: + # use transposed shape for strategies + # the strategies will be transformed back to its original shape in self.post_process + physical_input_operand = OperationData(name=str(self.node.args[0]), + type=OperationDataType.ARG, + data=self.node.args[0]._meta_data) + logical_shape_for_weight = list(self.named_parameters["weight"].shape) + logical_shape_for_weight[0], logical_shape_for_weight[1] = logical_shape_for_weight[ + 1], logical_shape_for_weight[0] + physical_other_operand = OperationData(name="weight", + type=OperationDataType.PARAM, + data=self.named_parameters['weight'], + logical_shape=torch.Size(logical_shape_for_weight)) + physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data) + + mapping = {"input": physical_input_operand, "other": physical_other_operand, "output": physical_output} + + if self.named_parameters['bias'] is not None: + physical_bias_operand = OperationData(name="bias", + type=OperationDataType.PARAM, + data=self.named_parameters['bias']) + mapping['bias'] = physical_bias_operand + return mapping + + def post_process(self, strategy: ShardingStrategy_V2): + """ + Convert the sharding spec of the weight parameter back to its original shape. + """ + for op_data, sharding_spec in strategy.input_sharding_specs.items(): + if op_data.name == "weight": + assert op_data.logical_shape != op_data.data.shape + dim_partition_dict = sharding_spec.dim_partition_dict + + # switch first and second dim of the conv module weight + first_dim_partition = dim_partition_dict.pop(1, None) + second_dim_partition = dim_partition_dict.pop(0, None) + + if first_dim_partition: + dim_partition_dict[0] = first_dim_partition + + if second_dim_partition: + dim_partition_dict[1] = second_dim_partition + + # re-init the sharding spec + sharding_spec.__init__(sharding_spec.device_mesh, sharding_spec.entire_shape, dim_partition_dict) + return strategy + + +@operator_registry.register(F.conv1d) +@operator_registry.register(F.conv2d) +@operator_registry.register(F.conv3d) +class ConvFunctionHandler(NodeHandler): + """ + A ConvFunctionHandler which deals with the sharding strategies for nn.functional.ConvXd functions. + """ + + def get_strategy_generator(self) -> List[StrategyGenerator_V2]: + op_data_mapping = self.get_operation_data_mapping() + generators = [] + generators.append(ConvStrategyGenerator(op_data_mapping, self.device_mesh)) + return generators + + def get_operation_data_mapping(self) -> Dict[str, OperationData]: + # use transposed shape for strategies + # the strategies will be transformed back to its original shape in self.post_process + physical_input_operand = OperationData(name=str(self.node.args[0]), + type=OperationDataType.ARG, + data=self.node.args[0]._meta_data) + + # check if the other operand is a parameter + if isinstance(self.node.args[1]._meta_data, torch.nn.parameter.Parameter): + data_type = OperationDataType.PARAM + else: + data_type = OperationDataType.ARG + + logical_shape_for_weight = list(self.node.args[1]._meta_data.shape) + logical_shape_for_weight[0], logical_shape_for_weight[1] = logical_shape_for_weight[ + 1], logical_shape_for_weight[0] + physical_other_operand = OperationData(name=str(self.node.args[1]), + type=data_type, + data=self.node.args[1]._meta_data, + logical_shape=torch.Size(logical_shape_for_weight)) + physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data) + + mapping = {"input": physical_input_operand, "other": physical_other_operand, "output": physical_output} + + if "bias" in self.node.kwargs: + # check if the other operand is a parameter + if isinstance(self.node.kwargs["bias"]._meta_data, torch.nn.parameter.Parameter): + data_type = OperationDataType.PARAM + else: + data_type = OperationDataType.ARG + physical_bias_operand = OperationData(name=str(self.node.kwargs["bias"]), + type=data_type, + data=self.node.kwargs["bias"]._meta_data) + mapping['bias'] = physical_bias_operand + return mapping + + def post_process(self, strategy: ShardingStrategy_V2): + """ + Convert the sharding spec of the weight parameter back to its original shape. + """ + for op_data, sharding_spec in strategy.input_sharding_specs.items(): + if op_data.name == str(self.node.args[1]): + assert op_data.logical_shape != op_data.data.shape + dim_partition_dict = sharding_spec.dim_partition_dict + + # switch first and second dim of the conv function weight + first_dim_partition = dim_partition_dict.pop(1, None) + second_dim_partition = dim_partition_dict.pop(0, None) + + if first_dim_partition: + dim_partition_dict[0] = first_dim_partition + + if second_dim_partition: + dim_partition_dict[1] = second_dim_partition + + # re-init the sharding spec + sharding_spec.__init__(sharding_spec.device_mesh, sharding_spec.entire_shape, dim_partition_dict) + return strategy diff --git a/colossalai/auto_parallel/solver/op_handler/node_handler.py b/colossalai/auto_parallel/solver/op_handler/node_handler.py index f18f61b5d..c539c5a2b 100644 --- a/colossalai/auto_parallel/solver/op_handler/node_handler.py +++ b/colossalai/auto_parallel/solver/op_handler/node_handler.py @@ -82,8 +82,6 @@ class ModuleHandler(NodeHandler): def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) - print("created") - # set attributes to access module parameters for convenience assert self.node.graph.owning_module is not None, \ f'The graph is not associated with a module, please make sure it can be used to instantiate a GraphModule object.' diff --git a/colossalai/auto_parallel/solver/strategy/__init__.py b/colossalai/auto_parallel/solver/strategy/__init__.py index 634b3e5af..568499095 100644 --- a/colossalai/auto_parallel/solver/strategy/__init__.py +++ b/colossalai/auto_parallel/solver/strategy/__init__.py @@ -1,7 +1,8 @@ from .strategy_generator import StrategyGenerator_V2 from .matmul_strategy_generator import DotProductStrategyGenerator, MatVecStrategyGenerator, LinearProjectionStrategyGenerator, BatchedMatMulStrategyGenerator +from .conv_strategy_generator import ConvStrategyGenerator __all__ = [ 'StrategyGenerator_V2', 'DotProductStrategyGenerator', 'MatVecStrategyGenerator', - 'LinearProjectionStrategyGenerator', 'BatchedMatMulStrategyGenerator' + 'LinearProjectionStrategyGenerator', 'BatchedMatMulStrategyGenerator', 'ConvStrategyGenerator' ] diff --git a/colossalai/auto_parallel/solver/strategy/conv_strategy_generator.py b/colossalai/auto_parallel/solver/strategy/conv_strategy_generator.py new file mode 100644 index 000000000..bddfc6b65 --- /dev/null +++ b/colossalai/auto_parallel/solver/strategy/conv_strategy_generator.py @@ -0,0 +1,491 @@ +import operator +from functools import reduce +from ..sharding_strategy import ShardingStrategy_V2, TrainCycleItem, MemoryCost +from colossalai.tensor.shape_consistency import CollectiveCommPattern +from .strategy_generator import StrategyGenerator_V2 +from typing import List +from .._utils import exception_handler +import copy + + +class ConvStrategyGenerator(StrategyGenerator_V2): + """ + ConvStrategyGenerator is a generic class to generate strategies. + The operation data is defined as `output = input x other + bias`. + """ + + @property + def has_bias(self): + return 'bias' in self.op_data + + def validate(self) -> bool: + ''' + In sanity check, we need make sure the input data having correct dimension size. + For Conv1d, the dim of input data should be 3([N, C, L]). + For Conv2d, the dim of input data should be 4([N, C, H, W]). + For Conv3d, the dim of input data should be 5([N, C, H, W, D]). + ''' + input_op_data = self.op_data['input'] + assert input_op_data.dim() in (3, 4, + 5), f'We suppose the dim of input fed into conv op should in range of [3, 5].' + + def update_compute_cost(self, strategy: ShardingStrategy_V2) -> TrainCycleItem: + ''' + Compute the computation cost per device with this specific strategy. + + Note: compute_cost need to be devided by TFLOPS, now it just shows the computation size. + ''' + # TODO: compute_cost need to be devided by TFLOPS, now it just shows the computation size. + # 1D: (L) * N * Cout * Cin * kernel + # 2D: (H * W) * N * Cout * Cin * kernel + # 3D: (H * W * D) * N * Cout * Cin * kernel + sharded_input_shape = strategy.sharding_specs[self.op_data['input']].get_sharded_shape_per_device() + sharded_other_shape = strategy.sharding_specs[self.op_data['other']].get_sharded_shape_per_device() + sharded_output_shape = strategy.sharding_specs[self.op_data['output']].get_sharded_shape_per_device() + if self.has_bias: + # bias add is an element wise operation, so the cost is equal to product of output shape. + bias_compute_cost = reduce(operator.mul, sharded_output_shape) + + output_size = sharded_output_shape[2:] + output_size_product = reduce(operator.mul, output_size) + input_size = sharded_input_shape[2:] + input_size_product = reduce(operator.mul, input_size, 1) + kernel_size = sharded_other_shape[2:] + kernel_size_product = reduce(operator.mul, kernel_size, 1) + batch_size = sharded_input_shape[0] + channel_in = sharded_input_shape[1] + channel_out = sharded_other_shape[1] + + forward_compute_cost = output_size_product * batch_size * channel_in * channel_out * kernel_size_product + + backward_activation_cost = input_size_product * batch_size * channel_in * channel_out * kernel_size_product + backward_weight_cost = output_size_product * batch_size * channel_in * channel_out * kernel_size_product + backward_compute_cost = backward_weight_cost + backward_activation_cost + if self.has_bias: + forward_compute_cost += bias_compute_cost + backward_compute_cost += bias_compute_cost + total_compute_cost = forward_compute_cost + backward_compute_cost + + compute_cost = TrainCycleItem(fwd=forward_compute_cost, bwd=backward_compute_cost, total=total_compute_cost) + return compute_cost + + def update_memory_cost(self, strategy: ShardingStrategy_V2) -> ShardingStrategy_V2: + forward_size_mapping = { + 'input': self._compute_size_in_bytes(strategy, "input"), + 'other': self._compute_size_in_bytes(strategy, "other"), + 'output': self._compute_size_in_bytes(strategy, "output") + } + + if self.has_bias: + bias_size = self._compute_size_in_bytes(strategy, "bias") + forward_size_mapping['bias'] = bias_size + + backward_size_mapping = copy.deepcopy(forward_size_mapping) + backward_size_mapping.pop("output") + # compute fwd cost incurred + # fwd_cost = input + other + bias + output + fwd_activation_cost = sum([v for k, v in forward_size_mapping.items() if not self.is_param(k)]) + fwd_parameter_cost = sum([v for k, v in forward_size_mapping.items() if self.is_param(k)]) + fwd_mem_cost = MemoryCost(activation=fwd_activation_cost, parameter=fwd_parameter_cost) + + # compute bwd cost incurred + # bwd_cost = input_grad + other_grad + bias_grad + bwd_activation_cost = sum([v for k, v in backward_size_mapping.items() if not self.is_param(k)]) + bwd_activation_cost = sum([v for k, v in backward_size_mapping.items() if self.is_param(k)]) + bwd_mem_cost = MemoryCost(activation=bwd_activation_cost, parameter=bwd_activation_cost) + + # compute total cost + total_mem_cost = MemoryCost(activation=fwd_activation_cost + bwd_activation_cost, + parameter=fwd_parameter_cost + bwd_activation_cost) + memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost) + strategy.memory_cost = memory_cost + + def split_input_batch_weight_out_channel(self, mesh_dim_0, mesh_dim_1): + name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}R x RS{mesh_dim_1}' + + dim_partition_dict_mapping = { + "input": { + 0: [mesh_dim_0] + }, + "other": { + 1: [mesh_dim_1] + }, + "output": { + 0: [mesh_dim_0], + 1: [mesh_dim_1] + }, + } + if self.has_bias: + dim_partition_dict_mapping["bias"] = {0: [mesh_dim_1]} + + sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) + + # set communication action + input_comm_spec = self.get_communication_spec( + sharding_spec=sharding_spec_mapping["input"], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=mesh_dim_1) + + communication_action_mapping = {"input": input_comm_spec} + + if self.is_param("other"): + other_comm_spec = self.get_communication_spec( + sharding_spec_mapping["other"], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=mesh_dim_0) + communication_action_mapping["other"] = other_comm_spec + + if self.has_bias and self.is_param("bias"): + bias_comm_spec = self.get_communication_spec( + sharding_spec_mapping["bias"], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=mesh_dim_0) + communication_action_mapping["bias"] = bias_comm_spec + + return self.get_sharding_strategy(name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping) + + def split_input_batch(self, mesh_dim_0): + name = f'S{mesh_dim_0}R = S{mesh_dim_0}R x RR' + + dim_partition_dict_mapping = { + "input": { + 0: [mesh_dim_0] + }, + "other": {}, + "output": { + 0: [mesh_dim_0], + }, + } + if self.has_bias: + dim_partition_dict_mapping["bias"] = {} + + sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) + + communication_action_mapping = {} + if self.is_param("other"): + other_comm_spec = self.get_communication_spec( + sharding_spec_mapping["other"], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=mesh_dim_0) + communication_action_mapping["other"] = other_comm_spec + + if self.has_bias and self.is_param("bias"): + bias_comm_spec = self.get_communication_spec( + sharding_spec_mapping["bias"], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=mesh_dim_0) + communication_action_mapping["bias"] = bias_comm_spec + + return self.get_sharding_strategy(name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping) + + def split_input_both_dim_weight_in_channel(self, mesh_dim_0, mesh_dim_1): + name = f'S{mesh_dim_0}R = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1}R' + + dim_partition_dict_mapping = { + "input": { + 0: [mesh_dim_0], + 1: [mesh_dim_1], + }, + "other": { + 0: [mesh_dim_1] + }, + "output": { + 0: [mesh_dim_0], + }, + } + if self.has_bias: + dim_partition_dict_mapping["bias"] = {} + + sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) + + # set communication action + output_comm_spec = self.get_communication_spec( + sharding_spec_mapping["output"], + communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD, + logical_process_axis=mesh_dim_1) + + communication_action_mapping = {"output": output_comm_spec} + + if self.is_param("other"): + other_comm_spec = self.get_communication_spec( + sharding_spec_mapping["other"], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=mesh_dim_0) + communication_action_mapping["other"] = other_comm_spec + + if self.has_bias and self.is_param("bias"): + bias_comm_spec = self.get_communication_spec( + sharding_spec_mapping["bias"], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=mesh_dim_0) + communication_action_mapping["bias"] = bias_comm_spec + + return self.get_sharding_strategy(name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping) + + def split_input_in_channel_weight_both_channel(self, mesh_dim_0, mesh_dim_1): + name = f'RS{mesh_dim_1} = RS{mesh_dim_0} x S{mesh_dim_0}S{mesh_dim_1}' + + dim_partition_dict_mapping = { + "input": { + 1: [mesh_dim_0], + }, + "other": { + 0: [mesh_dim_0], + 1: [mesh_dim_1], + }, + "output": { + 1: [mesh_dim_1], + }, + } + + if self.has_bias: + dim_partition_dict_mapping["bias"] = { + 0: [mesh_dim_1], + } + + sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) + + # set communication action + output_comm_spec = self.get_communication_spec( + sharding_spec_mapping["output"], + communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD, + logical_process_axis=mesh_dim_0) + input_comm_spec = self.get_communication_spec( + sharding_spec_mapping["input"], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=mesh_dim_0) + + communication_action_mapping = {"output": output_comm_spec, "input": input_comm_spec} + + return self.get_sharding_strategy(name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping) + + def split_input_in_channel_weight_in_channel(self, mesh_dim_0): + name = f'RR = RS{mesh_dim_0} x S{mesh_dim_0}R' + + dim_partition_dict_mapping = { + "input": { + 1: [mesh_dim_0], + }, + "other": { + 0: [mesh_dim_0], + }, + "output": {}, + } + + if self.has_bias: + dim_partition_dict_mapping["bias"] = {} + + sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) + + # set communication action + output_comm_spec = self.get_communication_spec( + sharding_spec_mapping["output"], + communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD, + logical_process_axis=mesh_dim_0) + + communication_action_mapping = {"output": output_comm_spec} + + return self.get_sharding_strategy(name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping) + + def split_weight_out_channel(self, mesh_dim_0): + name = f'RS{mesh_dim_0} = RR x RS{mesh_dim_0}' + + dim_partition_dict_mapping = { + "input": {}, + "other": { + 1: [mesh_dim_0], + }, + "output": { + 1: [mesh_dim_0], + }, + } + + if self.has_bias: + dim_partition_dict_mapping["bias"] = { + 0: [mesh_dim_0], + } + + sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) + + # set communication action + input_comm_spec = self.get_communication_spec( + sharding_spec_mapping["input"], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=mesh_dim_0) + + communication_action_mapping = {"input": input_comm_spec} + + return self.get_sharding_strategy(name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping) + + def non_split(self): + name = f'RR = RR x RR' + + dim_partition_dict_mapping = { + "input": {}, + "other": {}, + "output": {}, + } + + if self.has_bias: + dim_partition_dict_mapping["bias"] = {} + + sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) + + return self.get_sharding_strategy(name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping={}) + + def split_1d_parallel_on_input_batch(self, mesh_dim_0, mesh_dim_1): + name = f'S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x RR' + + dim_partition_dict_mapping = { + "input": { + 0: [mesh_dim_0, mesh_dim_1], + }, + "other": {}, + "output": { + 0: [mesh_dim_0, mesh_dim_1], + }, + } + + if self.has_bias: + dim_partition_dict_mapping["bias"] = {} + + sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) + + communication_action_mapping = {} + if self.is_param("other"): + other_comm_spec = self.get_communication_spec( + sharding_spec_mapping["other"], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=[mesh_dim_0, mesh_dim_1]) + communication_action_mapping["other"] = other_comm_spec + + if self.has_bias and self.is_param("bias"): + bias_comm_spec = self.get_communication_spec( + sharding_spec_mapping["bias"], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=[mesh_dim_0, mesh_dim_1]) + communication_action_mapping["bias"] = bias_comm_spec + + return self.get_sharding_strategy(name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping) + + def split_1d_parallel_on_in_channel(self, mesh_dim_0, mesh_dim_1): + name = f'RR = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}R' + dim_partition_dict_mapping = { + "input": { + 1: [mesh_dim_0, mesh_dim_1], + }, + "other": { + 0: [mesh_dim_0, mesh_dim_1], + }, + "output": {}, + } + + if self.has_bias: + dim_partition_dict_mapping["bias"] = {} + + sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) + + # set communication action + output_comm_spec = self.get_communication_spec( + sharding_spec_mapping["output"], + communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD, + logical_process_axis=[mesh_dim_0, mesh_dim_1]) + + communication_action_mapping = {"output": output_comm_spec} + + return self.get_sharding_strategy(name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping) + + def split_1d_parallel_on_out_channel(self, mesh_dim_0, mesh_dim_1): + name = f'RS{mesh_dim_0}{mesh_dim_1} = RR x RS{mesh_dim_0}{mesh_dim_1}' + dim_partition_dict_mapping = { + "input": {}, + "other": { + 1: [mesh_dim_0, mesh_dim_1], + }, + "output": { + 1: [mesh_dim_0, mesh_dim_1], + }, + } + + if self.has_bias: + dim_partition_dict_mapping["bias"] = { + 0: [mesh_dim_0, mesh_dim_1], + } + + sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) + + # set communication action + input_comm_spec = self.get_communication_spec( + sharding_spec_mapping["input"], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=[mesh_dim_0, mesh_dim_1]) + + communication_action_mapping = {"input": input_comm_spec} + + return self.get_sharding_strategy(name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping) + + def generate(self) -> List[ShardingStrategy_V2]: + strategies = [] + # SS = SR x RS + strategies.append(self.split_input_batch_weight_out_channel(0, 1)) + strategies.append(self.split_input_batch_weight_out_channel(1, 0)) + + # SR = SR x RR + strategies.append(self.split_input_batch(0)) + strategies.append(self.split_input_batch(1)) + + # SR = SS x SR + strategies.append(self.split_input_both_dim_weight_in_channel(0, 1)) + strategies.append(self.split_input_both_dim_weight_in_channel(1, 0)) + + # RS = RS x SS + strategies.append(self.split_input_in_channel_weight_both_channel(0, 1)) + strategies.append(self.split_input_in_channel_weight_both_channel(1, 0)) + + # RR = RS x SR + strategies.append(self.split_input_in_channel_weight_in_channel(0)) + strategies.append(self.split_input_in_channel_weight_in_channel(1)) + + # RS = RR x RS + strategies.append(self.split_weight_out_channel(0)) + strategies.append(self.split_weight_out_channel(1)) + + # RR= RR x RR + strategies.append(self.non_split()) + + # S01R = S01R x RR + strategies.append(self.split_1d_parallel_on_input_batch(0, 1)) + + # RR = RS01 x S01R + strategies.append(self.split_1d_parallel_on_in_channel(0, 1)) + + # RS01 = RR x RS01 + strategies.append(self.split_1d_parallel_on_out_channel(0, 1)) + + # update mete info on cost + for strategy in strategies: + self.update_communication_cost(strategy) + self.update_compute_cost(strategy) + self.update_memory_cost(strategy) + + return strategies diff --git a/tests/test_auto_parallel/test_node_handler/test_conv_handler_v2.py b/tests/test_auto_parallel/test_node_handler/test_conv_handler_v2.py new file mode 100644 index 000000000..c974fd34e --- /dev/null +++ b/tests/test_auto_parallel/test_node_handler/test_conv_handler_v2.py @@ -0,0 +1,210 @@ +from colossalai.fx.tracer.meta_patch.patched_module import linear +import torch +import torch.nn as nn +from colossalai.fx import ColoTracer, ColoGraphModule +from colossalai.auto_parallel.solver.op_handler.conv_handler_v2 import ConvModuleHandler, ConvFunctionHandler +from colossalai.auto_parallel.solver.sharding_strategy import OperationData, OperationDataType, StrategiesVector +from colossalai.device.device_mesh import DeviceMesh + + +def test_conv_module_handler(): + model = nn.Sequential(nn.Conv2d(4, 16, 3, padding=1).to('meta')) + tracer = ColoTracer() + # graph(): + # %input_1 : torch.Tensor [#users=1] = placeholder[target=input] + # %_0 : [#users=1] = call_module[target=0](args = (%input_1,), kwargs = {}) + # return _0 + graph = tracer.trace(model, meta_args={"input": torch.rand(4, 4, 64, 64).to('meta')}) + gm = ColoGraphModule(model, graph) + physical_mesh_id = torch.arange(0, 4) + + mesh_shape = (2, 2) + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) + conv_mod_node = list(graph.nodes)[1] + strategies_vector = StrategiesVector(conv_mod_node) + + # build handler + handler = ConvModuleHandler(node=conv_mod_node, device_mesh=device_mesh, strategies_vector=strategies_vector) + + # check operation data mapping + mapping = handler.get_operation_data_mapping() + + for name, op_data in mapping.items(): + op_data: OperationData + # make sure they have valid values + assert op_data.logical_shape is not None + assert op_data.data is not None + + assert mapping['input'].name == "input_1" + assert mapping['input'].data.is_meta + assert mapping['input'].data.shape == torch.Size([4, 4, 64, 64]) + assert mapping['input'].type == OperationDataType.ARG + assert mapping['input'].logical_shape == torch.Size([4, 4, 64, 64]) + + assert mapping['other'].name == "weight" + assert mapping['other'].data.is_meta + assert mapping['other'].data.shape == torch.Size([16, 4, 3, 3]) + assert mapping['other'].type == OperationDataType.PARAM + assert mapping['other'].logical_shape == torch.Size([4, 16, 3, 3]) + + assert mapping['bias'].name == "bias" + assert mapping['bias'].data.is_meta + assert mapping['bias'].data.shape == torch.Size([16]) + assert mapping['bias'].type == OperationDataType.PARAM + assert mapping['bias'].logical_shape == torch.Size([16]) + + assert mapping['output'].name == "_0" + assert mapping['output'].data.is_meta + assert mapping['output'].data.shape == torch.Size([4, 16, 64, 64]) + assert mapping['output'].type == OperationDataType.OUTPUT + + strategies_vector = handler.register_strategy() + strategy_name_list = [val.name for val in strategies_vector] + + # SS = SR x RS + assert 'S0S1 = S0R x RS1' in strategy_name_list + assert 'S1S0 = S1R x RS0' in strategy_name_list + + # SR = SR x RR + assert 'S0R = S0R x RR' in strategy_name_list + assert 'S1R = S1R x RR' in strategy_name_list + + # SR = SS x SR + assert 'S0R = S0S1 x S1R' in strategy_name_list + assert 'S1R = S1S0 x S0R' in strategy_name_list + + # RS = RS x SS + assert 'RS0 = RS1 x S1S0' in strategy_name_list + assert 'RS1 = RS0 x S0S1' in strategy_name_list + + # RR = RS x SR + assert 'RR = RS0 x S0R' in strategy_name_list + assert 'RR = RS1 x S1R' in strategy_name_list + + # RS= RR x RS + assert 'RS0 = RR x RS0' in strategy_name_list + assert 'RS1 = RR x RS1' in strategy_name_list + + # RR = RR x RR + assert 'RR = RR x RR' in strategy_name_list + + # S01R = S01R x RR + assert 'S01R = S01R x RR' in strategy_name_list + + # RR = RS01 x S01R + assert 'RR = RS01 x S01R' in strategy_name_list + + # RS01 = RR x RS01 + assert 'RS01 = RR x RS01' in strategy_name_list + + +class ConvModel(nn.Module): + + def __init__(self): + super().__init__() + + def forward(self, input, others, bias=None): + x = nn.functional.conv2d(input, others, bias=bias, padding=1) + return x + + +def test_conv_function_handler(): + model = ConvModel() + tracer = ColoTracer() + # graph(): + # %input_1 : torch.Tensor [#users=1] = placeholder[target=input] + # %others : torch.Tensor [#users=1] = placeholder[target=others] + # %conv2d : [#users=1] = call_function[target=torch.conv2d](args = (%input_1, %others), kwargs = {}) + # return conv2d + graph = tracer.trace(model, + meta_args={ + "input": torch.rand(4, 4, 64, 64).to('meta'), + "others": torch.rand(16, 4, 3, 3).to('meta'), + "bias": torch.rand(16).to('meta') + }) + gm = ColoGraphModule(model, graph) + physical_mesh_id = torch.arange(0, 4) + + mesh_shape = (2, 2) + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) + conv_mod_node = list(graph.nodes)[3] + strategies_vector = StrategiesVector(conv_mod_node) + + # build handler + handler = ConvFunctionHandler(node=conv_mod_node, device_mesh=device_mesh, strategies_vector=strategies_vector) + + # check operation data mapping + mapping = handler.get_operation_data_mapping() + + for name, op_data in mapping.items(): + op_data: OperationData + # make sure they have valid values + assert op_data.logical_shape is not None + assert op_data.data is not None + + assert mapping['input'].name == "input_1" + assert mapping['input'].data.is_meta + assert mapping['input'].data.shape == torch.Size([4, 4, 64, 64]) + assert mapping['input'].type == OperationDataType.ARG + assert mapping['input'].logical_shape == torch.Size([4, 4, 64, 64]) + + assert mapping['other'].name == "others" + assert mapping['other'].data.is_meta + assert mapping['other'].data.shape == torch.Size([16, 4, 3, 3]) + assert mapping['other'].type == OperationDataType.ARG + assert mapping['other'].logical_shape == torch.Size([4, 16, 3, 3]) + + assert mapping['bias'].name == "bias" + assert mapping['bias'].data.is_meta + assert mapping['bias'].data.shape == torch.Size([16]) + assert mapping['bias'].type == OperationDataType.ARG + assert mapping['bias'].logical_shape == torch.Size([16]) + + assert mapping['output'].name == "conv2d" + assert mapping['output'].data.is_meta + assert mapping['output'].data.shape == torch.Size([4, 16, 64, 64]) + assert mapping['output'].type == OperationDataType.OUTPUT + + strategies_vector = handler.register_strategy() + strategy_name_list = [val.name for val in strategies_vector] + + # SS = SR x RS + assert 'S0S1 = S0R x RS1' in strategy_name_list + assert 'S1S0 = S1R x RS0' in strategy_name_list + + # SR = SR x RR + assert 'S0R = S0R x RR' in strategy_name_list + assert 'S1R = S1R x RR' in strategy_name_list + + # SR = SS x SR + assert 'S0R = S0S1 x S1R' in strategy_name_list + assert 'S1R = S1S0 x S0R' in strategy_name_list + + # RS = RS x SS + assert 'RS0 = RS1 x S1S0' in strategy_name_list + assert 'RS1 = RS0 x S0S1' in strategy_name_list + + # RR = RS x SR + assert 'RR = RS0 x S0R' in strategy_name_list + assert 'RR = RS1 x S1R' in strategy_name_list + + # RS= RR x RS + assert 'RS0 = RR x RS0' in strategy_name_list + assert 'RS1 = RR x RS1' in strategy_name_list + + # RR = RR x RR + assert 'RR = RR x RR' in strategy_name_list + + # S01R = S01R x RR + assert 'S01R = S01R x RR' in strategy_name_list + + # RR = RS01 x S01R + assert 'RR = RS01 x S01R' in strategy_name_list + + # RS01 = RR x RS01 + assert 'RS01 = RR x RS01' in strategy_name_list + + +if __name__ == '__main__': + test_conv_module_handler() + test_conv_function_handler() diff --git a/tests/test_auto_parallel/test_node_handler/test_linear_handler_v2.py b/tests/test_auto_parallel/test_node_handler/test_linear_handler_v2.py index a69ff986d..993930060 100644 --- a/tests/test_auto_parallel/test_node_handler/test_linear_handler_v2.py +++ b/tests/test_auto_parallel/test_node_handler/test_linear_handler_v2.py @@ -48,7 +48,7 @@ def test_linear_module_handler(): assert mapping['bias'].data.is_meta assert mapping['bias'].data.shape == torch.Size([32]) assert mapping['bias'].type == OperationDataType.PARAM - assert mapping['other'].logical_shape == torch.Size([16, 32]) + assert mapping['bias'].logical_shape == torch.Size([32]) assert mapping['output'].name == "_0" assert mapping['output'].data.is_meta