[fx] refactor tracer to trace complete graph (#1342)

* [CLI] add CLI launcher

* Revert "[CLI] add CLI launcher"

This reverts commit df7e6506d4.

* [fx] refactor tracer to trace complete graph

* add comments and solve conflicts.
This commit is contained in:
YuliangLiu0306
2022-07-20 11:20:38 +08:00
committed by GitHub
parent 2cc1175c76
commit 942c8cd1fb
9 changed files with 160 additions and 20 deletions

View File

@@ -2,6 +2,7 @@ import operator
import torch
from torch.fx.proxy import Proxy, Attribute
from typing import List, Union, Any
from colossalai.fx.tracer.meta_patch import meta_patched_function
__all__ = ['ColoProxy']
@@ -45,6 +46,14 @@ class ColoProxy(Proxy):
self._assert_has_meta_data()
return len(self.meta_data)
def __int__(self):
self._assert_has_meta_data()
return int(self.meta_data)
def __float__(self):
self._assert_has_meta_data()
return float(self.meta_data)
def __bool__(self):
self._assert_has_meta_data()
return self.meta_data
@@ -53,9 +62,6 @@ class ColoProxy(Proxy):
return ColoAttribute(self, k)
def __setitem__(self, indices, values):
return self.tracer.create_proxy("call_function", operator.setitem, (self, indices, values), {})
def __contains__(self, key):
if self.node.op == "placeholder":
# this is used to handle like
@@ -65,11 +71,26 @@ class ColoProxy(Proxy):
return super().__contains__(key)
def extract_meta(*args, **kwargs):
"""
This function is copied from _tracer_utils.py to avoid circular import issue.
"""
def _convert(val):
if isinstance(val, ColoProxy):
return val.meta_data
elif isinstance(val, (list, tuple)):
return type(val)([_convert(ele) for ele in val])
return val
new_args = [_convert(val) for val in args]
new_kwargs = {k: _convert(v) for k, v in kwargs.items()}
return new_args, new_kwargs
class ColoAttribute(ColoProxy):
def __init__(self, root, attr: str):
# this class is copied from torch.fx.Attribute
# but inherits ColoProxy
self.root = root
self.attr = attr
self.tracer = root.tracer
@@ -78,8 +99,28 @@ class ColoAttribute(ColoProxy):
@property
def node(self):
if self._node is None:
self._node = self.tracer.create_proxy("call_function", getattr, (self.root, self.attr), {}).node
proxy = self.tracer.create_proxy("call_function", getattr, (self.root, self.attr), {})
if not isinstance(proxy, ColoProxy):
meta_args, meta_kwargs = extract_meta(*(self.root, self.attr))
meta_out = getattr(*meta_args, **meta_kwargs)
proxy = ColoProxy(proxy.node)
proxy.meta_data = meta_out
self._node = proxy.node
return self._node
def __call__(self, *args, **kwargs):
return self.tracer.create_proxy("call_method", self.attr, (self.root,) + args, kwargs)
proxy = self.tracer.create_proxy("call_method", self.attr, (self.root,) + args, kwargs)
if not isinstance(proxy, ColoProxy):
meta_args, meta_kwargs = extract_meta(*((self.root,) + args), **kwargs)
method = getattr(meta_args[0].__class__, self.attr)
if meta_patched_function.has(method):
meta_target = meta_patched_function.get(method)
elif meta_patched_function.has(target.__name__):
meta_target = meta_patched_function.get(target.__name__)
else:
meta_target = method
meta_out = meta_target(*meta_args, **meta_kwargs)
proxy = ColoProxy(proxy.node)
proxy.meta_data = meta_out
return proxy

View File

@@ -1,5 +1,7 @@
from typing import List, Union, Any
from ..proxy import ColoProxy, ColoAttribute
import torch
from .meta_patch import meta_patched_function, meta_patched_module
__all__ = ['is_element_in_list', 'extract_meta']
@@ -29,3 +31,20 @@ def extract_meta(*args, **kwargs):
new_args = [_convert(val) for val in args]
new_kwargs = {k: _convert(v) for k, v in kwargs.items()}
return new_args, new_kwargs
def compute_meta_data_for_functions_proxy(target, args, kwargs):
args_metas, kwargs_metas = extract_meta(*args, **kwargs)
# fetch patched function
if meta_patched_function.has(target):
meta_target = meta_patched_function.get(target)
elif meta_patched_function.has(target.__name__):
meta_target = meta_patched_function.get(target.__name__)
else:
meta_target = target
meta_out = meta_target(*args_metas, **kwargs_metas)
if isinstance(meta_out, torch.Tensor):
meta_out = meta_out.to(device="meta")
return meta_out

View File

@@ -24,6 +24,11 @@ def torch_arange(*args, **kwargs):
return torch.empty((end - start) // step, dtype=dtype, device="meta")
@meta_patched_function.register(torch.finfo)
def torch_finfo(*args):
return torch.finfo(*args)
@meta_patched_function.register(torch.where)
def torch_where(condition, x, y):
# torch.where returns the broadcasted tensor of condition, x, and y,

View File

@@ -7,6 +7,7 @@ tracer.py:
import enum
import inspect
import functools
import operator
from colossalai.fx.tracer.meta_patch import meta_patched_module
import torch
import torch.nn as nn
@@ -16,8 +17,9 @@ from torch.fx.graph import Graph
from torch.fx.proxy import Proxy, ParameterProxy
from ..proxy import ColoProxy
from typing import Optional, Dict, Any
from ._tracer_utils import is_element_in_list, extract_meta
from ._tracer_utils import is_element_in_list, extract_meta, compute_meta_data_for_functions_proxy
from .meta_patch import meta_patched_function, meta_patched_module
from torch.fx.graph import magic_methods, reflectable_magic_methods
__all__ = ['ColoTracer']
@@ -61,7 +63,7 @@ class ColoTracer(Tracer):
# Feature flag for proxying accesses to buffer values
proxy_buffer_attributes: bool = True
_TORCH_METHODS_TO_PATCH = ["arange", "zeros", "ones", "full", "full_like", "eye", "empty", "tensor"]
_TORCH_METHODS_TO_PATCH = ["arange", "zeros", "ones", "full", "full_like", "eye", "empty", "tensor", "finfo"]
def create_proxy(self, kind, target, args, kwargs, name=None, type_expr=None, proxy_factory_fn=None) -> ColoProxy:
"""
@@ -344,11 +346,15 @@ def wrap_tensor_constructor_method(target):
for arg in args:
if isinstance(arg, Proxy):
return arg
if isinstance(arg, (tuple, list)):
return look_for_proxy(*arg)
# find in keyword vars
for k, v in kwargs.items():
if isinstance(v, Proxy):
return v
if isinstance(v, (tuple, list)):
return look_for_proxy(*v)
return None
@functools.wraps(target)
@@ -358,10 +364,60 @@ def wrap_tensor_constructor_method(target):
if proxy is not None:
# if the arg is a proxy, then need to record this function called on this proxy
# e.g. torch.ones(size) where size is an input proxy
return proxy.tracer.create_proxy("call_function", target, args, kwargs)
colo_proxy = proxy.tracer.create_proxy("call_function", target, args, kwargs)
if not isinstance(colo_proxy, ColoProxy):
meta_out = compute_meta_data_for_functions_proxy(target, args, kwargs)
colo_proxy = ColoProxy(fx_proxy.node)
colo_proxy.meta_data = meta_out
return colo_proxy
else:
# this is called directly when the inputs do not contain proxy
# e.g. torch.ones(4) where the input is static
return target(*args, **kwargs)
return wrapper, target
# Patched magic methods for ColoProxy, then tracer could record the magic_method like __sub__,
# and add meta_data attribute to the created proxy.
for method in magic_methods:
def _scope(method):
def impl(*args, **kwargs):
tracer = args[0].tracer
target = getattr(operator, method)
proxy = tracer.create_proxy('call_function', target, args, kwargs)
if not isinstance(proxy, ColoProxy):
meta_out = compute_meta_data_for_functions_proxy(target, args, kwargs)
proxy = ColoProxy(proxy.node)
proxy.meta_data = meta_out
return proxy
impl.__name__ = method
as_magic = f'__{method.strip("_")}__'
setattr(ColoProxy, as_magic, impl)
_scope(method)
def _define_reflectable(orig_method_name):
method_name = f'__r{orig_method_name.strip("_")}__'
def impl(self, rhs):
target = getattr(operator, orig_method_name)
proxy = self.tracer.create_proxy('call_function', target, (rhs, self), {})
if not isinstance(proxy, ColoProxy):
meta_out = compute_meta_data_for_functions_proxy(target, *(rhs, self), {})
proxy = ColoProxy(proxy.node)
proxy.meta_data = meta_out
return proxy
impl.__name__ = method_name
impl.__qualname__ = method_name
setattr(ColoProxy, method_name, impl)
for orig_method_name in reflectable_magic_methods:
_define_reflectable(orig_method_name)