mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-12 12:47:21 +00:00
[autoparallel] fixed wrong sharding strategy in conv handler (#1747)
* [autoparallel] fixed wrong sharding strategy in conv handler * polish code
This commit is contained in:
@@ -5,12 +5,12 @@ from colossalai.auto_parallel.tensor_shard.node_handler.conv_handler import Conv
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.fx import ColoGraphModule, ColoTracer
|
||||
from colossalai.testing.pytest_wrapper import run_on_environment_flag
|
||||
from colossalai.testing import parameterize
|
||||
|
||||
|
||||
@run_on_environment_flag(name='AUTO_PARALLEL')
|
||||
def test_conv_module_handler():
|
||||
model = nn.Sequential(nn.Conv2d(4, 16, 3, padding=1).to('meta'))
|
||||
@parameterize('bias', [True, False])
|
||||
def test_conv_module_handler(bias):
|
||||
model = nn.Sequential(nn.Conv2d(4, 16, 3, padding=1, bias=bias).to('meta'))
|
||||
tracer = ColoTracer()
|
||||
# graph():
|
||||
# %input_1 : torch.Tensor [#users=1] = placeholder[target=input]
|
||||
@@ -49,11 +49,12 @@ def test_conv_module_handler():
|
||||
assert mapping['other'].type == OperationDataType.PARAM
|
||||
assert mapping['other'].logical_shape == torch.Size([4, 16, 3, 3])
|
||||
|
||||
assert mapping['bias'].name == "bias"
|
||||
assert mapping['bias'].data.is_meta
|
||||
assert mapping['bias'].data.shape == torch.Size([16])
|
||||
assert mapping['bias'].type == OperationDataType.PARAM
|
||||
assert mapping['bias'].logical_shape == torch.Size([16])
|
||||
if bias:
|
||||
assert mapping['bias'].name == "bias"
|
||||
assert mapping['bias'].data.is_meta
|
||||
assert mapping['bias'].data.shape == torch.Size([16])
|
||||
assert mapping['bias'].type == OperationDataType.PARAM
|
||||
assert mapping['bias'].logical_shape == torch.Size([16])
|
||||
|
||||
assert mapping['output'].name == "_0"
|
||||
assert mapping['output'].data.is_meta
|
||||
@@ -99,6 +100,24 @@ def test_conv_module_handler():
|
||||
# RS01 = RR x RS01
|
||||
assert 'RS01 = RR x RS01' in strategy_name_list
|
||||
|
||||
for strategy in strategies_vector:
|
||||
input_sharding_spec = strategy.get_sharding_spec_by_name('input_1')
|
||||
weight_sharding_spec = strategy.get_sharding_spec_by_name('weight')
|
||||
output_sharding_spec = strategy.get_sharding_spec_by_name('_0')
|
||||
|
||||
if bias:
|
||||
bias_sharding_spec = strategy.get_sharding_spec_by_name('bias')
|
||||
|
||||
# make sure the sharding matches across different operation data
|
||||
assert output_sharding_spec.sharding_sequence[1] == weight_sharding_spec.sharding_sequence[0]
|
||||
assert input_sharding_spec.sharding_sequence[0] == output_sharding_spec.sharding_sequence[0]
|
||||
assert input_sharding_spec.sharding_sequence[2:] == output_sharding_spec.sharding_sequence[2:]
|
||||
assert input_sharding_spec.sharding_sequence[1] == weight_sharding_spec.sharding_sequence[1]
|
||||
|
||||
if bias:
|
||||
assert bias_sharding_spec.sharding_sequence[-1] == weight_sharding_spec.sharding_sequence[0]
|
||||
assert bias_sharding_spec.sharding_sequence[-1] == output_sharding_spec.sharding_sequence[1]
|
||||
|
||||
|
||||
class ConvModel(nn.Module):
|
||||
|
||||
@@ -110,8 +129,8 @@ class ConvModel(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
@run_on_environment_flag(name='AUTO_PARALLEL')
|
||||
def test_conv_function_handler():
|
||||
@parameterize('bias', [True, False])
|
||||
def test_conv_function_handler(bias):
|
||||
model = ConvModel()
|
||||
tracer = ColoTracer()
|
||||
# graph():
|
||||
@@ -119,18 +138,20 @@ def test_conv_function_handler():
|
||||
# %others : torch.Tensor [#users=1] = placeholder[target=others]
|
||||
# %conv2d : [#users=1] = call_function[target=torch.conv2d](args = (%input_1, %others), kwargs = {})
|
||||
# return conv2d
|
||||
graph = tracer.trace(model,
|
||||
meta_args={
|
||||
"input": torch.rand(4, 4, 64, 64).to('meta'),
|
||||
"others": torch.rand(16, 4, 3, 3).to('meta'),
|
||||
"bias": torch.rand(16).to('meta')
|
||||
})
|
||||
meta_args = {"input": torch.rand(4, 4, 64, 64).to('meta'), "others": torch.rand(16, 4, 3, 3).to('meta')}
|
||||
if bias:
|
||||
meta_args['bias'] = torch.rand(16).to('meta')
|
||||
graph = tracer.trace(model, meta_args=meta_args)
|
||||
gm = ColoGraphModule(model, graph)
|
||||
physical_mesh_id = torch.arange(0, 4)
|
||||
|
||||
mesh_shape = (2, 2)
|
||||
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
|
||||
conv_mod_node = list(graph.nodes)[3]
|
||||
|
||||
if bias:
|
||||
conv_mod_node = list(graph.nodes)[3]
|
||||
else:
|
||||
conv_mod_node = list(graph.nodes)[2]
|
||||
strategies_vector = StrategiesVector(conv_mod_node)
|
||||
|
||||
# build handler
|
||||
@@ -157,11 +178,12 @@ def test_conv_function_handler():
|
||||
assert mapping['other'].type == OperationDataType.ARG
|
||||
assert mapping['other'].logical_shape == torch.Size([4, 16, 3, 3])
|
||||
|
||||
assert mapping['bias'].name == "bias"
|
||||
assert mapping['bias'].data.is_meta
|
||||
assert mapping['bias'].data.shape == torch.Size([16])
|
||||
assert mapping['bias'].type == OperationDataType.ARG
|
||||
assert mapping['bias'].logical_shape == torch.Size([16])
|
||||
if bias:
|
||||
assert mapping['bias'].name == "bias"
|
||||
assert mapping['bias'].data.is_meta
|
||||
assert mapping['bias'].data.shape == torch.Size([16])
|
||||
assert mapping['bias'].type == OperationDataType.ARG
|
||||
assert mapping['bias'].logical_shape == torch.Size([16])
|
||||
|
||||
assert mapping['output'].name == "conv2d"
|
||||
assert mapping['output'].data.is_meta
|
||||
@@ -207,6 +229,24 @@ def test_conv_function_handler():
|
||||
# RS01 = RR x RS01
|
||||
assert 'RS01 = RR x RS01' in strategy_name_list
|
||||
|
||||
for strategy in strategies_vector:
|
||||
input_sharding_spec = strategy.get_sharding_spec_by_name('input_1')
|
||||
weight_sharding_spec = strategy.get_sharding_spec_by_name('others')
|
||||
output_sharding_spec = strategy.get_sharding_spec_by_name('conv2d')
|
||||
|
||||
if bias:
|
||||
bias_sharding_spec = strategy.get_sharding_spec_by_name('bias')
|
||||
|
||||
# make sure the sharding matches across different operation data
|
||||
assert output_sharding_spec.sharding_sequence[1] == weight_sharding_spec.sharding_sequence[0]
|
||||
assert input_sharding_spec.sharding_sequence[0] == output_sharding_spec.sharding_sequence[0]
|
||||
assert input_sharding_spec.sharding_sequence[2:] == output_sharding_spec.sharding_sequence[2:]
|
||||
assert input_sharding_spec.sharding_sequence[1] == weight_sharding_spec.sharding_sequence[1]
|
||||
|
||||
if bias:
|
||||
assert bias_sharding_spec.sharding_sequence[-1] == weight_sharding_spec.sharding_sequence[0]
|
||||
assert bias_sharding_spec.sharding_sequence[-1] == output_sharding_spec.sharding_sequence[1]
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_conv_module_handler()
|
||||
|
Reference in New Issue
Block a user