mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-09 13:00:52 +00:00
[autoparallel] implemented all matmul strategy generator (#1650)
This commit is contained in:
@@ -8,9 +8,9 @@ from colossalai.device.device_mesh import DeviceMesh
|
||||
|
||||
|
||||
def test_linear_module_handler():
|
||||
model = nn.Sequential(nn.Linear(10, 20).to('meta'))
|
||||
model = nn.Sequential(nn.Linear(16, 32).to('meta'))
|
||||
tracer = ColoTracer()
|
||||
graph = tracer.trace(model, meta_args={"input": torch.rand(4, 10).to('meta')})
|
||||
graph = tracer.trace(model, meta_args={"input": torch.rand(4, 16).to('meta')})
|
||||
gm = ColoGraphModule(model, graph)
|
||||
physical_mesh_id = torch.arange(0, 4)
|
||||
|
||||
@@ -34,32 +34,55 @@ def test_linear_module_handler():
|
||||
|
||||
assert mapping['input'].name == "input_1"
|
||||
assert mapping['input'].data.is_meta
|
||||
assert mapping['input'].data.shape == torch.Size([4, 10])
|
||||
assert mapping['input'].data.shape == torch.Size([4, 16])
|
||||
assert mapping['input'].type == OperationDataType.ARG
|
||||
assert mapping['input'].logical_shape == torch.Size([4, 10])
|
||||
assert mapping['input'].logical_shape == torch.Size([4, 16])
|
||||
|
||||
assert mapping['other'].name == "weight"
|
||||
assert mapping['other'].data.is_meta
|
||||
assert mapping['other'].data.shape == torch.Size([20, 10])
|
||||
assert mapping['other'].data.shape == torch.Size([32, 16])
|
||||
assert mapping['other'].type == OperationDataType.PARAM
|
||||
assert mapping['other'].logical_shape == torch.Size([10, 20])
|
||||
assert mapping['other'].logical_shape == torch.Size([16, 32])
|
||||
|
||||
assert mapping['bias'].name == "bias"
|
||||
assert mapping['bias'].data.is_meta
|
||||
assert mapping['bias'].data.shape == torch.Size([20])
|
||||
assert mapping['bias'].data.shape == torch.Size([32])
|
||||
assert mapping['bias'].type == OperationDataType.PARAM
|
||||
assert mapping['other'].logical_shape == torch.Size([10, 20])
|
||||
assert mapping['other'].logical_shape == torch.Size([16, 32])
|
||||
|
||||
assert mapping['output'].name == "_0"
|
||||
assert mapping['output'].data.is_meta
|
||||
assert mapping['output'].data.shape == torch.Size([4, 20])
|
||||
assert mapping['output'].data.shape == torch.Size([4, 32])
|
||||
assert mapping['output'].type == OperationDataType.OUTPUT
|
||||
|
||||
strategies_vector = handler.register_strategy()
|
||||
strategy_name_list = [val.name for val in strategies_vector]
|
||||
|
||||
# SS = SR x RS
|
||||
assert 'S0S1 = S0R x RS1' in strategy_name_list
|
||||
assert 'S1S0 = S1R x RS0' in strategy_name_list
|
||||
|
||||
# SR = SS x SR
|
||||
assert 'S0R = S0S1 x S1R' in strategy_name_list
|
||||
assert 'S1R = S1S0 x S0R' in strategy_name_list
|
||||
|
||||
# RS = RS x SS
|
||||
assert 'RS0 = RS1 x S1S0' in strategy_name_list
|
||||
assert 'RS1 = RS0 x S0S1' in strategy_name_list
|
||||
|
||||
# RR = RS x SR
|
||||
assert 'RR = RS0 x S0R' in strategy_name_list
|
||||
assert 'RR = RS1 x S1R' in strategy_name_list
|
||||
|
||||
# RS= RR x RS
|
||||
assert 'RS0 = RR x RS0' in strategy_name_list
|
||||
assert 'RS1 = RR x RS1' in strategy_name_list
|
||||
|
||||
|
||||
def test_linear_function_handler():
|
||||
model = nn.Linear(10, 20).to('meta')
|
||||
model = nn.Linear(16, 32).to('meta')
|
||||
tracer = ColoTracer()
|
||||
graph = tracer.trace(model, meta_args={"input": torch.rand(4, 10).to('meta')})
|
||||
graph = tracer.trace(model, meta_args={"input": torch.rand(4, 16).to('meta')})
|
||||
gm = ColoGraphModule(model, graph)
|
||||
physical_mesh_id = torch.arange(0, 4)
|
||||
|
||||
@@ -77,27 +100,50 @@ def test_linear_function_handler():
|
||||
|
||||
assert mapping['input'].name == "input_1"
|
||||
assert mapping['input'].data.is_meta
|
||||
assert mapping['input'].data.shape == torch.Size([4, 10])
|
||||
assert mapping['input'].data.shape == torch.Size([4, 16])
|
||||
assert mapping['input'].type == OperationDataType.ARG
|
||||
assert mapping['input'].logical_shape == torch.Size([4, 10])
|
||||
assert mapping['input'].logical_shape == torch.Size([4, 16])
|
||||
|
||||
assert mapping['other'].name == "weight"
|
||||
assert mapping['other'].data.is_meta
|
||||
assert mapping['other'].data.shape == torch.Size([20, 10])
|
||||
assert mapping['other'].data.shape == torch.Size([32, 16])
|
||||
assert mapping['other'].type == OperationDataType.PARAM
|
||||
assert mapping['other'].logical_shape == torch.Size([10, 20])
|
||||
assert mapping['other'].logical_shape == torch.Size([16, 32])
|
||||
|
||||
assert mapping['bias'].name == "bias"
|
||||
assert mapping['bias'].data.is_meta
|
||||
assert mapping['bias'].data.shape == torch.Size([20])
|
||||
assert mapping['bias'].data.shape == torch.Size([32])
|
||||
assert mapping['bias'].type == OperationDataType.PARAM
|
||||
assert mapping['other'].logical_shape == torch.Size([10, 20])
|
||||
assert mapping['other'].logical_shape == torch.Size([16, 32])
|
||||
|
||||
assert mapping['output'].name == "linear"
|
||||
assert mapping['output'].data.is_meta
|
||||
assert mapping['output'].data.shape == torch.Size([4, 20])
|
||||
assert mapping['output'].data.shape == torch.Size([4, 32])
|
||||
assert mapping['output'].type == OperationDataType.OUTPUT
|
||||
|
||||
strategies_vector = handler.register_strategy()
|
||||
strategy_name_list = [val.name for val in strategies_vector]
|
||||
|
||||
# SS = SR x RS
|
||||
assert 'S0S1 = S0R x RS1' in strategy_name_list
|
||||
assert 'S1S0 = S1R x RS0' in strategy_name_list
|
||||
|
||||
# SR = SS x SR
|
||||
assert 'S0R = S0S1 x S1R' in strategy_name_list
|
||||
assert 'S1R = S1S0 x S0R' in strategy_name_list
|
||||
|
||||
# RS = RS x SS
|
||||
assert 'RS0 = RS1 x S1S0' in strategy_name_list
|
||||
assert 'RS1 = RS0 x S0S1' in strategy_name_list
|
||||
|
||||
# RR = RS x SR
|
||||
assert 'RR = RS0 x S0R' in strategy_name_list
|
||||
assert 'RR = RS1 x S1R' in strategy_name_list
|
||||
|
||||
# RS= RR x RS
|
||||
assert 'RS0 = RR x RS0' in strategy_name_list
|
||||
assert 'RS1 = RR x RS1' in strategy_name_list
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_linear_module_handler()
|
@@ -1,6 +1,3 @@
|
||||
from curses import meta
|
||||
from math import dist
|
||||
from xml.dom import HierarchyRequestErr
|
||||
from colossalai.fx.tracer import meta_patch
|
||||
from colossalai.fx.tracer.tracer import ColoTracer
|
||||
from colossalai.fx.tracer.meta_patch.patched_function import python_ops
|
||||
|
@@ -1,6 +1,3 @@
|
||||
from curses import meta
|
||||
from math import dist
|
||||
from xml.dom import HierarchyRequestErr
|
||||
from colossalai.fx.tracer import meta_patch
|
||||
from colossalai.fx.tracer.tracer import ColoTracer
|
||||
from colossalai.fx.tracer.meta_patch.patched_function import python_ops
|
||||
|
Reference in New Issue
Block a user