mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-16 22:52:25 +00:00
[fx] Add linear metainfo class for auto parallel (#1783)
* [fx] metainfo class for auto parallel * [fx] add unit test for linear metainfo * [fx] fix bwd param for linear * [fx] modify unit test * [fx] modify unit test * [fx] modify import * [fx] modify import * [fx] modify import * [fx] move meta profiler to auto parallel
This commit is contained in:
@@ -32,7 +32,7 @@ def addmm_flop_jit(inputs: List[Any], outputs: List[Any]) -> Number:
|
||||
# inputs is a list of length 3.
|
||||
input_shapes = [v.shape for v in inputs[1:3]]
|
||||
# input_shapes[0]: [batch size, input feature dimension]
|
||||
# input_shapes[1]: [batch size, output feature dimension]
|
||||
# input_shapes[1]: [input feature dimension, output feature dimension]
|
||||
assert len(input_shapes[0]) == 2, input_shapes[0]
|
||||
assert len(input_shapes[1]) == 2, input_shapes[1]
|
||||
batch_size, input_dim = input_shapes[0]
|
||||
|
Reference in New Issue
Block a user