[autoparallel]integrate auto parallel feature with new tracer (#3408)

* [autoparallel] integrate new analyzer in module level

* unify the profiling method

* polish

* fix no codegen bug

* fix pass bug

* fix liveness test

* polish
This commit is contained in:
YuliangLiu0306
2023-04-04 17:40:45 +08:00
committed by GitHub
parent 573af84184
commit ffcdbf0f65
46 changed files with 396 additions and 470 deletions

View File

@@ -237,7 +237,14 @@ class ShapeProp(torch.fx.Interpreter):
Returns:
Any: The value returned from executing the Module
"""
wrap_fn = lambda elem: MetaTensor(elem, device=device)
# wrap_fn = lambda elem: MetaTensor(elem, device=device)
def wrap_fn(elem, device=device):
if isinstance(elem, torch.Tensor):
return MetaTensor(elem, device=device)
else:
return elem
with self._mode:
return super().run(*tree_map(wrap_fn, args))