[fx] refactor tracer to trace complete graph (#1342)

* [CLI] add CLI launcher

* Revert "[CLI] add CLI launcher"

This reverts commit df7e6506d4.

* [fx] refactor tracer to trace complete graph

* add comments and solve conflicts.
This commit is contained in:
YuliangLiu0306
2022-07-20 11:20:38 +08:00
committed by GitHub
parent 2cc1175c76
commit 942c8cd1fb
9 changed files with 160 additions and 20 deletions

View File

@@ -2,6 +2,7 @@ import operator
import torch
from torch.fx.proxy import Proxy, Attribute
from typing import List, Union, Any
from colossalai.fx.tracer.meta_patch import meta_patched_function
__all__ = ['ColoProxy']
@@ -45,6 +46,14 @@ class ColoProxy(Proxy):
self._assert_has_meta_data()
return len(self.meta_data)
def __int__(self):
self._assert_has_meta_data()
return int(self.meta_data)
def __float__(self):
self._assert_has_meta_data()
return float(self.meta_data)
def __bool__(self):
self._assert_has_meta_data()
return self.meta_data
@@ -53,9 +62,6 @@ class ColoProxy(Proxy):
return ColoAttribute(self, k)
def __setitem__(self, indices, values):
return self.tracer.create_proxy("call_function", operator.setitem, (self, indices, values), {})
def __contains__(self, key):
if self.node.op == "placeholder":
# this is used to handle like
@@ -65,11 +71,26 @@ class ColoProxy(Proxy):
return super().__contains__(key)
def extract_meta(*args, **kwargs):
"""
This function is copied from _tracer_utils.py to avoid circular import issue.
"""
def _convert(val):
if isinstance(val, ColoProxy):
return val.meta_data
elif isinstance(val, (list, tuple)):
return type(val)([_convert(ele) for ele in val])
return val
new_args = [_convert(val) for val in args]
new_kwargs = {k: _convert(v) for k, v in kwargs.items()}
return new_args, new_kwargs
class ColoAttribute(ColoProxy):
def __init__(self, root, attr: str):
# this class is copied from torch.fx.Attribute
# but inherits ColoProxy
self.root = root
self.attr = attr
self.tracer = root.tracer
@@ -78,8 +99,28 @@ class ColoAttribute(ColoProxy):
@property
def node(self):
if self._node is None:
self._node = self.tracer.create_proxy("call_function", getattr, (self.root, self.attr), {}).node
proxy = self.tracer.create_proxy("call_function", getattr, (self.root, self.attr), {})
if not isinstance(proxy, ColoProxy):
meta_args, meta_kwargs = extract_meta(*(self.root, self.attr))
meta_out = getattr(*meta_args, **meta_kwargs)
proxy = ColoProxy(proxy.node)
proxy.meta_data = meta_out
self._node = proxy.node
return self._node
def __call__(self, *args, **kwargs):
return self.tracer.create_proxy("call_method", self.attr, (self.root,) + args, kwargs)
proxy = self.tracer.create_proxy("call_method", self.attr, (self.root,) + args, kwargs)
if not isinstance(proxy, ColoProxy):
meta_args, meta_kwargs = extract_meta(*((self.root,) + args), **kwargs)
method = getattr(meta_args[0].__class__, self.attr)
if meta_patched_function.has(method):
meta_target = meta_patched_function.get(method)
elif meta_patched_function.has(target.__name__):
meta_target = meta_patched_function.get(target.__name__)
else:
meta_target = method
meta_out = meta_target(*meta_args, **meta_kwargs)
proxy = ColoProxy(proxy.node)
proxy.meta_data = meta_out
return proxy