[autoparallel] implemented linear projection strategy generator (#1639)

This commit is contained in:
Frank Lee
2022-09-26 16:58:14 +08:00
committed by GitHub
parent 154d3ef432
commit 45b39a692a
7 changed files with 564 additions and 134 deletions

View File

@@ -84,13 +84,13 @@ def test_linear_function_handler():
assert mapping['other'].name == "weight"
assert mapping['other'].data.is_meta
assert mapping['other'].data.shape == torch.Size([20, 10])
assert mapping['other'].type == OperationDataType.ARG
assert mapping['other'].type == OperationDataType.PARAM
assert mapping['other'].logical_shape == torch.Size([10, 20])
assert mapping['bias'].name == "bias"
assert mapping['bias'].data.is_meta
assert mapping['bias'].data.shape == torch.Size([20])
assert mapping['bias'].type == OperationDataType.ARG
assert mapping['bias'].type == OperationDataType.PARAM
assert mapping['other'].logical_shape == torch.Size([10, 20])
assert mapping['output'].name == "linear"
@@ -100,5 +100,5 @@ def test_linear_function_handler():
if __name__ == '__main__':
# test_linear_module_handler()
test_linear_module_handler()
test_linear_function_handler()