[autoparallel] added addbmm handler (#1751)

This commit is contained in:
Frank Lee
2022-10-21 18:55:48 +08:00
committed by GitHub
parent 980ed21723
commit 262652c8bc
8 changed files with 353 additions and 35 deletions

View File

@@ -1,4 +1,5 @@
import torch
from ..registry import meta_patched_function
@@ -56,6 +57,16 @@ def torch_bmm(input, mat2, *, out=None):
return torch.empty(batch_size, n, p, device="meta")
@meta_patched_function.register(torch.addbmm)
@meta_patched_function.register(torch.Tensor.addbmm)
def torch_addbmm(input, mat1, mat2, *, beta=1, alpha=1, out=None):
if out is not None:
raise ValueError("Don't support in-place abs for MetaTensor analysis")
batch_size, n, m = mat1.shape
_, _, p = mat2.shape
return torch.empty(n, p, device="meta")
@meta_patched_function.register(torch.var_mean)
def torch_var_mean(input, dim, unbiased=True, keepdim=False, *, out=None):
assert out is None, 'saving to out is not supported yet'