diff --git a/colossalai/auto_parallel/solver/op_handler/batch_norm_handler_v2.py b/colossalai/auto_parallel/solver/op_handler/batch_norm_handler_v2.py new file mode 100644 index 000000000..185327c94 --- /dev/null +++ b/colossalai/auto_parallel/solver/op_handler/batch_norm_handler_v2.py @@ -0,0 +1,45 @@ +import torch +import torch.nn.functional as F +from .node_handler import ModuleHandler, NodeHandler +from ..sharding_strategy import ShardingStrategy_V2, OperationDataType, OperationData +from ..strategy import BatchNormStrategyGenerator, StrategyGenerator_V2 +from typing import List, Dict +from .registry import operator_registry + +__all__ = ['BatchNormModuleHandler'] + + +@operator_registry.register(torch.nn.BatchNorm1d) +@operator_registry.register(torch.nn.BatchNorm2d) +@operator_registry.register(torch.nn.BatchNorm3d) +class BatchNormModuleHandler(ModuleHandler): + """ + A BatchNormModuleHandler which deals with the sharding strategies for nn.BatchNormXd module. + """ + + def get_strategy_generator(self) -> List[StrategyGenerator_V2]: + op_data_mapping = self.get_operation_data_mapping() + generators = [] + generators.append(BatchNormStrategyGenerator(op_data_mapping, self.device_mesh)) + return generators + + def get_operation_data_mapping(self) -> Dict[str, OperationData]: + # use transposed shape for strategies + # the strategies will be transformed back to its original shape in self.post_process + physical_input_operand = OperationData(name=str(self.node.args[0]), + type=OperationDataType.ARG, + data=self.node.args[0]._meta_data) + physical_other_operand = OperationData(name="weight", + type=OperationDataType.PARAM, + data=self.named_parameters['weight'], + logical_shape=self.named_parameters['weight'].shape) + physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data) + + mapping = {"input": physical_input_operand, "other": physical_other_operand, "output": physical_output} + + if self.named_parameters['bias'] is not None: + physical_bias_operand = OperationData(name="bias", + type=OperationDataType.PARAM, + data=self.named_parameters['bias']) + mapping['bias'] = physical_bias_operand + return mapping diff --git a/colossalai/auto_parallel/solver/strategy/batch_norm_generator.py b/colossalai/auto_parallel/solver/strategy/batch_norm_generator.py new file mode 100644 index 000000000..8e9a16c55 --- /dev/null +++ b/colossalai/auto_parallel/solver/strategy/batch_norm_generator.py @@ -0,0 +1,291 @@ +import operator +from functools import reduce +from ..sharding_strategy import ShardingStrategy_V2, TrainCycleItem, MemoryCost +from colossalai.tensor.shape_consistency import CollectiveCommPattern +from .strategy_generator import StrategyGenerator_V2 +from typing import List +from .._utils import exception_handler +import copy + +__all__ = ['BatchNormStrategyGenerator'] + + +class BatchNormStrategyGenerator(StrategyGenerator_V2): + """ + A StrategyGenerator which deals with the sharding strategies of batch normalization. + + To keep the math consistency, there are two way to do BatchNorm if the input + shards on batch dimension: + 1. We gather the input partitions through batch dimension, then do the normal BatchNorm. + 2. We do the SyncBatchNorm on the each input partition seperately, the SyncBN op will help + us to keep the computing correctness. + In this generator, both methods will be considered. + """ + + @property + def has_bias(self): + return 'bias' in self.op_data + + def validate(self) -> bool: + ''' + In sanity check, we need make sure the input data having correct dimension size. + For BatchNorm1d, the dim of input data should be 3([N, C, L]). + For BatchNorm2d, the dim of input data should be 4([N, C, H, W]). + For BatchNorm3d, the dim of input data should be 5([N, C, H, W, D]). + ''' + input_op_data = self.op_data['input'] + assert input_op_data.dim() in (3, 4, + 5), f'We suppose the dim of input fed into conv op should in range of [3, 5].' + + def update_compute_cost(self, strategy: ShardingStrategy_V2) -> TrainCycleItem: + ''' + Compute the computation cost per device with this specific strategy. + + Note: compute_cost need to be devided by TFLOPS, now it just shows the computation size. + ''' + # TODO: a constant coefficient need to be added. + # 1D: (L) * N * Cin + # 2D: (H * W) * N * Cin + # 3D: (H * W * D) * N * Cin + sharded_input_shape = strategy.sharding_specs[self.op_data['input']].get_sharded_shape_per_device() + sharded_output_shape = strategy.sharding_specs[self.op_data['output']].get_sharded_shape_per_device() + if self.has_bias: + # bias add is an element wise operation, so the cost is equal to product of output shape. + bias_compute_cost = reduce(operator.mul, sharded_output_shape) + input_product = reduce(operator.mul, sharded_input_shape, 1) + forward_compute_cost = input_product + backward_activation_compute_cost = input_product + backward_weight_compute_cost = input_product + backward_compute_cost = backward_weight_compute_cost + backward_activation_compute_cost + if self.has_bias: + forward_compute_cost += bias_compute_cost + backward_compute_cost += bias_compute_cost + total_compute_cost = forward_compute_cost + backward_compute_cost + compute_cost = TrainCycleItem(fwd=forward_compute_cost, bwd=backward_compute_cost, total=total_compute_cost) + return compute_cost + + def update_memory_cost(self, strategy: ShardingStrategy_V2) -> TrainCycleItem: + forward_size_mapping = { + 'input': self._compute_size_in_bytes(strategy, "input"), + 'other': self._compute_size_in_bytes(strategy, "other"), + 'output': self._compute_size_in_bytes(strategy, "output") + } + + if self.has_bias: + bias_size = self._compute_size_in_bytes(strategy, "bias") + forward_size_mapping['bias'] = bias_size + + backward_size_mapping = copy.deepcopy(forward_size_mapping) + backward_size_mapping.pop("output") + # compute fwd cost incurred + # fwd_cost = input + other + bias + output + fwd_activation_cost = sum([v for k, v in forward_size_mapping.items() if not self.is_param(k)]) + fwd_parameter_cost = sum([v for k, v in forward_size_mapping.items() if self.is_param(k)]) + fwd_mem_cost = MemoryCost(activation=fwd_activation_cost, parameter=fwd_parameter_cost) + + # compute bwd cost incurred + # bwd_cost = input_grad + other_grad + bias_grad + bwd_activation_cost = sum([v for k, v in backward_size_mapping.items() if not self.is_param(k)]) + bwd_activation_cost = sum([v for k, v in backward_size_mapping.items() if self.is_param(k)]) + bwd_mem_cost = MemoryCost(activation=bwd_activation_cost, parameter=bwd_activation_cost) + + # compute total cost + total_mem_cost = MemoryCost(activation=fwd_activation_cost + bwd_activation_cost, + parameter=fwd_parameter_cost + bwd_activation_cost) + memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost) + strategy.memory_cost = memory_cost + + def split_input_channel(self, mesh_dim_0): + strategy_list = [] + name = f'RS{mesh_dim_0} = RS{mesh_dim_0} x S{mesh_dim_0}' + dim_partition_dict_mapping = { + "input": { + 1: [mesh_dim_0] + }, + "other": { + 0: [mesh_dim_0] + }, + "output": { + 1: [mesh_dim_0] + }, + } + if self.has_bias: + dim_partition_dict_mapping["bias"] = {0: [mesh_dim_0]} + + sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) + + communication_action_mapping = {} + + return self.get_sharding_strategy(name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping) + + def split_input_channel_1d(self, mesh_dim_0, mesh_dim_1): + name = f'RS{mesh_dim_0}{mesh_dim_1} = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}' + dim_partition_dict_mapping = { + "input": { + 1: [mesh_dim_0, mesh_dim_1] + }, + "other": { + 0: [mesh_dim_0, mesh_dim_1] + }, + "output": { + 1: [mesh_dim_0, mesh_dim_1] + }, + } + if self.has_bias: + dim_partition_dict_mapping["bias"] = {0: [mesh_dim_0, mesh_dim_1]} + + sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) + + communication_action_mapping = {} + + return self.get_sharding_strategy(name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping) + + def non_split(self): + name = f'RR = RR x R' + dim_partition_dict_mapping = { + "input": {}, + "other": {}, + "output": {}, + } + if self.has_bias: + dim_partition_dict_mapping["bias"] = {} + + sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) + + communication_action_mapping = {} + + return self.get_sharding_strategy(name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping) + + def split_input_batch(self, mesh_dim_0): + name = f'S{mesh_dim_0}R = S{mesh_dim_0}R x R WITH SYNC_BN' + dim_partition_dict_mapping = { + "input": { + 0: [mesh_dim_0] + }, + "other": {}, + "output": { + 0: [mesh_dim_0] + }, + } + if self.has_bias: + dim_partition_dict_mapping["bias"] = {} + + sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) + + # set communication action + # For SyncBN case, we don't need to do communication for weight and bias. + # TODO: the communication happens interally at SyncBN operation. We need to replace the BN operation + # to SyncBN operation instead of inserting a communication node. + output_comm_spec = self.get_communication_spec( + sharding_spec=sharding_spec_mapping["output"], + communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD, + logical_process_axis=mesh_dim_0) + + communication_action_mapping = {"output": output_comm_spec} + + return self.get_sharding_strategy(name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping) + + def split_input_batch_1d(self, mesh_dim_0, mesh_dim_1): + name = f'S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x R WITH SYNC_BN' + dim_partition_dict_mapping = { + "input": { + 0: [mesh_dim_0, mesh_dim_1] + }, + "other": {}, + "output": { + 0: [mesh_dim_0, mesh_dim_1] + }, + } + if self.has_bias: + dim_partition_dict_mapping["bias"] = {} + + sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) + + # set communication action + # For SyncBN case, we don't need to do communication for gradients of weight and bias. + # TODO: the communication happens interally at SyncBN operation. We need to replace the BN operation + # to SyncBN operation instead of inserting a communication node. + output_comm_spec = self.get_communication_spec( + sharding_spec=sharding_spec_mapping["output"], + communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD, + logical_process_axis=[mesh_dim_0, mesh_dim_1]) + + communication_action_mapping = {"output": output_comm_spec} + + return self.get_sharding_strategy(name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping) + + def split_input_both_dim(self, mesh_dim_0, mesh_dim_1): + name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1} WITH SYNC_BN' + dim_partition_dict_mapping = { + "input": { + 0: [mesh_dim_0], + 1: [mesh_dim_1], + }, + "other": { + 0: [mesh_dim_1], + }, + "output": { + 0: [mesh_dim_0], + 1: [mesh_dim_1], + }, + } + if self.has_bias: + dim_partition_dict_mapping["bias"] = { + 0: [mesh_dim_1], + } + + sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) + + # set communication action + # For SyncBN case, we don't need to do communication for gradients of weight and bias. + # TODO: the communication happens interally at SyncBN operation. We need to replace the BN operation + # to SyncBN operation instead of inserting a communication node. + output_comm_spec = self.get_communication_spec( + sharding_spec=sharding_spec_mapping["output"], + communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD, + logical_process_axis=[mesh_dim_0]) + + communication_action_mapping = {"output": output_comm_spec} + + return self.get_sharding_strategy(name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping) + + def generate(self): + ''' + Generate every possible strategies for a BatchNorm node, and record all strategies into the strategies_vector. + ''' + + strategy_list = [] + # RS = RS x S + strategy_list.append(self.split_input_channel(0)) + strategy_list.append(self.split_input_channel(1)) + + # RR = RR x R + strategy_list.append(self.non_split()) + + # RS01 = RS01 x S01 + strategy_list.append(self.split_input_channel_1d(0, 1)) + + # SR = SR x R WITH SYNC_BN + strategy_list.append(self.split_input_batch(0)) + strategy_list.append(self.split_input_batch(1)) + + # SS = SS x S WITH SYNC_BN + strategy_list.append(self.split_input_both_dim(0, 1)) + strategy_list.append(self.split_input_both_dim(1, 0)) + + # S01R = S01R x R WITH SYNC_BN + strategy_list.append(self.split_input_batch_1d(0, 1)) + + return strategy_list diff --git a/tests/test_auto_parallel/test_node_handler/test_batch_norm_handler_v2.py b/tests/test_auto_parallel/test_node_handler/test_batch_norm_handler_v2.py new file mode 100644 index 000000000..c5fb9326e --- /dev/null +++ b/tests/test_auto_parallel/test_node_handler/test_batch_norm_handler_v2.py @@ -0,0 +1,88 @@ +from colossalai.fx.tracer.meta_patch.patched_module import linear +import torch +import torch.nn as nn +from colossalai.fx import ColoTracer, ColoGraphModule +from colossalai.auto_parallel.solver.op_handler.batch_norm_handler_v2 import BatchNormModuleHandler +from colossalai.auto_parallel.solver.sharding_strategy import OperationData, OperationDataType, StrategiesVector +from colossalai.device.device_mesh import DeviceMesh + + +def test_bn_module_handler(): + model = nn.Sequential(nn.BatchNorm2d(16).to('meta')) + tracer = ColoTracer() + # graph(): + # %input_1 : torch.Tensor [#users=1] = placeholder[target=input] + # %_0 : [#users=1] = call_module[target=0](args = (%input_1,), kwargs = {}) + # return _0 + graph = tracer.trace(model, meta_args={"input": torch.rand(4, 16, 64, 64).to('meta')}) + gm = ColoGraphModule(model, graph) + physical_mesh_id = torch.arange(0, 4) + + mesh_shape = (2, 2) + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) + bn_mod_node = list(graph.nodes)[1] + strategies_vector = StrategiesVector(bn_mod_node) + + # build handler + handler = BatchNormModuleHandler(node=bn_mod_node, device_mesh=device_mesh, strategies_vector=strategies_vector) + + # check operation data mapping + mapping = handler.get_operation_data_mapping() + + for name, op_data in mapping.items(): + op_data: OperationData + # make sure they have valid values + assert op_data.logical_shape is not None + assert op_data.data is not None + + assert mapping['input'].name == "input_1" + assert mapping['input'].data.is_meta + assert mapping['input'].data.shape == torch.Size([4, 16, 64, 64]) + assert mapping['input'].type == OperationDataType.ARG + assert mapping['input'].logical_shape == torch.Size([4, 16, 64, 64]) + + assert mapping['other'].name == "weight" + assert mapping['other'].data.is_meta + assert mapping['other'].data.shape == torch.Size([16]) + assert mapping['other'].type == OperationDataType.PARAM + assert mapping['other'].logical_shape == torch.Size([16]) + + assert mapping['bias'].name == "bias" + assert mapping['bias'].data.is_meta + assert mapping['bias'].data.shape == torch.Size([16]) + assert mapping['bias'].type == OperationDataType.PARAM + assert mapping['bias'].logical_shape == torch.Size([16]) + + assert mapping['output'].name == "_0" + assert mapping['output'].data.is_meta + assert mapping['output'].data.shape == torch.Size([4, 16, 64, 64]) + assert mapping['output'].type == OperationDataType.OUTPUT + + strategies_vector = handler.register_strategy() + #[ 'S01R = S01R x R WITH SYNC_BN'] + strategy_name_list = [val.name for val in strategies_vector] + + # RS = RS x S + assert 'RS0 = RS0 x S0' in strategy_name_list + assert 'RS1 = RS1 x S1' in strategy_name_list + + # RR = RR x R + assert 'RR = RR x R' in strategy_name_list + + # RS01 = RS01 x S01 + assert 'RS01 = RS01 x S01' in strategy_name_list + + # SR = SR x R WITH SYNC_BN + assert 'S0R = S0R x R WITH SYNC_BN' in strategy_name_list + assert 'S1R = S1R x R WITH SYNC_BN' in strategy_name_list + + # SS = SS x S WITH SYNC_BN + assert 'S0S1 = S0S1 x S1 WITH SYNC_BN' in strategy_name_list + assert 'S1S0 = S1S0 x S0 WITH SYNC_BN' in strategy_name_list + + # S01R = S01R x R WITH SYNC_BN + assert 'S01R = S01R x R WITH SYNC_BN' in strategy_name_list + + +if __name__ == '__main__': + test_bn_module_handler()