[autoparallel] support addmm in tracer and solver (#1961)

* [fx] patch addmm

* [autoparallel] support addmm in tracer and solver
This commit is contained in:
YuliangLiu0306
2022-11-16 14:59:18 +08:00
committed by GitHub
parent f7e276fa71
commit fea3cb661c
7 changed files with 328 additions and 21 deletions

View File

@@ -72,11 +72,21 @@ def torch_linear(input, mat2, bias=None, *, out=None):
def torch_addbmm(input, mat1, mat2, *, beta=1, alpha=1, out=None):
if out is not None:
raise ValueError("Don't support in-place abs for MetaTensor analysis")
batch_size, n, m = mat1.shape
_, n, _ = mat1.shape
_, _, p = mat2.shape
return torch.empty(n, p, device="meta")
@meta_patched_function.register(torch.addmm)
@meta_patched_function.register(torch.Tensor.addmm)
def torch_addmm(input, mat1, mat2, *, beta=1, alpha=1, out=None):
if out is not None:
raise ValueError("Don't support in-place abs for MetaTensor analysis")
n, _ = mat1.shape
_, p = mat2.shape
return torch.empty(n, p, device="meta")
@meta_patched_function.register(torch.var_mean)
def torch_var_mean(input, dim, unbiased=True, keepdim=False, *, out=None):
assert out is None, 'saving to out is not supported yet'