[fx] Add linear metainfo class for auto parallel (#1783)

* [fx] metainfo class for auto parallel

* [fx] add unit test for linear metainfo

* [fx] fix bwd param for linear

* [fx] modify unit test

* [fx] modify unit test

* [fx] modify import

* [fx] modify import

* [fx] modify import

* [fx] move meta profiler to auto parallel
This commit is contained in:
Boyuan Yao
2022-11-04 10:55:09 +08:00
committed by GitHub
parent e8a9bebc87
commit 05ce3d369f
10 changed files with 516 additions and 2 deletions

View File

@@ -32,7 +32,7 @@ def addmm_flop_jit(inputs: List[Any], outputs: List[Any]) -> Number:
# inputs is a list of length 3.
input_shapes = [v.shape for v in inputs[1:3]]
# input_shapes[0]: [batch size, input feature dimension]
# input_shapes[1]: [batch size, output feature dimension]
# input_shapes[1]: [input feature dimension, output feature dimension]
assert len(input_shapes[0]) == 2, input_shapes[0]
assert len(input_shapes[1]) == 2, input_shapes[1]
batch_size, input_dim = input_shapes[0]