mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 19:13:01 +00:00
[fx] refactor tracer to trace complete graph (#1342)
* [CLI] add CLI launcher
* Revert "[CLI] add CLI launcher"
This reverts commit df7e6506d4
.
* [fx] refactor tracer to trace complete graph
* add comments and solve conflicts.
This commit is contained in:
@@ -7,6 +7,7 @@ tracer.py:
|
||||
import enum
|
||||
import inspect
|
||||
import functools
|
||||
import operator
|
||||
from colossalai.fx.tracer.meta_patch import meta_patched_module
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -16,8 +17,9 @@ from torch.fx.graph import Graph
|
||||
from torch.fx.proxy import Proxy, ParameterProxy
|
||||
from ..proxy import ColoProxy
|
||||
from typing import Optional, Dict, Any
|
||||
from ._tracer_utils import is_element_in_list, extract_meta
|
||||
from ._tracer_utils import is_element_in_list, extract_meta, compute_meta_data_for_functions_proxy
|
||||
from .meta_patch import meta_patched_function, meta_patched_module
|
||||
from torch.fx.graph import magic_methods, reflectable_magic_methods
|
||||
|
||||
__all__ = ['ColoTracer']
|
||||
|
||||
@@ -61,7 +63,7 @@ class ColoTracer(Tracer):
|
||||
# Feature flag for proxying accesses to buffer values
|
||||
proxy_buffer_attributes: bool = True
|
||||
|
||||
_TORCH_METHODS_TO_PATCH = ["arange", "zeros", "ones", "full", "full_like", "eye", "empty", "tensor"]
|
||||
_TORCH_METHODS_TO_PATCH = ["arange", "zeros", "ones", "full", "full_like", "eye", "empty", "tensor", "finfo"]
|
||||
|
||||
def create_proxy(self, kind, target, args, kwargs, name=None, type_expr=None, proxy_factory_fn=None) -> ColoProxy:
|
||||
"""
|
||||
@@ -344,11 +346,15 @@ def wrap_tensor_constructor_method(target):
|
||||
for arg in args:
|
||||
if isinstance(arg, Proxy):
|
||||
return arg
|
||||
if isinstance(arg, (tuple, list)):
|
||||
return look_for_proxy(*arg)
|
||||
|
||||
# find in keyword vars
|
||||
for k, v in kwargs.items():
|
||||
if isinstance(v, Proxy):
|
||||
return v
|
||||
if isinstance(v, (tuple, list)):
|
||||
return look_for_proxy(*v)
|
||||
return None
|
||||
|
||||
@functools.wraps(target)
|
||||
@@ -358,10 +364,60 @@ def wrap_tensor_constructor_method(target):
|
||||
if proxy is not None:
|
||||
# if the arg is a proxy, then need to record this function called on this proxy
|
||||
# e.g. torch.ones(size) where size is an input proxy
|
||||
return proxy.tracer.create_proxy("call_function", target, args, kwargs)
|
||||
colo_proxy = proxy.tracer.create_proxy("call_function", target, args, kwargs)
|
||||
if not isinstance(colo_proxy, ColoProxy):
|
||||
meta_out = compute_meta_data_for_functions_proxy(target, args, kwargs)
|
||||
colo_proxy = ColoProxy(fx_proxy.node)
|
||||
colo_proxy.meta_data = meta_out
|
||||
return colo_proxy
|
||||
else:
|
||||
# this is called directly when the inputs do not contain proxy
|
||||
# e.g. torch.ones(4) where the input is static
|
||||
return target(*args, **kwargs)
|
||||
|
||||
return wrapper, target
|
||||
|
||||
|
||||
# Patched magic methods for ColoProxy, then tracer could record the magic_method like __sub__,
|
||||
# and add meta_data attribute to the created proxy.
|
||||
for method in magic_methods:
|
||||
|
||||
def _scope(method):
|
||||
|
||||
def impl(*args, **kwargs):
|
||||
|
||||
tracer = args[0].tracer
|
||||
target = getattr(operator, method)
|
||||
proxy = tracer.create_proxy('call_function', target, args, kwargs)
|
||||
if not isinstance(proxy, ColoProxy):
|
||||
meta_out = compute_meta_data_for_functions_proxy(target, args, kwargs)
|
||||
proxy = ColoProxy(proxy.node)
|
||||
proxy.meta_data = meta_out
|
||||
return proxy
|
||||
|
||||
impl.__name__ = method
|
||||
as_magic = f'__{method.strip("_")}__'
|
||||
setattr(ColoProxy, as_magic, impl)
|
||||
|
||||
_scope(method)
|
||||
|
||||
|
||||
def _define_reflectable(orig_method_name):
|
||||
method_name = f'__r{orig_method_name.strip("_")}__'
|
||||
|
||||
def impl(self, rhs):
|
||||
target = getattr(operator, orig_method_name)
|
||||
proxy = self.tracer.create_proxy('call_function', target, (rhs, self), {})
|
||||
if not isinstance(proxy, ColoProxy):
|
||||
meta_out = compute_meta_data_for_functions_proxy(target, *(rhs, self), {})
|
||||
proxy = ColoProxy(proxy.node)
|
||||
proxy.meta_data = meta_out
|
||||
return proxy
|
||||
|
||||
impl.__name__ = method_name
|
||||
impl.__qualname__ = method_name
|
||||
setattr(ColoProxy, method_name, impl)
|
||||
|
||||
|
||||
for orig_method_name in reflectable_magic_methods:
|
||||
_define_reflectable(orig_method_name)
|
||||
|
Reference in New Issue
Block a user