[autoparallel] implemented all matmul strategy generator (#1650)

This commit is contained in:
Frank Lee
2022-09-27 12:06:25 +08:00
committed by GitHub
parent 03978aad45
commit 30e50c8b4a
8 changed files with 440 additions and 76 deletions

View File

@@ -8,9 +8,9 @@ from colossalai.device.device_mesh import DeviceMesh
def test_linear_module_handler():
model = nn.Sequential(nn.Linear(10, 20).to('meta'))
model = nn.Sequential(nn.Linear(16, 32).to('meta'))
tracer = ColoTracer()
graph = tracer.trace(model, meta_args={"input": torch.rand(4, 10).to('meta')})
graph = tracer.trace(model, meta_args={"input": torch.rand(4, 16).to('meta')})
gm = ColoGraphModule(model, graph)
physical_mesh_id = torch.arange(0, 4)
@@ -34,32 +34,55 @@ def test_linear_module_handler():
assert mapping['input'].name == "input_1"
assert mapping['input'].data.is_meta
assert mapping['input'].data.shape == torch.Size([4, 10])
assert mapping['input'].data.shape == torch.Size([4, 16])
assert mapping['input'].type == OperationDataType.ARG
assert mapping['input'].logical_shape == torch.Size([4, 10])
assert mapping['input'].logical_shape == torch.Size([4, 16])
assert mapping['other'].name == "weight"
assert mapping['other'].data.is_meta
assert mapping['other'].data.shape == torch.Size([20, 10])
assert mapping['other'].data.shape == torch.Size([32, 16])
assert mapping['other'].type == OperationDataType.PARAM
assert mapping['other'].logical_shape == torch.Size([10, 20])
assert mapping['other'].logical_shape == torch.Size([16, 32])
assert mapping['bias'].name == "bias"
assert mapping['bias'].data.is_meta
assert mapping['bias'].data.shape == torch.Size([20])
assert mapping['bias'].data.shape == torch.Size([32])
assert mapping['bias'].type == OperationDataType.PARAM
assert mapping['other'].logical_shape == torch.Size([10, 20])
assert mapping['other'].logical_shape == torch.Size([16, 32])
assert mapping['output'].name == "_0"
assert mapping['output'].data.is_meta
assert mapping['output'].data.shape == torch.Size([4, 20])
assert mapping['output'].data.shape == torch.Size([4, 32])
assert mapping['output'].type == OperationDataType.OUTPUT
strategies_vector = handler.register_strategy()
strategy_name_list = [val.name for val in strategies_vector]
# SS = SR x RS
assert 'S0S1 = S0R x RS1' in strategy_name_list
assert 'S1S0 = S1R x RS0' in strategy_name_list
# SR = SS x SR
assert 'S0R = S0S1 x S1R' in strategy_name_list
assert 'S1R = S1S0 x S0R' in strategy_name_list
# RS = RS x SS
assert 'RS0 = RS1 x S1S0' in strategy_name_list
assert 'RS1 = RS0 x S0S1' in strategy_name_list
# RR = RS x SR
assert 'RR = RS0 x S0R' in strategy_name_list
assert 'RR = RS1 x S1R' in strategy_name_list
# RS= RR x RS
assert 'RS0 = RR x RS0' in strategy_name_list
assert 'RS1 = RR x RS1' in strategy_name_list
def test_linear_function_handler():
model = nn.Linear(10, 20).to('meta')
model = nn.Linear(16, 32).to('meta')
tracer = ColoTracer()
graph = tracer.trace(model, meta_args={"input": torch.rand(4, 10).to('meta')})
graph = tracer.trace(model, meta_args={"input": torch.rand(4, 16).to('meta')})
gm = ColoGraphModule(model, graph)
physical_mesh_id = torch.arange(0, 4)
@@ -77,27 +100,50 @@ def test_linear_function_handler():
assert mapping['input'].name == "input_1"
assert mapping['input'].data.is_meta
assert mapping['input'].data.shape == torch.Size([4, 10])
assert mapping['input'].data.shape == torch.Size([4, 16])
assert mapping['input'].type == OperationDataType.ARG
assert mapping['input'].logical_shape == torch.Size([4, 10])
assert mapping['input'].logical_shape == torch.Size([4, 16])
assert mapping['other'].name == "weight"
assert mapping['other'].data.is_meta
assert mapping['other'].data.shape == torch.Size([20, 10])
assert mapping['other'].data.shape == torch.Size([32, 16])
assert mapping['other'].type == OperationDataType.PARAM
assert mapping['other'].logical_shape == torch.Size([10, 20])
assert mapping['other'].logical_shape == torch.Size([16, 32])
assert mapping['bias'].name == "bias"
assert mapping['bias'].data.is_meta
assert mapping['bias'].data.shape == torch.Size([20])
assert mapping['bias'].data.shape == torch.Size([32])
assert mapping['bias'].type == OperationDataType.PARAM
assert mapping['other'].logical_shape == torch.Size([10, 20])
assert mapping['other'].logical_shape == torch.Size([16, 32])
assert mapping['output'].name == "linear"
assert mapping['output'].data.is_meta
assert mapping['output'].data.shape == torch.Size([4, 20])
assert mapping['output'].data.shape == torch.Size([4, 32])
assert mapping['output'].type == OperationDataType.OUTPUT
strategies_vector = handler.register_strategy()
strategy_name_list = [val.name for val in strategies_vector]
# SS = SR x RS
assert 'S0S1 = S0R x RS1' in strategy_name_list
assert 'S1S0 = S1R x RS0' in strategy_name_list
# SR = SS x SR
assert 'S0R = S0S1 x S1R' in strategy_name_list
assert 'S1R = S1S0 x S0R' in strategy_name_list
# RS = RS x SS
assert 'RS0 = RS1 x S1S0' in strategy_name_list
assert 'RS1 = RS0 x S0S1' in strategy_name_list
# RR = RS x SR
assert 'RR = RS0 x S0R' in strategy_name_list
assert 'RR = RS1 x S1R' in strategy_name_list
# RS= RR x RS
assert 'RS0 = RR x RS0' in strategy_name_list
assert 'RS1 = RR x RS1' in strategy_name_list
if __name__ == '__main__':
test_linear_module_handler()

View File

@@ -1,6 +1,3 @@
from curses import meta
from math import dist
from xml.dom import HierarchyRequestErr
from colossalai.fx.tracer import meta_patch
from colossalai.fx.tracer.tracer import ColoTracer
from colossalai.fx.tracer.meta_patch.patched_function import python_ops

View File

@@ -1,6 +1,3 @@
from curses import meta
from math import dist
from xml.dom import HierarchyRequestErr
from colossalai.fx.tracer import meta_patch
from colossalai.fx.tracer.tracer import ColoTracer
from colossalai.fx.tracer.meta_patch.patched_function import python_ops