From 0c703189b975b3295e42cdc8a4341c5cfae1d6b4 Mon Sep 17 00:00:00 2001 From: YuliangLiu0306 <72588413+YuliangLiu0306@users.noreply.github.com> Date: Fri, 23 Sep 2022 12:00:25 +0800 Subject: [PATCH] [autoparallel] add layernorm handler (#1629) --- colossalai/auto_parallel/solver/_utils.py | 38 ++- colossalai/auto_parallel/solver/constants.py | 18 +- .../solver/op_handler/bcast_op_handler.py | 40 +-- .../solver/op_handler/layer_norm_handler.py | 233 ++++++++++++++++++ .../solver/strategies_constructor.py | 95 +++++-- .../test_layer_norm_handler.py | 70 ++++++ 6 files changed, 433 insertions(+), 61 deletions(-) create mode 100644 colossalai/auto_parallel/solver/op_handler/layer_norm_handler.py create mode 100644 tests/test_auto_parallel/test_layer_norm_handler.py diff --git a/colossalai/auto_parallel/solver/_utils.py b/colossalai/auto_parallel/solver/_utils.py index c9f85fb01..2c545b74c 100644 --- a/colossalai/auto_parallel/solver/_utils.py +++ b/colossalai/auto_parallel/solver/_utils.py @@ -94,7 +94,43 @@ def exception_handler(func): def wrapper(*args, **kwargs): try: func(*args, **kwargs) - except Exception as e: + except AssertionError as e: warnings.warn(f'{e}') return wrapper + + +def enumerate_all_possible_2d_sharding(mesh_dim_0, mesh_dim_1, dim_size): + dim_partition_list = [] + # enumerate all the 2D sharding cases + for i in range(dim_size): + for j in range(i + 1, dim_size): + dim_partition_dict_0 = {i: [mesh_dim_0], j: [mesh_dim_1]} + dim_partition_dict_1 = {i: [mesh_dim_1], j: [mesh_dim_0]} + dim_partition_list.append(dim_partition_dict_0) + dim_partition_list.append(dim_partition_dict_1) + for i in range(dim_size): + dim_partition_dict_flatten = {i: [mesh_dim_0, mesh_dim_1]} + dim_partition_list.append(dim_partition_dict_flatten) + + return dim_partition_list + + +def enumerate_all_possible_1d_sharding(mesh_dim_0, dim_size): + dim_partition_list = [] + # enumerate all the 1D sharding cases + for i in range(dim_size): + dim_partition_dict_0 = {i: [mesh_dim_0]} + dim_partition_list.append(dim_partition_dict_0) + + return dim_partition_list + + +def generate_sharding_size(dim_partition_dict, device_mesh): + total_sharding_size = 1 + for mesh_dim_list in dim_partition_dict.values(): + mesh_dim_sharding_size = [device_mesh.shape[mesh_dim] for mesh_dim in mesh_dim_list] + sharding_size = reduce(operator.mul, mesh_dim_sharding_size) + total_sharding_size *= sharding_size + + return total_sharding_size diff --git a/colossalai/auto_parallel/solver/constants.py b/colossalai/auto_parallel/solver/constants.py index 6addb2ebd..727c3ef35 100644 --- a/colossalai/auto_parallel/solver/constants.py +++ b/colossalai/auto_parallel/solver/constants.py @@ -3,7 +3,8 @@ import operator __all__ = [ 'ELEMENTWISE_MODULE_OP', 'ELEMENTWISE_FUNC_OP', 'RESHAPE_FUNC_OP', 'CONV_MODULE_OP', 'CONV_FUNC_OP', - 'LINEAR_MODULE_OP', 'LINEAR_FUNC_OP', 'BATCHNORM_MODULE_OP', 'POOL_MODULE_OP', 'NON_PARAM_FUNC_OP', 'BCAST_FUNC_OP' + 'LINEAR_MODULE_OP', 'LINEAR_FUNC_OP', 'BATCHNORM_MODULE_OP', 'POOL_MODULE_OP', 'NON_PARAM_FUNC_OP', 'BCAST_FUNC_OP', + 'EMBEDDING_MODULE_OP', 'LAYERNORM_MODULE_OP', 'ELEMENTWISE_METHOD_OP', 'RESHAPE_METHOD_OP' ] ELEMENTWISE_MODULE_OP = [torch.nn.Dropout, torch.nn.ReLU] @@ -11,7 +12,18 @@ ELEMENTWISE_FUNC_OP = [ torch.abs, torch.cos, torch.exp, operator.neg, torch.multiply, torch.nn.functional.relu, torch.nn.functional.dropout, torch.flatten ] -RESHAPE_FUNC_OP = [torch.flatten, torch.Tensor.view, torch.reshape] +ELEMENTWISE_METHOD_OP = [ + torch.Tensor.to, + torch.Tensor.type, +] +RESHAPE_FUNC_OP = [torch.flatten, torch.reshape] +RESHAPE_METHOD_OP = [ + torch.Tensor.view, + torch.Tensor.unsqueeze, + torch.Tensor.split, + torch.Tensor.permute, + torch.Tensor.transpose, +] BCAST_FUNC_OP = [ torch.add, torch.sub, torch.mul, torch.div, torch.floor_divide, torch.true_divide, operator.add, operator.sub, operator.mul, operator.floordiv, operator.truediv, torch.matmul @@ -23,9 +35,11 @@ CONV_MODULE_OP = [ CONV_FUNC_OP = [ torch.conv1d, torch.conv2d, torch.conv3d, torch.conv_transpose1d, torch.conv_transpose2d, torch.conv_transpose3d ] +EMBEDDING_MODULE_OP = [torch.nn.modules.sparse.Embedding] LINEAR_MODULE_OP = [torch.nn.Linear] LINEAR_FUNC_OP = [torch.nn.functional.linear, torch.matmul, torch.bmm] BATCHNORM_MODULE_OP = [torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, torch.nn.BatchNorm3d, torch.nn.SyncBatchNorm] +LAYERNORM_MODULE_OP = [torch.nn.LayerNorm] POOL_MODULE_OP = [torch.nn.MaxPool1d, torch.nn.MaxPool2d, torch.nn.MaxPool3d, torch.nn.AdaptiveAvgPool2d] NON_PARAM_FUNC_OP = RESHAPE_FUNC_OP + ELEMENTWISE_FUNC_OP diff --git a/colossalai/auto_parallel/solver/op_handler/bcast_op_handler.py b/colossalai/auto_parallel/solver/op_handler/bcast_op_handler.py index 0b20d02fe..1f1d681e0 100644 --- a/colossalai/auto_parallel/solver/op_handler/bcast_op_handler.py +++ b/colossalai/auto_parallel/solver/op_handler/bcast_op_handler.py @@ -8,7 +8,7 @@ from colossalai.tensor.shape_consistency import ShapeConsistencyManager from colossalai.tensor.sharding_spec import ShardingSpec from copy import deepcopy from typing import Dict, List -from colossalai.auto_parallel.solver._utils import exception_handler +from colossalai.auto_parallel.solver._utils import exception_handler, enumerate_all_possible_1d_sharding, enumerate_all_possible_2d_sharding __all__ = ['BcastOpHandler'] @@ -110,45 +110,19 @@ class BcastOpHandler(OperatorHandler): return sharding_spec_list - def _enumerate_all_possible_2d_sharding(self, mesh_dim_0, mesh_dim_1, dim_size): - dim_partition_list = [] - # enumerate all the 2D sharding cases - for i in range(dim_size): - for j in range(i + 1, dim_size): - dim_partition_dict_0 = {i: [mesh_dim_0], j: [mesh_dim_1]} - dim_partition_dict_1 = {i: [mesh_dim_1], j: [mesh_dim_0]} - dim_partition_list.append(dim_partition_dict_0) - dim_partition_list.append(dim_partition_dict_1) - for i in range(dim_size): - dim_partition_dict_flatten = {i: [mesh_dim_0, mesh_dim_1]} - dim_partition_list.append(dim_partition_dict_flatten) - - # sharding_spec_list = self._convert_partition_dict_to_sharding_spec(dim_partition_list) - return dim_partition_list - - def _enumerate_all_possible_1d_sharding(self, mesh_dim_0, dim_size): - dim_partition_list = [] - # enumerate all the 1D sharding cases - for i in range(dim_size): - dim_partition_dict_0 = {i: [mesh_dim_0]} - dim_partition_list.append(dim_partition_dict_0) - - # sharding_spec_list = self._convert_partition_dict_to_sharding_spec(dim_partition_list) - return dim_partition_list - def _enumerate_all_possible_output(self, mesh_dim_0, mesh_dim_1): # use mesh_dim_0, mesh_dim_1 instead of constant 0, 1 in here for N-D device mesh scaliablity. output_dim_partition_list = [] dim_size = self.output_data.dim() # enumerate all the 2D sharding cases - sharding_list_2d = self._enumerate_all_possible_2d_sharding(mesh_dim_0, mesh_dim_1, dim_size) + sharding_list_2d = enumerate_all_possible_2d_sharding(mesh_dim_0, mesh_dim_1, dim_size) output_dim_partition_list.extend(sharding_list_2d) # enumerate all the 1D sharding cases - sharding_list_1d_on_dim_0 = self._enumerate_all_possible_1d_sharding(mesh_dim_0, dim_size) + sharding_list_1d_on_dim_0 = enumerate_all_possible_1d_sharding(mesh_dim_0, dim_size) output_dim_partition_list.extend(sharding_list_1d_on_dim_0) - sharding_list_1d_on_dim_1 = self._enumerate_all_possible_1d_sharding(mesh_dim_1, dim_size) + sharding_list_1d_on_dim_1 = enumerate_all_possible_1d_sharding(mesh_dim_1, dim_size) output_dim_partition_list.extend(sharding_list_1d_on_dim_1) # add empty dict for fully replicated case @@ -545,15 +519,13 @@ class BcastOpHandler(OperatorHandler): dim_size = self.output_data.dim() - 2 # Both device mesh axises are uesd on batch dimensions - dim_partition_dicts_2d = self._enumerate_all_possible_2d_sharding(MESH_DIM_LIST[0], MESH_DIM_LIST[1], - dim_size) + dim_partition_dicts_2d = enumerate_all_possible_2d_sharding(MESH_DIM_LIST[0], MESH_DIM_LIST[1], dim_size) for dim_partition_dict in dim_partition_dicts_2d: self._registry_no_split_strategies_for_matmul(dim_partition_dict) # Only one device mesh axis is uesd on batch dimensions for mesh_dim_index in [0, 1]: - dim_partition_dicts_1d = self._enumerate_all_possible_1d_sharding(MESH_DIM_LIST[mesh_dim_index], - dim_size) + dim_partition_dicts_1d = enumerate_all_possible_1d_sharding(MESH_DIM_LIST[mesh_dim_index], dim_size) for dim_partition_dict in dim_partition_dicts_1d: self._registry_no_split_strategies_for_matmul(dim_partition_dict) self._registry_1d_strategies_for_matmul(dim_partition_dict, [MESH_DIM_LIST[mesh_dim_index - 1]]) diff --git a/colossalai/auto_parallel/solver/op_handler/layer_norm_handler.py b/colossalai/auto_parallel/solver/op_handler/layer_norm_handler.py new file mode 100644 index 000000000..9b41f37be --- /dev/null +++ b/colossalai/auto_parallel/solver/op_handler/layer_norm_handler.py @@ -0,0 +1,233 @@ +import operator +from functools import reduce +import torch +from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector +from .operator_handler import OperatorHandler +from colossalai.auto_parallel.solver._utils import exception_handler, enumerate_all_possible_2d_sharding, enumerate_all_possible_1d_sharding, generate_sharding_size + +__all__ = ['LayerNormHandler'] + + +class LayerNormHandler(OperatorHandler): + """ + A OperatorHandler which deals with the sharding strategies of normalization. + + Note: To keep the math consistency, LayerNorm do not allow shards on hidden dimension. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.input_data = self.predecessor_node[0]._meta_data + self.weight = self.module_named_parameters['weight'] + self.bias = self.module_named_parameters['bias'] + self.output_data = self.node._meta_data + + def _generate_compute_cost(self, total_sharding_size): + ''' + Compute the computation cost per device with this specific strategy. + + Note: compute_cost need to be devided by TFLOPS, now it just shows the computation size. + + Argument: + bs(int): Batch size of the input data. + channel_in(int): The channel dimension of input data. + + Return: + compute_cost(float): Computation cost per device with this specific strategy + ''' + # TODO: compute_cost need to be devided by TFLOPS, now it just shows the computation size. + # TODO: a constant coefficient need to be added. + + norm_kernel_size = self.weight.shape + # in LayerNorm context, batch dimensions mean all the dimensions do not join the normalization. + input_batch_shape = self.input_data.shape[:-len(norm_kernel_size)] + input_batch_product = reduce(operator.mul, input_batch_shape, 1) + norm_kernel_product = reduce(operator.mul, norm_kernel_size, 1) + forward_compute_cost = input_batch_product * norm_kernel_product / total_sharding_size + backward_activation_compute_cost = input_batch_product * norm_kernel_product / total_sharding_size + # To compute gradient of on norm kernel element requires input_batch_product times computation, so + # the total cost is input_batch_product * norm_kernel_product + backward_weight_compute_cost = input_batch_product * norm_kernel_product / total_sharding_size + backward_compute_cost = backward_activation_compute_cost + backward_weight_compute_cost + compute_cost = forward_compute_cost + backward_compute_cost + return compute_cost + + def _generate_memory_cost(self, sharding_size_forward, sharding_size_backward_activation, sharding_size_weight): + ''' + Compute the memory cost per device with this specific strategy. + + Argument: + sharding_size_forward(int): The forward activation will be divided + into sharding_size_forward number partions. + sharding_size_backward_activation(int): The backward activation will + be divided into sharding_size_backward_activation number partions. + sharding_size_weight(int): The backward weight will be divided + into sharding_size_weight number partions. + + Return: + memory_cost(Tuple[float]): Memory cost per device with this + specific strategy, the first element of this tuple is forward + memory cost, and the second element of this tuple is backward + memory cost. + memory_cost_forward(float): Memory cost of forward activation per + device with this specific strategy. + memory_cost_backward_activation(float): Memory cost of backward activation + per device with this specific strategy. + ''' + # compute the memory cost of this strategy + dtype = self.input_data.dtype + numel_output = self.output_data.numel() + # this operation will not change the shape of input + numel_input = numel_output + numel_weight = self.weight.numel() + size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size() + + # forward memory_cost + memory_cost_forward_activation = numel_output * size_per_elem_bytes / sharding_size_forward + memory_cost_forward_weight = numel_weight * size_per_elem_bytes / sharding_size_weight + memory_cost_forward = memory_cost_forward_activation + memory_cost_forward_weight + + # backward memory_cost + memory_cost_backward_activation = numel_input * size_per_elem_bytes / sharding_size_backward_activation + memory_cost_backward_weight = numel_weight * size_per_elem_bytes / sharding_size_weight + memory_cost_backward = memory_cost_backward_activation + memory_cost_backward_weight + + # memory_cost pair + memory_cost = (memory_cost_forward, memory_cost_backward) + + return memory_cost, memory_cost_forward_activation, memory_cost_backward_activation, memory_cost_backward_weight + + def _generate_strategy_with_dim_partition(self, dim_partition): + dim_partition_dict_for_input = dim_partition + sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input) + + dim_partition_dict_for_weight = {} + sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight) + + dim_partition_dict_for_output = dim_partition + sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output) + + name = f'{sharding_spec_for_output.sharding_sequence} = {sharding_spec_for_input.sharding_sequence} x {sharding_spec_for_weight.sharding_sequence}' + # generate resharding cost for this strategy + resharding_costs = self._generate_resharding_costs([sharding_spec_for_input]) + + total_sharding_size = generate_sharding_size(dim_partition, self.device_mesh) + # compute the computation cost of this strategy + compute_cost = self._generate_compute_cost(total_sharding_size) + + # compute the memory cost of this strategy + sharding_size_forward = generate_sharding_size(dim_partition_dict_for_input, self.device_mesh) + sharding_size_backward_activation = generate_sharding_size(dim_partition_dict_for_output, self.device_mesh) + sharding_size_weight = generate_sharding_size(dim_partition_dict_for_weight, self.device_mesh) + memory_cost, _, _, memory_cost_backward_weight = self._generate_memory_cost(sharding_size_forward, + sharding_size_backward_activation, + sharding_size_weight) + + total_mesh_dim_list = [] + for mesh_dim_list in dim_partition.values(): + total_mesh_dim_list.extend(mesh_dim_list) + + # This strategy do not need to do all_reduce operation for activation + communication_cost_forward_activation = 0 + communication_cost_backward_activation = 0 + if len(total_mesh_dim_list) == 1: + communication_cost_backward_weight = self.device_mesh.all_reduce_cost(memory_cost_backward_weight, + total_mesh_dim_list[0]) + else: + assert len(total_mesh_dim_list) == 2, f'temporally we just support 2d device mesh.' + communication_cost_backward_weight = self.device_mesh.flatten_device_mesh.all_reduce_cost( + memory_cost_backward_weight, 0) + communication_cost = communication_cost_forward_activation + communication_cost_backward_activation + communication_cost_backward_weight + + sharding_strategies = ShardingStrategy(name, + output_sharding_spec=sharding_spec_for_output, + compute_cost=compute_cost, + communication_cost=communication_cost, + memory_cost=memory_cost, + resharding_costs=resharding_costs, + input_shardings=(sharding_spec_for_input, sharding_spec_for_weight)) + + self.strategies_vector.append(sharding_strategies) + + @exception_handler + def split_input_batch_single_mesh_dim(self, mesh_dim_0): + batch_dimension_length = self.input_data.dim() - self.weight.dim() + dim_partition_list = enumerate_all_possible_1d_sharding(mesh_dim_0, batch_dimension_length) + for dim_partition in dim_partition_list: + self._generate_strategy_with_dim_partition(dim_partition) + + @exception_handler + def split_input_batch_both_mesh_dim(self, mesh_dim_0, mesh_dim_1): + batch_dimension_length = self.input_data.dim() - self.weight.dim() + dim_partition_list = enumerate_all_possible_2d_sharding(mesh_dim_0, mesh_dim_1, batch_dimension_length) + for dim_partition in dim_partition_list: + self._generate_strategy_with_dim_partition(dim_partition) + + @exception_handler + def non_split(self): + name = f'RR = RR x R' + + dim_partition_dict_for_input = {} + sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input) + + dim_partition_dict_for_weight = {} + sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight) + + dim_partition_dict_for_output = {} + sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output) + + # generate resharding cost for this strategy + resharding_costs = self._generate_resharding_costs([sharding_spec_for_input]) + + total_sharding_size = 1 + # compute the computation cost of this strategy + compute_cost = self._generate_compute_cost(total_sharding_size) + + # compute the memory cost of this strategy + sharding_size_forward = 1 + sharding_size_backward_activation = 1 + sharding_size_weight = 1 + memory_cost, _, _, _ = self._generate_memory_cost(sharding_size_forward, sharding_size_backward_activation, + sharding_size_weight) + + # This strategy do not need to do all_reduce operation + communication_cost = 0 + sharding_strategies = ShardingStrategy(name, + output_sharding_spec=sharding_spec_for_output, + compute_cost=compute_cost, + communication_cost=communication_cost, + memory_cost=memory_cost, + resharding_costs=resharding_costs, + input_shardings=(sharding_spec_for_input, sharding_spec_for_weight)) + + self.strategies_vector.append(sharding_strategies) + + def register_strategy(self) -> StrategiesVector: + ''' + Generate every possible strategies for a BatchNorm node, and record all strategies into the strategies_vector. + + Example: + norm_handler = BatchNormHandler(node, strategies_vector, + self.shape_consistency_manager) + norm_handler.register_strategy() + for strategy in norm_handler.strategies_vector: + print(f'{strategy.name}, computation_cost: {strategy.compute_cost}, memory_cost: {strategy.memory_cost}') + + Output: + RS0 = RS0 x S0, computation_cost: 131072, memory_cost: 524288.0 + RS1 = RS1 x S1, computation_cost: 131072, memory_cost: 524288.0 + RR = RR x R, computation_cost: 262144, memory_cost: 1048576 + RS01 = RS01 x S01, computation_cost: 65536, memory_cost: 262144.0 + ''' + + # SR = SR x R with single mesh dim on batch dimensions + self.split_input_batch_single_mesh_dim(0) + self.split_input_batch_single_mesh_dim(1) + + # SR = SR x R with both mesh dims on batch dimensions + self.split_input_batch_both_mesh_dim(0, 1) + + # RR = RR x R + self.non_split() + + return self.strategies_vector diff --git a/colossalai/auto_parallel/solver/strategies_constructor.py b/colossalai/auto_parallel/solver/strategies_constructor.py index 6b037aef7..ab67d37e9 100644 --- a/colossalai/auto_parallel/solver/strategies_constructor.py +++ b/colossalai/auto_parallel/solver/strategies_constructor.py @@ -1,5 +1,6 @@ from torch.fx import Graph, Node from colossalai.auto_parallel.solver.op_handler.bcast_op_handler import BcastOpHandler +from colossalai.auto_parallel.solver.op_handler.layer_norm_handler import LayerNormHandler from colossalai.tensor.sharding_spec import ShardingSpec from colossalai.device.device_mesh import DeviceMesh from colossalai.tensor.shape_consistency import ShapeConsistencyManager @@ -216,6 +217,15 @@ class StrategiesConstructor: input_shardings=[input_sharding_spec]) strategies_vector.append(sharding_strategy) + # embedding module + elif submod_type in EMBEDDING_MODULE_OP: + embedding_handler = EmbeddingHandler(node, self.device_mesh, strategies_vector) + embedding_handler.register_strategy() + + # layernorm module + elif submod_type in LAYERNORM_MODULE_OP: + layernorm_handler = LayerNormHandler(node, self.device_mesh, strategies_vector) + layernorm_handler.register_strategy() # other module else: raise RuntimeError(f'{submod_type} module is NOT supported now.') @@ -349,35 +359,72 @@ class StrategiesConstructor: elif target == operator.getitem: index = node.args[1] input_tensor_node = strategies_vector.predecessor_nodes[0] - for strategy in input_tensor_node.strategies_vector: - input_sharding_spec = input_tensor_node.output_sharding_spec[index] - assert isinstance(input_sharding_spec, ShardingSpec), f'This assertion is used to debug.' - dim_partition_dict_for_output = deepcopy(input_sharding_spec.dim_partition_dict) - entire_shape_output = deepcopy(input_sharding_spec.entire_shape) - output_sharding_spec = ShardingSpec(self.device_mesh, - entire_shape_output, - dim_partition_dict=dim_partition_dict_for_output) - # TODO: use meta_info_prop to profile origin memory cost and compute cost, then divide them depending on sharding spec. - compute_cost = 0 - memory_cost = 0 - resharding_costs = generate_resharding_costs(strategies_vector.predecessor_nodes, - [input_sharding_spec]) - # to prevent the resharding happening, set their resharding cost to inf. - resharding_costs[input_tensor_node] = [ - cost if cost == 0 else math.inf for cost in resharding_costs[input_tensor_node] - ] - sharding_strategy = ShardingStrategy(name, - output_sharding_spec, - compute_cost=compute_cost, - memory_cost=memory_cost, - resharding_costs=resharding_costs, - input_shardings=[input_tensor_node.output_sharding_spec]) - strategies_vector.append(sharding_strategy) + if isinstance(input_tensor_node, torch.Tensor): + for strategy in input_tensor_node.strategies_vector: + input_sharding_spec = strategy.output_sharding_spec[index] + assert isinstance(input_sharding_spec, ShardingSpec), f'This assertion is used to debug.' + dim_partition_dict_for_output = deepcopy(input_sharding_spec.dim_partition_dict) + entire_shape_output = deepcopy(input_sharding_spec.entire_shape) + output_sharding_spec = ShardingSpec(self.device_mesh, + entire_shape_output, + dim_partition_dict=dim_partition_dict_for_output) + # TODO: use meta_info_prop to profile origin memory cost and compute cost, then divide them depending on sharding spec. + compute_cost = 0 + memory_cost = 0 + resharding_costs = generate_resharding_costs(strategies_vector.predecessor_nodes, + [input_sharding_spec]) + # to prevent the resharding happening, set their resharding cost to inf. + resharding_costs[input_tensor_node] = [ + cost if cost == 0 else math.inf for cost in resharding_costs[input_tensor_node] + ] + sharding_strategy = ShardingStrategy( + name, + output_sharding_spec, + compute_cost=compute_cost, + memory_cost=memory_cost, + resharding_costs=resharding_costs, + input_shardings=[input_tensor_node.output_sharding_spec]) + strategies_vector.append(sharding_strategy) + # torch.arange function + elif target == torch.arange: + name = f'FULLY REPLICATED ARANGE' + entire_shape_output = node._meta_data.shape + dim_partition_dict_for_output = {} + output_sharding_spec = ShardingSpec(self.device_mesh, + entire_shape_output, + dim_partition_dict=dim_partition_dict_for_output) + memory_cost = node._meta_data.numel() + sharding_strategy = ShardingStrategy(name, + output_sharding_spec, + compute_cost=0, + memory_cost=memory_cost) + strategies_vector.append(sharding_strategy) + + # op list to be processed to support gpt2 + elif target in (builtins.getattr, operator.le, torch.addmm, operator.pow, torch.where, torch.softmax, + torch.nn.functional.softmax, torch.pow, torch.tanh): + pass # other function else: raise RuntimeError(f'{target} function is NOT supported now.') + # call_method node + if node.op == 'call_method': + method = getattr(node.args[0]._meta_data.__class__, node.target) + if method in (torch.Tensor.size, torch.Tensor.contiguous): + pass + elif method in ELEMENTWISE_METHOD_OP: + unary_elementwise_handler = UnaryElementwiseHandler(node, self.device_mesh, strategies_vector) + unary_elementwise_handler.register_strategy() + + elif method in RESHAPE_METHOD_OP: + reshape_handler = ReshapeHandler(node, self.device_mesh, strategies_vector) + reshape_handler.register_strategy() + + else: + raise RuntimeError(f'{method} function is NOT supported now.') + # output node if node.op == 'output': if self.solver_options.fast: diff --git a/tests/test_auto_parallel/test_layer_norm_handler.py b/tests/test_auto_parallel/test_layer_norm_handler.py new file mode 100644 index 000000000..afab3934f --- /dev/null +++ b/tests/test_auto_parallel/test_layer_norm_handler.py @@ -0,0 +1,70 @@ +import torch +from torch.fx import GraphModule +import torch.nn as nn +import pytest +from colossalai.auto_parallel.solver import sharding_strategy + +from colossalai.fx.proxy import ColoProxy +from colossalai.fx.tracer.tracer import ColoTracer +from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec +from colossalai.auto_parallel.solver.op_handler.layer_norm_handler import LayerNormHandler +from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector +from colossalai.device.device_mesh import DeviceMesh + + +class LNModel(nn.Module): + + def __init__(self, c): + super().__init__() + self.ln = nn.LayerNorm(c) + + def forward(self, x): + x = x * 2 + x = self.ln(x) + return x + + +def test_bn_handler(): + physical_mesh_id = torch.arange(0, 4) + mesh_shape = (2, 2) + # [[0, 1] + # [2, 3]] + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) + entire_shape = torch.Size((4, 4, 128)) + + tracer = ColoTracer() + model = LNModel(128) + input_sample = {'x': torch.rand(4, 4, 128).to('meta')} + # graph(): + # %x : torch.Tensor [#users=1] = placeholder[target=x] + # %mul : [#users=1] = call_function[target=operator.mul](args = (%x, 2), kwargs = {}) + # %ln : [#users=1] = call_module[target=ln](args = (%mul,), kwargs = {}) + # return ln + graph = tracer.trace(root=model, meta_args=input_sample) + gm = GraphModule(model, graph, model.__class__.__name__) + gm.recompile() + # [x, mul, ln, output] + nodes = [node for node in gm.graph.nodes] + sharding_spec_for_input = ShardingSpec(device_mesh, entire_shape, {}) + sharding_strategy_for_input = ShardingStrategy('node_1', sharding_spec_for_input) + strategies_vector_for_input = StrategiesVector(nodes[1]) + strategies_vector_for_input.append(sharding_strategy_for_input) + setattr(nodes[1], 'strategies_vector', strategies_vector_for_input) + + # generate bn strategy + strategies_vector = StrategiesVector(node=nodes[2]) + ln_handler = LayerNormHandler( + node=nodes[2], + device_mesh=device_mesh, + strategies_vector=strategies_vector, + ) + ln_handler.register_strategy() + # ['[S0, R, R] = [S0, R, R] x [R]', '[R, S0, R] = [R, S0, R] x [R]', '[S1, R, R] = [S1, R, R] x [R]', '[R, S1, R] = [R, S1, R] x [R]', + # '[S0, S1, R] = [S0, S1, R] x [R]', '[S1, S0, R] = [S1, S0, R] x [R]', '[S01, R, R] = [S01, R, R] x [R]', '[R, S01, R] = [R, S01, R] x [R]', 'RR = RR x R'] + strategy_name_list = [strategy.name for strategy in ln_handler.strategies_vector] + + assert len(strategy_name_list) == 9 + + +if __name__ == '__main__': + test_bn_handler()