From 474111ecb549abd4072d9df5190b8c73d0130eed Mon Sep 17 00:00:00 2001 From: Frank Lee Date: Thu, 20 Oct 2022 16:12:39 +0800 Subject: [PATCH] [autoparallel] fixed wrong sharding strategy in conv handler (#1747) * [autoparallel] fixed wrong sharding strategy in conv handler * polish code --- .../tensor_shard/node_handler/conv_handler.py | 34 +------- .../node_handler/linear_handler.py | 4 +- .../tensor_shard/utils/__init__.py | 4 +- .../auto_parallel/tensor_shard/utils/misc.py | 3 +- .../tensor_shard/utils/sharding.py | 4 +- .../test_node_handler/test_conv_handler.py | 86 ++++++++++++++----- 6 files changed, 75 insertions(+), 60 deletions(-) diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/conv_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/conv_handler.py index 8463cc62b..0c00160ef 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/conv_handler.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/conv_handler.py @@ -4,6 +4,7 @@ import torch import torch.nn.functional as F from ..sharding_strategy import OperationData, OperationDataType, ShardingStrategy +from ..utils import transpose_partition_dim from .node_handler import ModuleHandler, NodeHandler from .registry import operator_registry from .strategy import ConvStrategyGenerator, StrategyGenerator @@ -55,20 +56,7 @@ class ConvModuleHandler(ModuleHandler): """ for op_data, sharding_spec in strategy.input_sharding_specs.items(): if op_data.name == "weight": - dim_partition_dict = sharding_spec.dim_partition_dict - - # switch first and second dim of the conv module weight - first_dim_partition = dim_partition_dict.pop(1, None) - second_dim_partition = dim_partition_dict.pop(0, None) - - if first_dim_partition: - dim_partition_dict[0] = first_dim_partition - - if second_dim_partition: - dim_partition_dict[1] = second_dim_partition - - # re-init the sharding spec - sharding_spec.__init__(sharding_spec.device_mesh, op_data.data.shape, dim_partition_dict) + transpose_partition_dim(sharding_spec, 0, 1) return strategy @@ -110,7 +98,7 @@ class ConvFunctionHandler(NodeHandler): mapping = {"input": physical_input_operand, "other": physical_other_operand, "output": physical_output} - if "bias" in self.node.kwargs: + if "bias" in self.node.kwargs and self.node.kwargs['bias'] is not None: # check if the other operand is a parameter if isinstance(self.node.kwargs["bias"]._meta_data, torch.nn.parameter.Parameter): data_type = OperationDataType.PARAM @@ -128,19 +116,5 @@ class ConvFunctionHandler(NodeHandler): """ for op_data, sharding_spec in strategy.input_sharding_specs.items(): if op_data.name == str(self.node.args[1]): - assert op_data.logical_shape != op_data.data.shape - dim_partition_dict = sharding_spec.dim_partition_dict - - # switch first and second dim of the conv function weight - first_dim_partition = dim_partition_dict.pop(1, None) - second_dim_partition = dim_partition_dict.pop(0, None) - - if first_dim_partition: - dim_partition_dict[0] = first_dim_partition - - if second_dim_partition: - dim_partition_dict[1] = second_dim_partition - - # re-init the sharding spec - sharding_spec.__init__(sharding_spec.device_mesh, sharding_spec.entire_shape, dim_partition_dict) + transpose_partition_dim(sharding_spec, 0, 1) return strategy diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/linear_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/linear_handler.py index 4a8af4ca7..62210ebe9 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/linear_handler.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/linear_handler.py @@ -3,7 +3,7 @@ from typing import Dict, List, Union import torch import torch.nn.functional as F -from colossalai.auto_parallel.tensor_shard.utils import tranpose_partition_dim, update_partition_dim +from colossalai.auto_parallel.tensor_shard.utils import transpose_partition_dim, update_partition_dim from colossalai.logging import get_dist_logger from colossalai.tensor.sharding_spec import ShardingNotDivisibleError @@ -30,7 +30,7 @@ def _update_sharding_spec_for_transposed_weight_for_linear(strategy: ShardingStr op_data = strategy.get_op_data_by_name(weight_name) assert op_data.logical_shape != op_data.data.shape, \ "Expected the logical and physical shape of the linear operator's weight to be different, but found them to be the same" - tranpose_partition_dim(sharding_spec, 0, -1) + transpose_partition_dim(sharding_spec, 0, -1) return strategy diff --git a/colossalai/auto_parallel/tensor_shard/utils/__init__.py b/colossalai/auto_parallel/tensor_shard/utils/__init__.py index 9032fc58f..380464bcd 100644 --- a/colossalai/auto_parallel/tensor_shard/utils/__init__.py +++ b/colossalai/auto_parallel/tensor_shard/utils/__init__.py @@ -5,13 +5,13 @@ from .sharding import ( enumerate_all_possible_1d_sharding, enumerate_all_possible_2d_sharding, generate_sharding_size, - tranpose_partition_dim, + transpose_partition_dim, update_partition_dim, ) __all__ = [ 'BroadcastType', 'get_broadcast_shape', 'is_broadcastable', 'recover_sharding_spec_for_broadcast_shape', 'generate_resharding_costs', 'generate_sharding_spec', 'ignore_sharding_exception', 'check_sharding_spec_validity' - 'tranpose_partition_dim', 'update_partition_dim', 'enumerate_all_possible_1d_sharding', + 'transpose_partition_dim', 'update_partition_dim', 'enumerate_all_possible_1d_sharding', 'enumerate_all_possible_2d_sharding', 'generate_sharding_size' ] diff --git a/colossalai/auto_parallel/tensor_shard/utils/misc.py b/colossalai/auto_parallel/tensor_shard/utils/misc.py index c0ef6df88..967847390 100644 --- a/colossalai/auto_parallel/tensor_shard/utils/misc.py +++ b/colossalai/auto_parallel/tensor_shard/utils/misc.py @@ -68,4 +68,5 @@ def check_sharding_spec_validity(sharding_spec: ShardingSpec, tensor: torch.Tens f'The dimension at index {i} has value {dim_size}, but it is sharded over {num_devices} devices.' # make sure the entire shape matches the physical tensor shape - assert sharding_spec.entire_shape == tensor.shape + assert sharding_spec.entire_shape == tensor.shape, \ + f'The entire_shape of the sharding spec {sharding_spec.entire_shape} does not match the tensor shape {tensor.shape}' diff --git a/colossalai/auto_parallel/tensor_shard/utils/sharding.py b/colossalai/auto_parallel/tensor_shard/utils/sharding.py index 622a33367..e2ce59e0b 100644 --- a/colossalai/auto_parallel/tensor_shard/utils/sharding.py +++ b/colossalai/auto_parallel/tensor_shard/utils/sharding.py @@ -8,12 +8,12 @@ import torch from colossalai.tensor.sharding_spec import ShardingSpec __all__ = [ - 'tranpose_partition_dim', 'update_partition_dim', 'enumerate_all_possible_1d_sharding', + 'transpose_partition_dim', 'update_partition_dim', 'enumerate_all_possible_1d_sharding', 'enumerate_all_possible_2d_sharding', 'generate_sharding_size' ] -def tranpose_partition_dim(sharding_spec: ShardingSpec, dim1: int, dim2: int) -> ShardingSpec: +def transpose_partition_dim(sharding_spec: ShardingSpec, dim1: int, dim2: int) -> ShardingSpec: """ Switch the sharding mesh dimensions for two tensor dimensions. This operation is in-place. diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_conv_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_conv_handler.py index 69fd411e0..97025729c 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_conv_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_conv_handler.py @@ -5,12 +5,12 @@ from colossalai.auto_parallel.tensor_shard.node_handler.conv_handler import Conv from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector from colossalai.device.device_mesh import DeviceMesh from colossalai.fx import ColoGraphModule, ColoTracer -from colossalai.testing.pytest_wrapper import run_on_environment_flag +from colossalai.testing import parameterize -@run_on_environment_flag(name='AUTO_PARALLEL') -def test_conv_module_handler(): - model = nn.Sequential(nn.Conv2d(4, 16, 3, padding=1).to('meta')) +@parameterize('bias', [True, False]) +def test_conv_module_handler(bias): + model = nn.Sequential(nn.Conv2d(4, 16, 3, padding=1, bias=bias).to('meta')) tracer = ColoTracer() # graph(): # %input_1 : torch.Tensor [#users=1] = placeholder[target=input] @@ -49,11 +49,12 @@ def test_conv_module_handler(): assert mapping['other'].type == OperationDataType.PARAM assert mapping['other'].logical_shape == torch.Size([4, 16, 3, 3]) - assert mapping['bias'].name == "bias" - assert mapping['bias'].data.is_meta - assert mapping['bias'].data.shape == torch.Size([16]) - assert mapping['bias'].type == OperationDataType.PARAM - assert mapping['bias'].logical_shape == torch.Size([16]) + if bias: + assert mapping['bias'].name == "bias" + assert mapping['bias'].data.is_meta + assert mapping['bias'].data.shape == torch.Size([16]) + assert mapping['bias'].type == OperationDataType.PARAM + assert mapping['bias'].logical_shape == torch.Size([16]) assert mapping['output'].name == "_0" assert mapping['output'].data.is_meta @@ -99,6 +100,24 @@ def test_conv_module_handler(): # RS01 = RR x RS01 assert 'RS01 = RR x RS01' in strategy_name_list + for strategy in strategies_vector: + input_sharding_spec = strategy.get_sharding_spec_by_name('input_1') + weight_sharding_spec = strategy.get_sharding_spec_by_name('weight') + output_sharding_spec = strategy.get_sharding_spec_by_name('_0') + + if bias: + bias_sharding_spec = strategy.get_sharding_spec_by_name('bias') + + # make sure the sharding matches across different operation data + assert output_sharding_spec.sharding_sequence[1] == weight_sharding_spec.sharding_sequence[0] + assert input_sharding_spec.sharding_sequence[0] == output_sharding_spec.sharding_sequence[0] + assert input_sharding_spec.sharding_sequence[2:] == output_sharding_spec.sharding_sequence[2:] + assert input_sharding_spec.sharding_sequence[1] == weight_sharding_spec.sharding_sequence[1] + + if bias: + assert bias_sharding_spec.sharding_sequence[-1] == weight_sharding_spec.sharding_sequence[0] + assert bias_sharding_spec.sharding_sequence[-1] == output_sharding_spec.sharding_sequence[1] + class ConvModel(nn.Module): @@ -110,8 +129,8 @@ class ConvModel(nn.Module): return x -@run_on_environment_flag(name='AUTO_PARALLEL') -def test_conv_function_handler(): +@parameterize('bias', [True, False]) +def test_conv_function_handler(bias): model = ConvModel() tracer = ColoTracer() # graph(): @@ -119,18 +138,20 @@ def test_conv_function_handler(): # %others : torch.Tensor [#users=1] = placeholder[target=others] # %conv2d : [#users=1] = call_function[target=torch.conv2d](args = (%input_1, %others), kwargs = {}) # return conv2d - graph = tracer.trace(model, - meta_args={ - "input": torch.rand(4, 4, 64, 64).to('meta'), - "others": torch.rand(16, 4, 3, 3).to('meta'), - "bias": torch.rand(16).to('meta') - }) + meta_args = {"input": torch.rand(4, 4, 64, 64).to('meta'), "others": torch.rand(16, 4, 3, 3).to('meta')} + if bias: + meta_args['bias'] = torch.rand(16).to('meta') + graph = tracer.trace(model, meta_args=meta_args) gm = ColoGraphModule(model, graph) physical_mesh_id = torch.arange(0, 4) mesh_shape = (2, 2) device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) - conv_mod_node = list(graph.nodes)[3] + + if bias: + conv_mod_node = list(graph.nodes)[3] + else: + conv_mod_node = list(graph.nodes)[2] strategies_vector = StrategiesVector(conv_mod_node) # build handler @@ -157,11 +178,12 @@ def test_conv_function_handler(): assert mapping['other'].type == OperationDataType.ARG assert mapping['other'].logical_shape == torch.Size([4, 16, 3, 3]) - assert mapping['bias'].name == "bias" - assert mapping['bias'].data.is_meta - assert mapping['bias'].data.shape == torch.Size([16]) - assert mapping['bias'].type == OperationDataType.ARG - assert mapping['bias'].logical_shape == torch.Size([16]) + if bias: + assert mapping['bias'].name == "bias" + assert mapping['bias'].data.is_meta + assert mapping['bias'].data.shape == torch.Size([16]) + assert mapping['bias'].type == OperationDataType.ARG + assert mapping['bias'].logical_shape == torch.Size([16]) assert mapping['output'].name == "conv2d" assert mapping['output'].data.is_meta @@ -207,6 +229,24 @@ def test_conv_function_handler(): # RS01 = RR x RS01 assert 'RS01 = RR x RS01' in strategy_name_list + for strategy in strategies_vector: + input_sharding_spec = strategy.get_sharding_spec_by_name('input_1') + weight_sharding_spec = strategy.get_sharding_spec_by_name('others') + output_sharding_spec = strategy.get_sharding_spec_by_name('conv2d') + + if bias: + bias_sharding_spec = strategy.get_sharding_spec_by_name('bias') + + # make sure the sharding matches across different operation data + assert output_sharding_spec.sharding_sequence[1] == weight_sharding_spec.sharding_sequence[0] + assert input_sharding_spec.sharding_sequence[0] == output_sharding_spec.sharding_sequence[0] + assert input_sharding_spec.sharding_sequence[2:] == output_sharding_spec.sharding_sequence[2:] + assert input_sharding_spec.sharding_sequence[1] == weight_sharding_spec.sharding_sequence[1] + + if bias: + assert bias_sharding_spec.sharding_sequence[-1] == weight_sharding_spec.sharding_sequence[0] + assert bias_sharding_spec.sharding_sequence[-1] == output_sharding_spec.sharding_sequence[1] + if __name__ == '__main__': test_conv_module_handler()