[fx] refactor tracer to trace complete graph (#1342)

* [CLI] add CLI launcher

* Revert "[CLI] add CLI launcher"

This reverts commit df7e6506d4.

* [fx] refactor tracer to trace complete graph

* add comments and solve conflicts.
This commit is contained in:
YuliangLiu0306
2022-07-20 11:20:38 +08:00
committed by GitHub
parent 2cc1175c76
commit 942c8cd1fb
9 changed files with 160 additions and 20 deletions

View File

@@ -1,5 +1,7 @@
from typing import List, Union, Any
from ..proxy import ColoProxy, ColoAttribute
import torch
from .meta_patch import meta_patched_function, meta_patched_module
__all__ = ['is_element_in_list', 'extract_meta']
@@ -29,3 +31,20 @@ def extract_meta(*args, **kwargs):
new_args = [_convert(val) for val in args]
new_kwargs = {k: _convert(v) for k, v in kwargs.items()}
return new_args, new_kwargs
def compute_meta_data_for_functions_proxy(target, args, kwargs):
args_metas, kwargs_metas = extract_meta(*args, **kwargs)
# fetch patched function
if meta_patched_function.has(target):
meta_target = meta_patched_function.get(target)
elif meta_patched_function.has(target.__name__):
meta_target = meta_patched_function.get(target.__name__)
else:
meta_target = target
meta_out = meta_target(*args_metas, **kwargs_metas)
if isinstance(meta_out, torch.Tensor):
meta_out = meta_out.to(device="meta")
return meta_out