[autoparallel] fixed wrong sharding strategy in conv handler (#1747)

* [autoparallel] fixed wrong sharding strategy in conv handler

* polish code
This commit is contained in:
Frank Lee
2022-10-20 16:12:39 +08:00
committed by GitHub
parent 8b8937d901
commit 474111ecb5
6 changed files with 75 additions and 60 deletions

View File

@@ -5,12 +5,12 @@ from colossalai.auto_parallel.tensor_shard.node_handler.conv_handler import Conv
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.testing import parameterize
@run_on_environment_flag(name='AUTO_PARALLEL')
def test_conv_module_handler():
model = nn.Sequential(nn.Conv2d(4, 16, 3, padding=1).to('meta'))
@parameterize('bias', [True, False])
def test_conv_module_handler(bias):
model = nn.Sequential(nn.Conv2d(4, 16, 3, padding=1, bias=bias).to('meta'))
tracer = ColoTracer()
# graph():
# %input_1 : torch.Tensor [#users=1] = placeholder[target=input]
@@ -49,11 +49,12 @@ def test_conv_module_handler():
assert mapping['other'].type == OperationDataType.PARAM
assert mapping['other'].logical_shape == torch.Size([4, 16, 3, 3])
assert mapping['bias'].name == "bias"
assert mapping['bias'].data.is_meta
assert mapping['bias'].data.shape == torch.Size([16])
assert mapping['bias'].type == OperationDataType.PARAM
assert mapping['bias'].logical_shape == torch.Size([16])
if bias:
assert mapping['bias'].name == "bias"
assert mapping['bias'].data.is_meta
assert mapping['bias'].data.shape == torch.Size([16])
assert mapping['bias'].type == OperationDataType.PARAM
assert mapping['bias'].logical_shape == torch.Size([16])
assert mapping['output'].name == "_0"
assert mapping['output'].data.is_meta
@@ -99,6 +100,24 @@ def test_conv_module_handler():
# RS01 = RR x RS01
assert 'RS01 = RR x RS01' in strategy_name_list
for strategy in strategies_vector:
input_sharding_spec = strategy.get_sharding_spec_by_name('input_1')
weight_sharding_spec = strategy.get_sharding_spec_by_name('weight')
output_sharding_spec = strategy.get_sharding_spec_by_name('_0')
if bias:
bias_sharding_spec = strategy.get_sharding_spec_by_name('bias')
# make sure the sharding matches across different operation data
assert output_sharding_spec.sharding_sequence[1] == weight_sharding_spec.sharding_sequence[0]
assert input_sharding_spec.sharding_sequence[0] == output_sharding_spec.sharding_sequence[0]
assert input_sharding_spec.sharding_sequence[2:] == output_sharding_spec.sharding_sequence[2:]
assert input_sharding_spec.sharding_sequence[1] == weight_sharding_spec.sharding_sequence[1]
if bias:
assert bias_sharding_spec.sharding_sequence[-1] == weight_sharding_spec.sharding_sequence[0]
assert bias_sharding_spec.sharding_sequence[-1] == output_sharding_spec.sharding_sequence[1]
class ConvModel(nn.Module):
@@ -110,8 +129,8 @@ class ConvModel(nn.Module):
return x
@run_on_environment_flag(name='AUTO_PARALLEL')
def test_conv_function_handler():
@parameterize('bias', [True, False])
def test_conv_function_handler(bias):
model = ConvModel()
tracer = ColoTracer()
# graph():
@@ -119,18 +138,20 @@ def test_conv_function_handler():
# %others : torch.Tensor [#users=1] = placeholder[target=others]
# %conv2d : [#users=1] = call_function[target=torch.conv2d](args = (%input_1, %others), kwargs = {})
# return conv2d
graph = tracer.trace(model,
meta_args={
"input": torch.rand(4, 4, 64, 64).to('meta'),
"others": torch.rand(16, 4, 3, 3).to('meta'),
"bias": torch.rand(16).to('meta')
})
meta_args = {"input": torch.rand(4, 4, 64, 64).to('meta'), "others": torch.rand(16, 4, 3, 3).to('meta')}
if bias:
meta_args['bias'] = torch.rand(16).to('meta')
graph = tracer.trace(model, meta_args=meta_args)
gm = ColoGraphModule(model, graph)
physical_mesh_id = torch.arange(0, 4)
mesh_shape = (2, 2)
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
conv_mod_node = list(graph.nodes)[3]
if bias:
conv_mod_node = list(graph.nodes)[3]
else:
conv_mod_node = list(graph.nodes)[2]
strategies_vector = StrategiesVector(conv_mod_node)
# build handler
@@ -157,11 +178,12 @@ def test_conv_function_handler():
assert mapping['other'].type == OperationDataType.ARG
assert mapping['other'].logical_shape == torch.Size([4, 16, 3, 3])
assert mapping['bias'].name == "bias"
assert mapping['bias'].data.is_meta
assert mapping['bias'].data.shape == torch.Size([16])
assert mapping['bias'].type == OperationDataType.ARG
assert mapping['bias'].logical_shape == torch.Size([16])
if bias:
assert mapping['bias'].name == "bias"
assert mapping['bias'].data.is_meta
assert mapping['bias'].data.shape == torch.Size([16])
assert mapping['bias'].type == OperationDataType.ARG
assert mapping['bias'].logical_shape == torch.Size([16])
assert mapping['output'].name == "conv2d"
assert mapping['output'].data.is_meta
@@ -207,6 +229,24 @@ def test_conv_function_handler():
# RS01 = RR x RS01
assert 'RS01 = RR x RS01' in strategy_name_list
for strategy in strategies_vector:
input_sharding_spec = strategy.get_sharding_spec_by_name('input_1')
weight_sharding_spec = strategy.get_sharding_spec_by_name('others')
output_sharding_spec = strategy.get_sharding_spec_by_name('conv2d')
if bias:
bias_sharding_spec = strategy.get_sharding_spec_by_name('bias')
# make sure the sharding matches across different operation data
assert output_sharding_spec.sharding_sequence[1] == weight_sharding_spec.sharding_sequence[0]
assert input_sharding_spec.sharding_sequence[0] == output_sharding_spec.sharding_sequence[0]
assert input_sharding_spec.sharding_sequence[2:] == output_sharding_spec.sharding_sequence[2:]
assert input_sharding_spec.sharding_sequence[1] == weight_sharding_spec.sharding_sequence[1]
if bias:
assert bias_sharding_spec.sharding_sequence[-1] == weight_sharding_spec.sharding_sequence[0]
assert bias_sharding_spec.sharding_sequence[-1] == output_sharding_spec.sharding_sequence[1]
if __name__ == '__main__':
test_conv_module_handler()