[FX] refactor experimental tracer and adapt it with hf models (#3157)

* pass gpt trace and meta_prop

* pass t5 trace and meta_prop

* [FX] refactor experimental tracer and adapt it with hf models

* pass all mainstream model zoo

* fix CI

* fix CI

* fix CI

* fix CI

* fix CI

* fix CI

* fix CI

* fix CI

* skip tests

* fix CI

* using packaging version

* polish
This commit is contained in:
YuliangLiu0306
2023-03-22 10:40:33 +08:00
committed by GitHub
parent b429529365
commit f57d34958b
28 changed files with 1014 additions and 863 deletions

View File

@@ -111,7 +111,24 @@ class ShapeProp(torch.fx.Interpreter):
with self.global_hook:
r = getattr(self, n.op)(n.target, args, kwargs)
unwrap_fn = lambda elem: elem._tensor if isinstance(elem, MetaTensor) else elem
def unwrap_fn(elem):
def _convert_meta(t: torch.Tensor):
if t.device == 'meta':
return t
else:
return t.to('meta')
if isinstance(elem, MetaTensor):
return _convert_meta(elem._tensor)
elif isinstance(elem, torch.Tensor):
return _convert_meta(elem)
else:
return elem
# unwrap_fn = lambda elem: elem._tensor if isinstance(elem, MetaTensor) else elem
is_pure_tensor = lambda elem: isinstance(elem, MetaTensor) and not isinstance(elem, torch.nn.Parameter)
n_info = MetaInfo(n)
n_info.outputs = _normalize_tuple(r)