[autoparallel] add cost graph class (#1481)

* [autoparallel] add cost graph class

* polish code
This commit is contained in:
YuliangLiu0306
2022-08-25 17:19:59 +08:00
committed by GitHub
parent 4b03c25f85
commit 413c053453
6 changed files with 141 additions and 5 deletions

View File

@@ -15,7 +15,7 @@ def torch_matmul(input, other, *, out=None):
shape = (input.size(0), other.size(1))
elif d1 == 1 and d2 == 2:
shape = (other.size(1),)
elif d1 == 2 and d1 == 1:
elif d1 == 2 and d2 == 1:
shape = (input.size(0),)
else:
max_length = max(input.dim(), other.dim())