mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-04 02:26:51 +00:00
[fx] refactor tracer to trace complete graph (#1342)
* [CLI] add CLI launcher
* Revert "[CLI] add CLI launcher"
This reverts commit df7e6506d4
.
* [fx] refactor tracer to trace complete graph
* add comments and solve conflicts.
This commit is contained in:
@@ -2,6 +2,7 @@ import operator
|
||||
import torch
|
||||
from torch.fx.proxy import Proxy, Attribute
|
||||
from typing import List, Union, Any
|
||||
from colossalai.fx.tracer.meta_patch import meta_patched_function
|
||||
|
||||
__all__ = ['ColoProxy']
|
||||
|
||||
@@ -45,6 +46,14 @@ class ColoProxy(Proxy):
|
||||
self._assert_has_meta_data()
|
||||
return len(self.meta_data)
|
||||
|
||||
def __int__(self):
|
||||
self._assert_has_meta_data()
|
||||
return int(self.meta_data)
|
||||
|
||||
def __float__(self):
|
||||
self._assert_has_meta_data()
|
||||
return float(self.meta_data)
|
||||
|
||||
def __bool__(self):
|
||||
self._assert_has_meta_data()
|
||||
return self.meta_data
|
||||
@@ -53,9 +62,6 @@ class ColoProxy(Proxy):
|
||||
|
||||
return ColoAttribute(self, k)
|
||||
|
||||
def __setitem__(self, indices, values):
|
||||
return self.tracer.create_proxy("call_function", operator.setitem, (self, indices, values), {})
|
||||
|
||||
def __contains__(self, key):
|
||||
if self.node.op == "placeholder":
|
||||
# this is used to handle like
|
||||
@@ -65,11 +71,26 @@ class ColoProxy(Proxy):
|
||||
return super().__contains__(key)
|
||||
|
||||
|
||||
def extract_meta(*args, **kwargs):
|
||||
"""
|
||||
This function is copied from _tracer_utils.py to avoid circular import issue.
|
||||
"""
|
||||
|
||||
def _convert(val):
|
||||
if isinstance(val, ColoProxy):
|
||||
return val.meta_data
|
||||
elif isinstance(val, (list, tuple)):
|
||||
return type(val)([_convert(ele) for ele in val])
|
||||
return val
|
||||
|
||||
new_args = [_convert(val) for val in args]
|
||||
new_kwargs = {k: _convert(v) for k, v in kwargs.items()}
|
||||
return new_args, new_kwargs
|
||||
|
||||
|
||||
class ColoAttribute(ColoProxy):
|
||||
|
||||
def __init__(self, root, attr: str):
|
||||
# this class is copied from torch.fx.Attribute
|
||||
# but inherits ColoProxy
|
||||
self.root = root
|
||||
self.attr = attr
|
||||
self.tracer = root.tracer
|
||||
@@ -78,8 +99,28 @@ class ColoAttribute(ColoProxy):
|
||||
@property
|
||||
def node(self):
|
||||
if self._node is None:
|
||||
self._node = self.tracer.create_proxy("call_function", getattr, (self.root, self.attr), {}).node
|
||||
proxy = self.tracer.create_proxy("call_function", getattr, (self.root, self.attr), {})
|
||||
if not isinstance(proxy, ColoProxy):
|
||||
meta_args, meta_kwargs = extract_meta(*(self.root, self.attr))
|
||||
meta_out = getattr(*meta_args, **meta_kwargs)
|
||||
proxy = ColoProxy(proxy.node)
|
||||
proxy.meta_data = meta_out
|
||||
self._node = proxy.node
|
||||
|
||||
return self._node
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return self.tracer.create_proxy("call_method", self.attr, (self.root,) + args, kwargs)
|
||||
proxy = self.tracer.create_proxy("call_method", self.attr, (self.root,) + args, kwargs)
|
||||
if not isinstance(proxy, ColoProxy):
|
||||
meta_args, meta_kwargs = extract_meta(*((self.root,) + args), **kwargs)
|
||||
method = getattr(meta_args[0].__class__, self.attr)
|
||||
if meta_patched_function.has(method):
|
||||
meta_target = meta_patched_function.get(method)
|
||||
elif meta_patched_function.has(target.__name__):
|
||||
meta_target = meta_patched_function.get(target.__name__)
|
||||
else:
|
||||
meta_target = method
|
||||
meta_out = meta_target(*meta_args, **meta_kwargs)
|
||||
proxy = ColoProxy(proxy.node)
|
||||
proxy.meta_data = meta_out
|
||||
return proxy
|
||||
|
Reference in New Issue
Block a user