From 262652c8bcc3ad1ff198c4d4eff1c5e6c4a1ef62 Mon Sep 17 00:00:00 2001 From: Frank Lee Date: Fri, 21 Oct 2022 18:55:48 +0800 Subject: [PATCH] [autoparallel] added addbmm handler (#1751) --- .../tensor_shard/node_handler/__init__.py | 9 +- .../tensor_shard/node_handler/bmm_handler.py | 90 +++++++-- .../strategy/matmul_strategy_generator.py | 67 ++++++- .../strategy/strategy_generator.py | 3 - .../tensor_shard/utils/broadcast.py | 7 +- .../meta_patch/patched_function/arithmetic.py | 11 + .../test_node_handler/test_addbmm_handler.py | 189 ++++++++++++++++++ .../test_node_handler/test_bmm_handler.py | 12 +- 8 files changed, 353 insertions(+), 35 deletions(-) create mode 100644 tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_addbmm_handler.py diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/__init__.py b/colossalai/auto_parallel/tensor_shard/node_handler/__init__.py index d8dbaa0ac..b9227e2ec 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/__init__.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/__init__.py @@ -1,5 +1,5 @@ from .batch_norm_handler import BatchNormModuleHandler -from .bmm_handler import BMMFunctionHandler +from .bmm_handler import AddBMMFunctionHandler, BMMFunctionHandler from .conv_handler import ConvFunctionHandler, ConvModuleHandler from .layer_norm_handler import LayerNormModuleHandler from .linear_handler import LinearFunctionHandler, LinearModuleHandler @@ -12,7 +12,8 @@ from .unary_elementwise_handler import UnaryElementwiseHandler from .where_handler import WhereHandler __all__ = [ - 'LinearFunctionHandler', 'LinearModuleHandler', 'BMMFunctionHandler', 'LayerNormModuleHandler', - 'BatchNormModuleHandler', 'ConvModuleHandler', 'ConvFunctionHandler', 'UnaryElementwiseHandler', 'ReshapeHandler', - 'PlacehodlerHandler', 'OuputHandler', 'WhereHandler', 'NormPoolingHandler', 'operator_registry' + 'LinearFunctionHandler', 'LinearModuleHandler', 'BMMFunctionHandler', 'AddBMMFunctionHandler', + 'LayerNormModuleHandler', 'BatchNormModuleHandler', 'ConvModuleHandler', 'ConvFunctionHandler', + 'UnaryElementwiseHandler', 'ReshapeHandler', 'PlacehodlerHandler', 'OuputHandler', 'WhereHandler', + 'NormPoolingHandler', 'operator_registry' ] diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/bmm_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/bmm_handler.py index a1ca06a74..09016d507 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/bmm_handler.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/bmm_handler.py @@ -1,33 +1,97 @@ -from typing import Dict, List +from typing import Dict, List, Union import torch -from ..sharding_strategy import OperationData, OperationDataType +from ..sharding_strategy import OperationData, OperationDataType, ShardingStrategy +from ..utils import recover_sharding_spec_for_broadcast_shape from .node_handler import NodeHandler from .registry import operator_registry from .strategy import BatchedMatMulStrategyGenerator, StrategyGenerator +__all__ = ['BMMFunctionHandler', 'AddBMMFunctionHandler'] + + +def _get_data_mapping_for_bmm_op(node, input_idx, other_idx, bias_idx=None): + """ + This function is a helper function which extracts the common logic for both `bmm` and `addbmm` + node handler to reduce code redundancy. + """ + # input operand + physical_input_operand = OperationData(name=str(node.args[input_idx]), + type=OperationDataType.ARG, + data=node.args[input_idx]._meta_data) + + # other operand + physical_other_operand = OperationData(name=str(node.args[other_idx]), + type=OperationDataType.ARG, + data=node.args[other_idx]._meta_data) + + # output + physical_output = OperationData(name=str(node), type=OperationDataType.OUTPUT, data=node._meta_data) + mapping = {"input": physical_input_operand, "other": physical_other_operand, "output": physical_output} + + if bias_idx is not None: + # bias physical shape + bias_logical_shape = node._meta_data.shape + physical_bias_operand = OperationData(name=str(node.args[bias_idx]), + type=OperationDataType.ARG, + data=node.args[bias_idx]._meta_data, + logical_shape=bias_logical_shape) + mapping['bias'] = physical_bias_operand + return mapping + @operator_registry.register(torch.bmm) @operator_registry.register(torch.Tensor.bmm) class BMMFunctionHandler(NodeHandler): + """ + This is a NodeHandler class which deals with the batched matrix multiplication operation in PyTorch. + Such operations including `torch.bmm` and `torch.Tensor.bmm` require the tensor to be 3D, thus, there is + no logical-physical shape conversion in this handler. + """ def get_operation_data_mapping(self) -> Dict[str, OperationData]: - physical_input_operand = OperationData(name=str(self.node.args[0]), - type=OperationDataType.ARG, - data=self.node.args[0]._meta_data) - - physical_other_operand = OperationData(name=str(self.node.args[1]), - type=OperationDataType.ARG, - data=self.node.args[1]._meta_data) - physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data) - - mapping = {"input": physical_input_operand, "other": physical_other_operand, "output": physical_output} + mapping = _get_data_mapping_for_bmm_op(node=self.node, input_idx=0, other_idx=1) return mapping def get_strategy_generator(self) -> List[StrategyGenerator]: - generators = [] op_data_mapping = self.get_operation_data_mapping() generators = [] generators.append(BatchedMatMulStrategyGenerator(op_data_mapping, self.device_mesh)) return generators + + +@operator_registry.register(torch.addbmm) +@operator_registry.register(torch.Tensor.addbmm) +class AddBMMFunctionHandler(NodeHandler): + """ + This is a NodeHandler class which deals with the addition + batched matrix multiplication operation in PyTorch. + Such operations including `torch.addbmm` and `torch.Tensor.addbmm` require the two matmul tensor to be 3D. However, due to the + addition, logical-physical shape conversion is required for the bias term. + + As the addbmm operation will reduce the batch dimension, the bias is maximum 2D. + """ + + def get_operation_data_mapping(self) -> Dict[str, OperationData]: + mapping = _get_data_mapping_for_bmm_op(node=self.node, input_idx=1, other_idx=2, bias_idx=0) + return mapping + + def get_strategy_generator(self) -> List[StrategyGenerator]: + op_data_mapping = self.get_operation_data_mapping() + generators = [] + generators.append(BatchedMatMulStrategyGenerator(op_data_mapping, self.device_mesh)) + return generators + + def post_process(self, strategy: ShardingStrategy) -> Union[ShardingStrategy, List[ShardingStrategy]]: + # convert bias from its logical sharding spec to its physical sharding spec + op_data_mapping = self.get_operation_data_mapping() + + if 'bias' in op_data_mapping: + bias_op_data = op_data_mapping['bias'] + bias_physical_shape = bias_op_data.data.shape + bias_logical_shape = bias_op_data.logical_shape + bias_sharding_spec = strategy.get_sharding_spec_by_name(bias_op_data.name) + bias_sharding_spec = recover_sharding_spec_for_broadcast_shape(bias_sharding_spec, bias_logical_shape, + bias_physical_shape) + strategy.sharding_specs[bias_op_data] = bias_sharding_spec + return strategy diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/matmul_strategy_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/matmul_strategy_generator.py index d178ebe7a..be2a95098 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/matmul_strategy_generator.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/matmul_strategy_generator.py @@ -514,23 +514,60 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator): A batched matrix multiplication can be viewed as [b, i, k] x [b, k, j] -> [b, i, j] + + The bias term is considered to have a 2D logical shape. """ + def __init__(self, *args, **kwargs): + self.squeeze_batch_dim = False + super().__init__(*args, **kwargs) + + def _pop_batch_dim_sharding_for_output(self, dim_partition_dict): + # remove partition dict for dim 0 + dim_partition_dict['output'].pop(0, None) + + # decrease the remaining dim index by 1 + temp_dim_partition = {} + keys = list(dim_partition_dict['output'].keys()) + for key in keys: + val = dim_partition_dict['output'].pop(key) + temp_dim_partition[key - 1] = val + dim_partition_dict['output'].update(temp_dim_partition) + def validate(self) -> bool: input_op_data = self.op_data['input'] other_op_data = self.op_data['other'] - assert input_op_data.data.dim() > 2 or other_op_data.data.dim() > 2 + assert input_op_data.data.dim() == 3 or other_op_data.data.dim() == 3 + + if 'bias' in self.op_data: + bias_op_data = self.op_data['bias'] + assert bias_op_data.data.dim() < 3 and len(bias_op_data.logical_shape) == 2 + + if self.op_data['output'].data.dim() == 2: + # addbmm will shrink the first batch dim + self.squeeze_batch_dim = True def update_compute_cost(self, strategy: ShardingStrategy) -> ShardingStrategy: - return self.op_data['input'].data.shape[-1] * reduce(operator.mul, self.op_data['output'].data.shape) + fwd_compute_cost = self.op_data['input'].data.shape[-1] * reduce(operator.mul, + self.op_data['output'].data.shape) + bwd_compute_cost = fwd_compute_cost * 2 + compute_cost = TrainCycleItem(fwd=fwd_compute_cost, + bwd=bwd_compute_cost, + total=fwd_compute_cost + bwd_compute_cost) + strategy.compute_cost = compute_cost + @ignore_sharding_exception def split_one_batch_dim(self, mesh_dim): name = f'Sb{mesh_dim} = Sb{mesh_dim} x Sb{mesh_dim}' # get sharding_spec dim_partition_dict = {"input": {0: [mesh_dim]}, "other": {0: [mesh_dim]}, "bias": {}, "output": {0: [mesh_dim]}} + if self.squeeze_batch_dim: + self._pop_batch_dim_sharding_for_output(dim_partition_dict) sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict) + print(sharding_spec_mapping) + # get communication actions communication_action_mapping = {} if self.has_bias: @@ -543,6 +580,7 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator): sharding_spec_mapping=sharding_spec_mapping, communication_action_mapping=communication_action_mapping) + @ignore_sharding_exception def split_two_batch_dim(self, mesh_dim_0, mesh_dim_1): name = f'Sb{mesh_dim_0}{mesh_dim_1} = Sb{mesh_dim_0}{mesh_dim_1} x Sb{mesh_dim_0}{mesh_dim_1}' dim_partition_dict = { @@ -557,6 +595,8 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator): 0: [mesh_dim_0, mesh_dim_1] } } + if self.squeeze_batch_dim: + self._pop_batch_dim_sharding_for_output(dim_partition_dict) sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict) # get communication actions @@ -572,22 +612,27 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator): sharding_spec_mapping=sharding_spec_mapping, communication_action_mapping=communication_action_mapping) + @ignore_sharding_exception def split_batch_dim_lhs_space(self, mesh_dim_0, mesh_dim_1): name = f'Sb{mesh_dim_0}Si{mesh_dim_1} = Sb{mesh_dim_0}Si{mesh_dim_1} x Sb{mesh_dim_0}' dim_partition_dict = { "input": { 0: [mesh_dim_0], - -2: [mesh_dim_1] + 1: [mesh_dim_1] }, "other": { 0: [mesh_dim_0] }, - "bias": {}, + "bias": { + 0: [mesh_dim_1] + }, "output": { 0: [mesh_dim_0], - -2: [mesh_dim_1] + 1: [mesh_dim_1] } } + if self.squeeze_batch_dim: + self._pop_batch_dim_sharding_for_output(dim_partition_dict) sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict) # get communication actions @@ -609,6 +654,7 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator): sharding_spec_mapping=sharding_spec_mapping, communication_action_mapping=communication_action_mapping) + @ignore_sharding_exception def split_batch_dim_rhs_space(self, mesh_dim_0, mesh_dim_1): name = f'Sb{mesh_dim_0}Sj{mesh_dim_1} = Sb{mesh_dim_0}R x Sb{mesh_dim_0}Sj{mesh_dim_1}' dim_partition_dict = { @@ -617,16 +663,18 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator): }, "other": { 0: [mesh_dim_0], - -1: [mesh_dim_1] + 2: [mesh_dim_1] }, "bias": { - -1: [mesh_dim_1] + 1: [mesh_dim_1] }, "output": { 0: [mesh_dim_0], - -1: [mesh_dim_1] + 2: [mesh_dim_1] } } + if self.squeeze_batch_dim: + self._pop_batch_dim_sharding_for_output(dim_partition_dict) sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict) # get communication actions @@ -648,6 +696,7 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator): sharding_spec_mapping=sharding_spec_mapping, communication_action_mapping=communication_action_mapping) + @ignore_sharding_exception def split_batch_dim_both_contract(self, mesh_dim_0, mesh_dim_1): name = f'Sb{mesh_dim_0}R = Sb{mesh_dim_0}Sk{mesh_dim_1} x Sb{mesh_dim_0}Sk{mesh_dim_1}' dim_partition_dict = { @@ -664,6 +713,8 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator): 0: [mesh_dim_0], } } + if self.squeeze_batch_dim: + self._pop_batch_dim_sharding_for_output(dim_partition_dict) sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict) # get communication actions diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/strategy_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/strategy_generator.py index 6bbb15e57..8f57ee6a0 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/strategy_generator.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/strategy_generator.py @@ -4,7 +4,6 @@ from functools import reduce from typing import Any, Dict, List, Union import torch - from torch.fx import Node from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( @@ -15,11 +14,9 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( ShardingStrategy, TrainCycleItem, ) - from colossalai.device.device_mesh import DeviceMesh from colossalai.tensor.shape_consistency import CollectiveCommPattern, CommSpec, ShapeConsistencyManager from colossalai.tensor.sharding_spec import ShardingSpec -from torch.fx import Node class StrategyGenerator(ABC): diff --git a/colossalai/auto_parallel/tensor_shard/utils/broadcast.py b/colossalai/auto_parallel/tensor_shard/utils/broadcast.py index 027edb5f7..a0edce9b9 100644 --- a/colossalai/auto_parallel/tensor_shard/utils/broadcast.py +++ b/colossalai/auto_parallel/tensor_shard/utils/broadcast.py @@ -1,6 +1,8 @@ -import torch from enum import Enum, auto from typing import List + +import torch + from colossalai.tensor.sharding_spec import ShardingSpec __all__ = ['BroadcastType', 'is_broadcastable', 'get_broadcast_shape', 'recover_sharding_spec_for_broadcast_shape'] @@ -56,6 +58,9 @@ def recover_sharding_spec_for_broadcast_shape(logical_sharding_spec: ShardingSpe logical_num_dims = len(logical_shape) physical_num_dims = len(physical_shape) + assert logical_num_dims >= physical_num_dims, \ + 'The number of dimensions in the logical shape is smaller than that of the physical shape, this tensor is not broadcast!' + # track the dim and its broadcasting type logical_dim_broadcast_info = {} diff --git a/colossalai/fx/tracer/meta_patch/patched_function/arithmetic.py b/colossalai/fx/tracer/meta_patch/patched_function/arithmetic.py index 00c434e03..3e697de86 100644 --- a/colossalai/fx/tracer/meta_patch/patched_function/arithmetic.py +++ b/colossalai/fx/tracer/meta_patch/patched_function/arithmetic.py @@ -1,4 +1,5 @@ import torch + from ..registry import meta_patched_function @@ -56,6 +57,16 @@ def torch_bmm(input, mat2, *, out=None): return torch.empty(batch_size, n, p, device="meta") +@meta_patched_function.register(torch.addbmm) +@meta_patched_function.register(torch.Tensor.addbmm) +def torch_addbmm(input, mat1, mat2, *, beta=1, alpha=1, out=None): + if out is not None: + raise ValueError("Don't support in-place abs for MetaTensor analysis") + batch_size, n, m = mat1.shape + _, _, p = mat2.shape + return torch.empty(n, p, device="meta") + + @meta_patched_function.register(torch.var_mean) def torch_var_mean(input, dim, unbiased=True, keepdim=False, *, out=None): assert out is None, 'saving to out is not supported yet' diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_addbmm_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_addbmm_handler.py new file mode 100644 index 000000000..54cd473b4 --- /dev/null +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_addbmm_handler.py @@ -0,0 +1,189 @@ +import torch +import torch.nn as nn + +from colossalai.auto_parallel.tensor_shard.node_handler import AddBMMFunctionHandler +from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector +from colossalai.device.device_mesh import DeviceMesh +from colossalai.fx import ColoGraphModule, ColoTracer +from colossalai.testing import parameterize + + +class AddBMMTensorMethodModule(nn.Module): + + def forward(self, bias, x1, x2): + return bias.addbmm(x1, x2) + + +class AddBMMTorchFunctionModule(nn.Module): + + def forward(self, bias, x1, x2): + return torch.addbmm(bias, x1, x2) + + +@parameterize('module', [AddBMMTorchFunctionModule, AddBMMTensorMethodModule]) +@parameterize('bias_shape', [[8], [1, 8], [8, 8]]) +def test_2d_device_mesh(module, bias_shape): + + model = module() + tracer = ColoTracer() + graph = tracer.trace(model, + meta_args={ + 'bias': torch.rand(*bias_shape).to('meta'), + "x1": torch.rand(4, 8, 16).to('meta'), + 'x2': torch.rand(4, 16, 8).to('meta') + }) + print(graph) + gm = ColoGraphModule(model, graph) + physical_mesh_id = torch.arange(0, 4) + + mesh_shape = (2, 2) + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) + linear_mod_node = list(graph.nodes)[3] + strategies_vector = StrategiesVector(linear_mod_node) + + # build handler + handler = AddBMMFunctionHandler(node=linear_mod_node, device_mesh=device_mesh, strategies_vector=strategies_vector) + + # check operation data mapping + mapping = handler.get_operation_data_mapping() + + for name, op_data in mapping.items(): + op_data: OperationData + # make sure they have valid values + assert op_data.logical_shape is not None + assert op_data.data is not None + + assert mapping['input'].name == "x1" + assert mapping['input'].data.is_meta + assert mapping['input'].data.shape == torch.Size([4, 8, 16]) + assert mapping['input'].type == OperationDataType.ARG + assert mapping['input'].logical_shape == torch.Size([4, 8, 16]) + + assert mapping['other'].name == "x2" + assert mapping['other'].data.is_meta + assert mapping['other'].data.shape == torch.Size([4, 16, 8]) + assert mapping['other'].type == OperationDataType.ARG + assert mapping['other'].logical_shape == torch.Size([4, 16, 8]) + + assert mapping['bias'].name == "bias" + assert mapping['bias'].data.is_meta + assert mapping['bias'].data.shape == torch.Size(bias_shape) + assert mapping['bias'].type == OperationDataType.ARG + assert mapping['bias'].logical_shape == torch.Size([8, 8]) + + assert mapping['output'].name == "addbmm" + assert mapping['output'].data.is_meta + assert mapping['output'].data.shape == torch.Size([8, 8]) + assert mapping['output'].type == OperationDataType.OUTPUT + + strategies_vector = handler.register_strategy(compute_resharding_cost=False) + strategy_name_list = [val.name for val in strategies_vector] + + # one batch dim + assert 'Sb0 = Sb0 x Sb0' not in strategy_name_list + + # two batch dim + assert 'Sb01 = Sb01 x Sb01' in strategy_name_list + + # SbSi = SbSi x Sb + assert 'Sb0Si1 = Sb0Si1 x Sb0' in strategy_name_list + assert 'Sb1Si0 = Sb1Si0 x Sb1' in strategy_name_list + + # SbSj = SbR x SbSj + assert 'Sb0Sj1 = Sb0R x Sb0Sj1' in strategy_name_list + assert 'Sb1Sj0 = Sb1R x Sb1Sj0' in strategy_name_list + + # SbR = SbSk x SbSk + assert 'Sb0R = Sb0Sk1 x Sb0Sk1' in strategy_name_list + assert 'Sb1R = Sb1Sk0 x Sb1Sk0' in strategy_name_list + + for strategy in strategies_vector: + input_sharding_spec = strategy.get_sharding_spec_by_name('x1') + other_sharding_spec = strategy.get_sharding_spec_by_name('x2') + bias_sharding_spec = strategy.get_sharding_spec_by_name('bias') + output_sharding_spec = strategy.get_sharding_spec_by_name('addbmm') + + # make sure the sharding matches across different operation data + assert input_sharding_spec.sharding_sequence[1] == output_sharding_spec.sharding_sequence[0] + assert other_sharding_spec.sharding_sequence[1] == input_sharding_spec.sharding_sequence[-1] + assert other_sharding_spec.sharding_sequence[-1] == output_sharding_spec.sharding_sequence[-1] + assert bias_sharding_spec.sharding_sequence[-1] == output_sharding_spec.sharding_sequence[-1] + + +@parameterize('module', [AddBMMTorchFunctionModule, AddBMMTensorMethodModule]) +@parameterize('bias_shape', [[8], [1, 8], [8, 8]]) +def test_1d_device_mesh(module, bias_shape): + model = module() + tracer = ColoTracer() + graph = tracer.trace(model, + meta_args={ + 'bias': torch.rand(*bias_shape).to('meta'), + "x1": torch.rand(4, 8, 16).to('meta'), + 'x2': torch.rand(4, 16, 8).to('meta') + }) + print(graph) + gm = ColoGraphModule(model, graph) + physical_mesh_id = torch.arange(0, 4) + + mesh_shape = (1, 4) + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) + linear_mod_node = list(graph.nodes)[3] + strategies_vector = StrategiesVector(linear_mod_node) + + # build handler + handler = AddBMMFunctionHandler(node=linear_mod_node, device_mesh=device_mesh, strategies_vector=strategies_vector) + + # check operation data mapping + mapping = handler.get_operation_data_mapping() + + for name, op_data in mapping.items(): + op_data: OperationData + # make sure they have valid values + assert op_data.logical_shape is not None + assert op_data.data is not None + + assert mapping['input'].name == "x1" + assert mapping['input'].data.is_meta + assert mapping['input'].data.shape == torch.Size([4, 8, 16]) + assert mapping['input'].type == OperationDataType.ARG + assert mapping['input'].logical_shape == torch.Size([4, 8, 16]) + + assert mapping['other'].name == "x2" + assert mapping['other'].data.is_meta + assert mapping['other'].data.shape == torch.Size([4, 16, 8]) + assert mapping['other'].type == OperationDataType.ARG + assert mapping['other'].logical_shape == torch.Size([4, 16, 8]) + + assert mapping['bias'].name == "bias" + assert mapping['bias'].data.is_meta + assert mapping['bias'].data.shape == torch.Size(bias_shape) + assert mapping['bias'].type == OperationDataType.ARG + assert mapping['bias'].logical_shape == torch.Size([8, 8]) + + assert mapping['output'].name == "addbmm" + assert mapping['output'].data.is_meta + assert mapping['output'].data.shape == torch.Size([8, 8]) + assert mapping['output'].type == OperationDataType.OUTPUT + + strategies_vector = handler.register_strategy(compute_resharding_cost=False) + strategy_name_list = [val.name for val in strategies_vector] + assert len(strategy_name_list) == 1 + # one batch dim + assert 'Sb0 = Sb0 x Sb0' in strategy_name_list + + for strategy in strategies_vector: + input_sharding_spec = strategy.get_sharding_spec_by_name('x1') + other_sharding_spec = strategy.get_sharding_spec_by_name('x2') + bias_sharding_spec = strategy.get_sharding_spec_by_name('bias') + output_sharding_spec = strategy.get_sharding_spec_by_name('addbmm') + + # make sure the sharding matches across different operation data + assert input_sharding_spec.sharding_sequence[1] == output_sharding_spec.sharding_sequence[0] + assert other_sharding_spec.sharding_sequence[1] == input_sharding_spec.sharding_sequence[-1] + assert other_sharding_spec.sharding_sequence[-1] == output_sharding_spec.sharding_sequence[-1] + assert bias_sharding_spec.sharding_sequence[-1] == output_sharding_spec.sharding_sequence[-1] + + +if __name__ == '__main__': + test_1d_device_mesh() + # test_2d_device_mesh() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_bmm_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_bmm_handler.py index 5d7272bec..f59fea90d 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_bmm_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_bmm_handler.py @@ -6,6 +6,7 @@ from colossalai.auto_parallel.tensor_shard.node_handler import BMMFunctionHandle from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector from colossalai.device.device_mesh import DeviceMesh from colossalai.fx import ColoGraphModule, ColoTracer +from colossalai.testing import parameterize class BMMTensorMethodModule(nn.Module): @@ -20,7 +21,7 @@ class BMMTorchFunctionModule(nn.Module): return torch.bmm(x1, x2) -@pytest.mark.parametrize('module', [BMMTensorMethodModule, BMMTorchFunctionModule]) +@parameterize('module', [BMMTensorMethodModule, BMMTorchFunctionModule]) def test_2d_device_mesh(module): model = module() @@ -95,12 +96,13 @@ def test_2d_device_mesh(module): output_sharding_spec = strategy.get_sharding_spec_by_name('bmm') # make sure the sharding matches across different operation data + print(input_sharding_spec.sharding_sequence, output_sharding_spec.sharding_sequence) assert input_sharding_spec.sharding_sequence[:-1] == output_sharding_spec.sharding_sequence[:-1] assert other_sharding_spec.sharding_sequence[1] == input_sharding_spec.sharding_sequence[-1] assert other_sharding_spec.sharding_sequence[-1] == output_sharding_spec.sharding_sequence[-1] -@pytest.mark.parametrize('module', [BMMTensorMethodModule, BMMTorchFunctionModule]) +@parameterize('module', [BMMTensorMethodModule, BMMTorchFunctionModule]) def test_1d_device_mesh(module): model = module() tracer = ColoTracer() @@ -165,7 +167,5 @@ def test_1d_device_mesh(module): if __name__ == '__main__': - test_1d_device_mesh(BMMTensorMethodModule) - test_1d_device_mesh(BMMTorchFunctionModule) - test_2d_device_mesh(BMMTensorMethodModule) - test_2d_device_mesh(BMMTorchFunctionModule) + test_1d_device_mesh() + test_2d_device_mesh()