[fx] refactored the file structure of patched function and module (#1238)

* [fx] refactored the file structure of patched function and module

* polish code
This commit is contained in:
Frank Lee
2022-07-12 15:01:58 +08:00
committed by GitHub
parent 17ed33350b
commit 7531c6271f
15 changed files with 353 additions and 318 deletions

View File

@@ -0,0 +1,24 @@
import operator
import torch
from ..registry import meta_patched_function
@meta_patched_function.register(operator.getitem)
def operator_getitem(a, b):
# copied from huggingface.utils.fx
def to_concrete(t):
if isinstance(t, torch.Tensor):
concrete = torch.ones_like(t, device="cpu")
if concrete.dtype in [torch.float16, torch.float32, torch.float64, torch.int32]:
concrete = concrete.to(torch.int64)
return concrete
return t
if isinstance(a, torch.Tensor):
# TODO: infer shape without performing the computation.
if isinstance(b, tuple):
b = tuple(map(to_concrete, b))
else:
b = to_concrete(b)
return operator.getitem(torch.empty_like(a, device="cpu"), b).to("meta")
return operator.getitem(a, b)