[autoparallel] fixed wrong sharding strategy in conv handler (#1747)

* [autoparallel] fixed wrong sharding strategy in conv handler

* polish code
This commit is contained in:
Frank Lee 2022-10-20 16:12:39 +08:00 committed by GitHub
parent 8b8937d901
commit 474111ecb5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 75 additions and 60 deletions

View File

@ -4,6 +4,7 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
from ..sharding_strategy import OperationData, OperationDataType, ShardingStrategy from ..sharding_strategy import OperationData, OperationDataType, ShardingStrategy
from ..utils import transpose_partition_dim
from .node_handler import ModuleHandler, NodeHandler from .node_handler import ModuleHandler, NodeHandler
from .registry import operator_registry from .registry import operator_registry
from .strategy import ConvStrategyGenerator, StrategyGenerator from .strategy import ConvStrategyGenerator, StrategyGenerator
@ -55,20 +56,7 @@ class ConvModuleHandler(ModuleHandler):
""" """
for op_data, sharding_spec in strategy.input_sharding_specs.items(): for op_data, sharding_spec in strategy.input_sharding_specs.items():
if op_data.name == "weight": if op_data.name == "weight":
dim_partition_dict = sharding_spec.dim_partition_dict transpose_partition_dim(sharding_spec, 0, 1)
# switch first and second dim of the conv module weight
first_dim_partition = dim_partition_dict.pop(1, None)
second_dim_partition = dim_partition_dict.pop(0, None)
if first_dim_partition:
dim_partition_dict[0] = first_dim_partition
if second_dim_partition:
dim_partition_dict[1] = second_dim_partition
# re-init the sharding spec
sharding_spec.__init__(sharding_spec.device_mesh, op_data.data.shape, dim_partition_dict)
return strategy return strategy
@ -110,7 +98,7 @@ class ConvFunctionHandler(NodeHandler):
mapping = {"input": physical_input_operand, "other": physical_other_operand, "output": physical_output} mapping = {"input": physical_input_operand, "other": physical_other_operand, "output": physical_output}
if "bias" in self.node.kwargs: if "bias" in self.node.kwargs and self.node.kwargs['bias'] is not None:
# check if the other operand is a parameter # check if the other operand is a parameter
if isinstance(self.node.kwargs["bias"]._meta_data, torch.nn.parameter.Parameter): if isinstance(self.node.kwargs["bias"]._meta_data, torch.nn.parameter.Parameter):
data_type = OperationDataType.PARAM data_type = OperationDataType.PARAM
@ -128,19 +116,5 @@ class ConvFunctionHandler(NodeHandler):
""" """
for op_data, sharding_spec in strategy.input_sharding_specs.items(): for op_data, sharding_spec in strategy.input_sharding_specs.items():
if op_data.name == str(self.node.args[1]): if op_data.name == str(self.node.args[1]):
assert op_data.logical_shape != op_data.data.shape transpose_partition_dim(sharding_spec, 0, 1)
dim_partition_dict = sharding_spec.dim_partition_dict
# switch first and second dim of the conv function weight
first_dim_partition = dim_partition_dict.pop(1, None)
second_dim_partition = dim_partition_dict.pop(0, None)
if first_dim_partition:
dim_partition_dict[0] = first_dim_partition
if second_dim_partition:
dim_partition_dict[1] = second_dim_partition
# re-init the sharding spec
sharding_spec.__init__(sharding_spec.device_mesh, sharding_spec.entire_shape, dim_partition_dict)
return strategy return strategy

View File

@ -3,7 +3,7 @@ from typing import Dict, List, Union
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from colossalai.auto_parallel.tensor_shard.utils import tranpose_partition_dim, update_partition_dim from colossalai.auto_parallel.tensor_shard.utils import transpose_partition_dim, update_partition_dim
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from colossalai.tensor.sharding_spec import ShardingNotDivisibleError from colossalai.tensor.sharding_spec import ShardingNotDivisibleError
@ -30,7 +30,7 @@ def _update_sharding_spec_for_transposed_weight_for_linear(strategy: ShardingStr
op_data = strategy.get_op_data_by_name(weight_name) op_data = strategy.get_op_data_by_name(weight_name)
assert op_data.logical_shape != op_data.data.shape, \ assert op_data.logical_shape != op_data.data.shape, \
"Expected the logical and physical shape of the linear operator's weight to be different, but found them to be the same" "Expected the logical and physical shape of the linear operator's weight to be different, but found them to be the same"
tranpose_partition_dim(sharding_spec, 0, -1) transpose_partition_dim(sharding_spec, 0, -1)
return strategy return strategy

View File

@ -5,13 +5,13 @@ from .sharding import (
enumerate_all_possible_1d_sharding, enumerate_all_possible_1d_sharding,
enumerate_all_possible_2d_sharding, enumerate_all_possible_2d_sharding,
generate_sharding_size, generate_sharding_size,
tranpose_partition_dim, transpose_partition_dim,
update_partition_dim, update_partition_dim,
) )
__all__ = [ __all__ = [
'BroadcastType', 'get_broadcast_shape', 'is_broadcastable', 'recover_sharding_spec_for_broadcast_shape', 'BroadcastType', 'get_broadcast_shape', 'is_broadcastable', 'recover_sharding_spec_for_broadcast_shape',
'generate_resharding_costs', 'generate_sharding_spec', 'ignore_sharding_exception', 'check_sharding_spec_validity' 'generate_resharding_costs', 'generate_sharding_spec', 'ignore_sharding_exception', 'check_sharding_spec_validity'
'tranpose_partition_dim', 'update_partition_dim', 'enumerate_all_possible_1d_sharding', 'transpose_partition_dim', 'update_partition_dim', 'enumerate_all_possible_1d_sharding',
'enumerate_all_possible_2d_sharding', 'generate_sharding_size' 'enumerate_all_possible_2d_sharding', 'generate_sharding_size'
] ]

View File

@ -68,4 +68,5 @@ def check_sharding_spec_validity(sharding_spec: ShardingSpec, tensor: torch.Tens
f'The dimension at index {i} has value {dim_size}, but it is sharded over {num_devices} devices.' f'The dimension at index {i} has value {dim_size}, but it is sharded over {num_devices} devices.'
# make sure the entire shape matches the physical tensor shape # make sure the entire shape matches the physical tensor shape
assert sharding_spec.entire_shape == tensor.shape assert sharding_spec.entire_shape == tensor.shape, \
f'The entire_shape of the sharding spec {sharding_spec.entire_shape} does not match the tensor shape {tensor.shape}'

View File

@ -8,12 +8,12 @@ import torch
from colossalai.tensor.sharding_spec import ShardingSpec from colossalai.tensor.sharding_spec import ShardingSpec
__all__ = [ __all__ = [
'tranpose_partition_dim', 'update_partition_dim', 'enumerate_all_possible_1d_sharding', 'transpose_partition_dim', 'update_partition_dim', 'enumerate_all_possible_1d_sharding',
'enumerate_all_possible_2d_sharding', 'generate_sharding_size' 'enumerate_all_possible_2d_sharding', 'generate_sharding_size'
] ]
def tranpose_partition_dim(sharding_spec: ShardingSpec, dim1: int, dim2: int) -> ShardingSpec: def transpose_partition_dim(sharding_spec: ShardingSpec, dim1: int, dim2: int) -> ShardingSpec:
""" """
Switch the sharding mesh dimensions for two tensor dimensions. This operation is in-place. Switch the sharding mesh dimensions for two tensor dimensions. This operation is in-place.

View File

@ -5,12 +5,12 @@ from colossalai.auto_parallel.tensor_shard.node_handler.conv_handler import Conv
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.testing.pytest_wrapper import run_on_environment_flag from colossalai.testing import parameterize
@run_on_environment_flag(name='AUTO_PARALLEL') @parameterize('bias', [True, False])
def test_conv_module_handler(): def test_conv_module_handler(bias):
model = nn.Sequential(nn.Conv2d(4, 16, 3, padding=1).to('meta')) model = nn.Sequential(nn.Conv2d(4, 16, 3, padding=1, bias=bias).to('meta'))
tracer = ColoTracer() tracer = ColoTracer()
# graph(): # graph():
# %input_1 : torch.Tensor [#users=1] = placeholder[target=input] # %input_1 : torch.Tensor [#users=1] = placeholder[target=input]
@ -49,6 +49,7 @@ def test_conv_module_handler():
assert mapping['other'].type == OperationDataType.PARAM assert mapping['other'].type == OperationDataType.PARAM
assert mapping['other'].logical_shape == torch.Size([4, 16, 3, 3]) assert mapping['other'].logical_shape == torch.Size([4, 16, 3, 3])
if bias:
assert mapping['bias'].name == "bias" assert mapping['bias'].name == "bias"
assert mapping['bias'].data.is_meta assert mapping['bias'].data.is_meta
assert mapping['bias'].data.shape == torch.Size([16]) assert mapping['bias'].data.shape == torch.Size([16])
@ -99,6 +100,24 @@ def test_conv_module_handler():
# RS01 = RR x RS01 # RS01 = RR x RS01
assert 'RS01 = RR x RS01' in strategy_name_list assert 'RS01 = RR x RS01' in strategy_name_list
for strategy in strategies_vector:
input_sharding_spec = strategy.get_sharding_spec_by_name('input_1')
weight_sharding_spec = strategy.get_sharding_spec_by_name('weight')
output_sharding_spec = strategy.get_sharding_spec_by_name('_0')
if bias:
bias_sharding_spec = strategy.get_sharding_spec_by_name('bias')
# make sure the sharding matches across different operation data
assert output_sharding_spec.sharding_sequence[1] == weight_sharding_spec.sharding_sequence[0]
assert input_sharding_spec.sharding_sequence[0] == output_sharding_spec.sharding_sequence[0]
assert input_sharding_spec.sharding_sequence[2:] == output_sharding_spec.sharding_sequence[2:]
assert input_sharding_spec.sharding_sequence[1] == weight_sharding_spec.sharding_sequence[1]
if bias:
assert bias_sharding_spec.sharding_sequence[-1] == weight_sharding_spec.sharding_sequence[0]
assert bias_sharding_spec.sharding_sequence[-1] == output_sharding_spec.sharding_sequence[1]
class ConvModel(nn.Module): class ConvModel(nn.Module):
@ -110,8 +129,8 @@ class ConvModel(nn.Module):
return x return x
@run_on_environment_flag(name='AUTO_PARALLEL') @parameterize('bias', [True, False])
def test_conv_function_handler(): def test_conv_function_handler(bias):
model = ConvModel() model = ConvModel()
tracer = ColoTracer() tracer = ColoTracer()
# graph(): # graph():
@ -119,18 +138,20 @@ def test_conv_function_handler():
# %others : torch.Tensor [#users=1] = placeholder[target=others] # %others : torch.Tensor [#users=1] = placeholder[target=others]
# %conv2d : [#users=1] = call_function[target=torch.conv2d](args = (%input_1, %others), kwargs = {}) # %conv2d : [#users=1] = call_function[target=torch.conv2d](args = (%input_1, %others), kwargs = {})
# return conv2d # return conv2d
graph = tracer.trace(model, meta_args = {"input": torch.rand(4, 4, 64, 64).to('meta'), "others": torch.rand(16, 4, 3, 3).to('meta')}
meta_args={ if bias:
"input": torch.rand(4, 4, 64, 64).to('meta'), meta_args['bias'] = torch.rand(16).to('meta')
"others": torch.rand(16, 4, 3, 3).to('meta'), graph = tracer.trace(model, meta_args=meta_args)
"bias": torch.rand(16).to('meta')
})
gm = ColoGraphModule(model, graph) gm = ColoGraphModule(model, graph)
physical_mesh_id = torch.arange(0, 4) physical_mesh_id = torch.arange(0, 4)
mesh_shape = (2, 2) mesh_shape = (2, 2)
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
if bias:
conv_mod_node = list(graph.nodes)[3] conv_mod_node = list(graph.nodes)[3]
else:
conv_mod_node = list(graph.nodes)[2]
strategies_vector = StrategiesVector(conv_mod_node) strategies_vector = StrategiesVector(conv_mod_node)
# build handler # build handler
@ -157,6 +178,7 @@ def test_conv_function_handler():
assert mapping['other'].type == OperationDataType.ARG assert mapping['other'].type == OperationDataType.ARG
assert mapping['other'].logical_shape == torch.Size([4, 16, 3, 3]) assert mapping['other'].logical_shape == torch.Size([4, 16, 3, 3])
if bias:
assert mapping['bias'].name == "bias" assert mapping['bias'].name == "bias"
assert mapping['bias'].data.is_meta assert mapping['bias'].data.is_meta
assert mapping['bias'].data.shape == torch.Size([16]) assert mapping['bias'].data.shape == torch.Size([16])
@ -207,6 +229,24 @@ def test_conv_function_handler():
# RS01 = RR x RS01 # RS01 = RR x RS01
assert 'RS01 = RR x RS01' in strategy_name_list assert 'RS01 = RR x RS01' in strategy_name_list
for strategy in strategies_vector:
input_sharding_spec = strategy.get_sharding_spec_by_name('input_1')
weight_sharding_spec = strategy.get_sharding_spec_by_name('others')
output_sharding_spec = strategy.get_sharding_spec_by_name('conv2d')
if bias:
bias_sharding_spec = strategy.get_sharding_spec_by_name('bias')
# make sure the sharding matches across different operation data
assert output_sharding_spec.sharding_sequence[1] == weight_sharding_spec.sharding_sequence[0]
assert input_sharding_spec.sharding_sequence[0] == output_sharding_spec.sharding_sequence[0]
assert input_sharding_spec.sharding_sequence[2:] == output_sharding_spec.sharding_sequence[2:]
assert input_sharding_spec.sharding_sequence[1] == weight_sharding_spec.sharding_sequence[1]
if bias:
assert bias_sharding_spec.sharding_sequence[-1] == weight_sharding_spec.sharding_sequence[0]
assert bias_sharding_spec.sharding_sequence[-1] == output_sharding_spec.sharding_sequence[1]
if __name__ == '__main__': if __name__ == '__main__':
test_conv_module_handler() test_conv_module_handler()