mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-02 01:28:31 +00:00
[autoparallel] implemented linear projection strategy generator (#1639)
This commit is contained in:
@@ -84,13 +84,13 @@ def test_linear_function_handler():
|
||||
assert mapping['other'].name == "weight"
|
||||
assert mapping['other'].data.is_meta
|
||||
assert mapping['other'].data.shape == torch.Size([20, 10])
|
||||
assert mapping['other'].type == OperationDataType.ARG
|
||||
assert mapping['other'].type == OperationDataType.PARAM
|
||||
assert mapping['other'].logical_shape == torch.Size([10, 20])
|
||||
|
||||
assert mapping['bias'].name == "bias"
|
||||
assert mapping['bias'].data.is_meta
|
||||
assert mapping['bias'].data.shape == torch.Size([20])
|
||||
assert mapping['bias'].type == OperationDataType.ARG
|
||||
assert mapping['bias'].type == OperationDataType.PARAM
|
||||
assert mapping['other'].logical_shape == torch.Size([10, 20])
|
||||
|
||||
assert mapping['output'].name == "linear"
|
||||
@@ -100,5 +100,5 @@ def test_linear_function_handler():
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# test_linear_module_handler()
|
||||
test_linear_module_handler()
|
||||
test_linear_function_handler()
|
||||
|
Reference in New Issue
Block a user