mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-07 20:10:17 +00:00
[autoparallel]integrate auto parallel feature with new tracer (#3408)
* [autoparallel] integrate new analyzer in module level * unify the profiling method * polish * fix no codegen bug * fix pass bug * fix liveness test * polish
This commit is contained in:
@@ -235,7 +235,28 @@ def matmul_flop_jit(inputs: List[Any], outputs: List[Any]) -> Number:
|
||||
# Inputs contains the shapes of two matrices.
|
||||
input_shapes = [v.shape for v in inputs]
|
||||
assert len(input_shapes) == 2, input_shapes
|
||||
assert input_shapes[0][-1] == input_shapes[1][-2], input_shapes
|
||||
|
||||
# There are three cases: 1) gemm, 2) gemv, 3) dot
|
||||
if all(len(shape) == 2 for shape in input_shapes):
|
||||
# gemm
|
||||
assert input_shapes[0][-1] == input_shapes[1][-2], input_shapes
|
||||
elif all(len(shape) == 1 for shape in input_shapes):
|
||||
# dot
|
||||
assert input_shapes[0][0] == input_shapes[1][0], input_shapes
|
||||
|
||||
# expand shape
|
||||
input_shapes[0] = torch.Size([1, input_shapes[0][0]])
|
||||
input_shapes[1] = torch.Size([input_shapes[1][0], 1])
|
||||
else:
|
||||
# gemv
|
||||
if len(input_shapes[0]) == 1:
|
||||
assert input_shapes[0][0] == input_shapes[1][-2], input_shapes
|
||||
input_shapes.reverse()
|
||||
else:
|
||||
assert input_shapes[1][0] == input_shapes[0][-1], input_shapes
|
||||
|
||||
# expand the shape of the vector to [batch size, 1]
|
||||
input_shapes[-1] = torch.Size([input_shapes[-1][-1], 1])
|
||||
flops = reduce(operator.mul, input_shapes[0]) * input_shapes[-1][-1]
|
||||
return flops
|
||||
|
||||
|
Reference in New Issue
Block a user