[auto-parallel] refactoring ColoTracer (#2118)

* add meta_data_computing

* add checkpoint_annotation

* rename proxy.data to proxy.meta_data and add bias addition pass

* polish code

* delete meta_prop_pass invoke and rename ori_node to orig_node

* add TracerType

* unify meta data computing

* delete TracerType

* handle setitem operation

* operator.setitem
This commit is contained in:
Zihao 2023-01-04 14:44:22 +08:00 committed by GitHub
parent 32253315b4
commit 3a02b46447
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1,6 +1,8 @@
import enum import enum
import functools import functools
import operator
import inspect import inspect
from contextlib import contextmanager
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
import torch import torch
@ -8,6 +10,15 @@ from torch.fx import Graph, Node, Proxy, Tracer
from torch.utils._pytree import tree_map from torch.utils._pytree import tree_map
from colossalai.fx import ColoGraphModule, compatibility, is_compatible_with_meta from colossalai.fx import ColoGraphModule, compatibility, is_compatible_with_meta
from colossalai.fx.tracer._tracer_utils import extract_meta, is_element_in_list
from colossalai.fx.tracer.bias_addition_patch import func_to_func_dict, method_to_func_dict, module_to_func_dict
from colossalai.fx.tracer.registry import (
bias_addition_function,
bias_addition_method,
bias_addition_module,
meta_patched_function,
meta_patched_module,
)
if is_compatible_with_meta(): if is_compatible_with_meta():
from colossalai.fx.profiler import MetaTensor from colossalai.fx.profiler import MetaTensor
@ -31,18 +42,6 @@ def _truncate_suffix(s: str):
return re.sub(r'_\d+$', '', s) return re.sub(r'_\d+$', '', s)
def is_element_in_list(elements: Union[List[Any], Any], list_: List[Any]):
if isinstance(elements, (tuple, list, set)):
for ele in elements:
if ele not in list_:
return False, ele
else:
if elements not in list_:
return False, elements
return True, None
def default_device(): def default_device():
return torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu') return torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
@ -52,24 +51,24 @@ class ColoProxy(Proxy):
def __init__(self, *args, data=None, **kwargs): def __init__(self, *args, data=None, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self._data = data self._meta_data = data
@property @property
def data(self): def meta_data(self):
return self._data return self._meta_data
@data.setter @meta_data.setter
def data(self, args): def meta_data(self, args):
wrap_fn = lambda x: MetaTensor(x) if isinstance(x, torch.Tensor) else x wrap_fn = lambda x: MetaTensor(x) if isinstance(x, torch.Tensor) else x
self._data = tree_map(wrap_fn, args) self._meta_data = tree_map(wrap_fn, args)
@classmethod @classmethod
def __torch_function__(cls, orig_method, types, args=(), kwargs=None): def __torch_function__(cls, orig_method, types, args=(), kwargs=None):
proxy = cls.from_torch_proxy(super().__torch_function__(orig_method, types, args, kwargs)) proxy = cls.from_torch_proxy(super().__torch_function__(orig_method, types, args, kwargs))
unwrap_fn = lambda p: p.data if isinstance(p, ColoProxy) else p unwrap_fn = lambda p: p.meta_data if isinstance(p, ColoProxy) else p
kwargs = {} if kwargs is None else kwargs kwargs = {} if kwargs is None else kwargs
if proxy.data is None: if proxy.meta_data is None:
proxy.data = orig_method(*tree_map(unwrap_fn, args), **tree_map(unwrap_fn, kwargs)) proxy.meta_data = orig_method(*tree_map(unwrap_fn, args), **tree_map(unwrap_fn, kwargs))
return proxy return proxy
@classmethod @classmethod
@ -77,28 +76,33 @@ class ColoProxy(Proxy):
return cls(proxy.node, proxy.tracer) return cls(proxy.node, proxy.tracer)
def __repr__(self): def __repr__(self):
return f"ColoProxy({self.node.name}, data={self.data})" return f"ColoProxy({self.node.name}, meta_data={self.meta_data})"
def __len__(self): def __len__(self):
return len(self.data) return len(self.meta_data)
def __int__(self): def __int__(self):
return int(self.data) return int(self.meta_data)
def __index__(self): def __index__(self):
try: try:
return int(self.data) return int(self.meta_data)
except: except:
return torch.zeros(self.data.shape, dtype=torch.bool).numpy().__index__() return torch.zeros(self.meta_data.shape, dtype=torch.bool).numpy().__index__()
def __float__(self): def __float__(self):
return float(self.data) return float(self.meta_data)
def __bool__(self): def __bool__(self):
return self.data return self.meta_data
def __getattr__(self, k): def __getattr__(self, k):
return ColoAttribute(self, k, getattr(self._data, k, None)) return ColoAttribute(self, k, getattr(self._meta_data, k, None))
def __setitem__(self, key, value):
proxy = self.tracer.create_proxy('call_function', operator.setitem, (self, key, value), {})
proxy.meta_data = self._meta_data
return proxy
def __contains__(self, key): def __contains__(self, key):
if self.node.op == "placeholder": if self.node.op == "placeholder":
@ -109,26 +113,26 @@ class ColoProxy(Proxy):
return super().__contains__(key) return super().__contains__(key)
def __isinstancecheck__(self, type): def __isinstancecheck__(self, type):
return isinstance(self.data, type) return isinstance(self.meta_data, type)
@property @property
def shape(self): def shape(self):
return self.data.shape return self.meta_data.shape
@property @property
def ndim(self): def ndim(self):
return self.data.ndim return self.meta_data.ndim
@property @property
def device(self): def device(self):
proxy = self.tracer.create_proxy('call_function', getattr, (self, 'device'), {}) proxy = self.tracer.create_proxy('call_function', getattr, (self, 'device'), {})
proxy.data = self.data.device proxy.meta_data = self.meta_data.device
return proxy return proxy
@property @property
def dtype(self): def dtype(self):
proxy = self.tracer.create_proxy('call_function', getattr, (self, 'dtype'), {}) proxy = self.tracer.create_proxy('call_function', getattr, (self, 'dtype'), {})
proxy.data = self.data.dtype proxy.meta_data = self.meta_data.dtype
return proxy return proxy
def to(self, *args, **kwargs): def to(self, *args, **kwargs):
@ -148,7 +152,7 @@ class ColoAttribute(ColoProxy):
self.root = root self.root = root
self.attr = attr self.attr = attr
self.tracer = root.tracer self.tracer = root.tracer
self._data = data self._meta_data = data
self._node: Optional[Node] = None self._node: Optional[Node] = None
@property @property
@ -174,6 +178,12 @@ class ColoTracer(Tracer):
self._disable_module_getattr = False self._disable_module_getattr = False
self.proxy_buffer_attributes = True self.proxy_buffer_attributes = True
# whether the tracer will record the usage of torch.utils.checkpoint
self.trace_act_ckpt = trace_act_ckpt
# whether the current tracing occurs within the activation checkpoint functions
self.inside_torch_checkpoint_func = False
self.act_ckpt_region_count = 0
def proxy(self, node: Node) -> 'ColoProxy': def proxy(self, node: Node) -> 'ColoProxy':
return ColoProxy(node, self) return ColoProxy(node, self)
@ -185,10 +195,11 @@ class ColoTracer(Tracer):
name: Optional[str] = None, name: Optional[str] = None,
type_expr: Optional[Any] = None, type_expr: Optional[Any] = None,
proxy_factory_fn: Callable[[Node], 'Proxy'] = None): proxy_factory_fn: Callable[[Node], 'Proxy'] = None):
proxy: ColoProxy = super().create_proxy(kind, target, args, kwargs, name, type_expr, proxy_factory_fn) proxy: ColoProxy = super().create_proxy(kind, target, args, kwargs, name, type_expr, proxy_factory_fn)
unwrap_fn = lambda p: p.data if isinstance(p, ColoProxy) else p unwrap_fn = lambda p: p.meta_data if isinstance(p, ColoProxy) else p
if kind == 'placeholder': if kind == 'placeholder':
proxy.data = self.meta_args[target] if target in self.meta_args else self.concrete_args.get( proxy.meta_data = self.meta_args[target] if target in self.meta_args else self.concrete_args.get(
_truncate_suffix(target), None) _truncate_suffix(target), None)
elif kind == 'get_attr': elif kind == 'get_attr':
self._disable_module_getattr = True self._disable_module_getattr = True
@ -197,32 +208,39 @@ class ColoTracer(Tracer):
atoms = target.split(".") atoms = target.split(".")
for atom in atoms: for atom in atoms:
attr_itr = getattr(attr_itr, atom) attr_itr = getattr(attr_itr, atom)
proxy.data = attr_itr proxy.meta_data = attr_itr
finally: finally:
self._disable_module_getattr = False self._disable_module_getattr = False
elif kind == 'call_function': elif kind == 'call_function':
proxy.data = target(*tree_map(unwrap_fn, args), **tree_map(unwrap_fn, kwargs)) proxy.meta_data = target(*tree_map(unwrap_fn, args), **tree_map(unwrap_fn, kwargs))
elif kind == 'call_method': elif kind == 'call_method':
self._disable_module_getattr = True self._disable_module_getattr = True
try: try:
if target == '__call__': if target == '__call__':
proxy.data = unwrap_fn(args[0])(*tree_map(unwrap_fn, args[1:]), **tree_map(unwrap_fn, kwargs)) proxy.meta_data = unwrap_fn(args[0])(*tree_map(unwrap_fn, args[1:]), **tree_map(unwrap_fn, kwargs))
else: else:
if target not in _TensorPropertyMethod: if target not in _TensorPropertyMethod:
proxy._data = getattr(unwrap_fn(args[0]), target)(*tree_map(unwrap_fn, args[1:]), proxy._meta_data = getattr(unwrap_fn(args[0]), target)(*tree_map(unwrap_fn, args[1:]),
**tree_map(unwrap_fn, kwargs)) **tree_map(unwrap_fn, kwargs))
finally: finally:
self._disable_module_getattr = False self._disable_module_getattr = False
elif kind == 'call_module': elif kind == 'call_module':
mod = self.root.get_submodule(target) mod = self.root.get_submodule(target)
unwrap_fn = lambda p: p.data if isinstance(p, ColoProxy) else p
self._disable_module_getattr = True self._disable_module_getattr = True
try: try:
proxy.data = mod.forward(*tree_map(unwrap_fn, args), **tree_map(unwrap_fn, kwargs)) proxy.meta_data = mod.forward(*tree_map(unwrap_fn, args), **tree_map(unwrap_fn, kwargs))
finally: finally:
self._disable_module_getattr = True self._disable_module_getattr = False
return proxy return proxy
def create_node(self, *args, **kwargs) -> Node:
node = super().create_node(*args, **kwargs)
if self.inside_torch_checkpoint_func:
# annotate the activation checkpoint module
node.meta['activation_checkpoint'] = self.act_ckpt_region_count
return node
def trace(self, def trace(self,
root: torch.nn.Module, root: torch.nn.Module,
concrete_args: Optional[Dict[str, torch.Tensor]] = None, concrete_args: Optional[Dict[str, torch.Tensor]] = None,
@ -263,11 +281,42 @@ class ColoTracer(Tracer):
self.concrete_args = concrete_args self.concrete_args = concrete_args
self.meta_args = meta_args self.meta_args = meta_args
with _TorchTensorOverride(self): with _TorchTensorOverride(self), self.trace_activation_checkpoint(enabled=self.trace_act_ckpt):
self.graph = super().trace(root, concrete_args=concrete_args) self.graph = super().trace(root, concrete_args=concrete_args)
self.graph.lint() self.graph.lint()
return self.graph return self.graph
@contextmanager
def trace_activation_checkpoint(self, enabled: bool):
if enabled:
orig_ckpt_func = torch.utils.checkpoint.CheckpointFunction
class PatchedCheckpointFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, run_function, preserve_rng_state, *args):
# signal that the current tracing occurs within activaton checkpoint part
self.inside_torch_checkpoint_func = True
out = run_function(*args)
self.inside_torch_checkpoint_func = False
self.act_ckpt_region_count += 1
return out
@staticmethod
def backward(ctx: Any, *grad_outputs: Any) -> Any:
raise NotImplementedError(
"We do not implement the backward pass as we only trace the forward pass.")
# override the checkpoint function
torch.utils.checkpoint.CheckpointFunction = PatchedCheckpointFunction
yield
if enabled:
# recover the checkpoint function upon exit
torch.utils.checkpoint.CheckpointFunction = orig_ckpt_func
def _post_check(self, non_concrete_arg_names: Set[str]): def _post_check(self, non_concrete_arg_names: Set[str]):
# This is necessary because concrete args are added as input to the traced module since # This is necessary because concrete args are added as input to the traced module since
# https://github.com/pytorch/pytorch/pull/55888. # https://github.com/pytorch/pytorch/pull/55888.
@ -392,3 +441,202 @@ class _TorchTensorOverride(object):
def __exit__(self, exc_type, exc_val, exc_tb): def __exit__(self, exc_type, exc_val, exc_tb):
for name, (wrapper, orig) in self.overrides.items(): for name, (wrapper, orig) in self.overrides.items():
setattr(torch, name, orig) setattr(torch, name, orig)
def meta_prop_pass(gm: ColoGraphModule,
root: torch.nn.Module,
meta_args: Optional[Dict[str, Any]] = None,
concrete_args: Optional[Dict[str, torch.Tensor]] = None):
if meta_args is None:
meta_args = {}
if concrete_args is None:
concrete_args = {}
# check concrete and meta args have valid names
sig = inspect.signature(root.forward)
sig_names = set(sig.parameters.keys())
meta_arg_names = set(meta_args.keys())
# update concrete args with default values
non_meta_arg_names = sig_names - meta_arg_names
for k, v in sig.parameters.items():
if k in non_meta_arg_names and \
k not in concrete_args and \
v.default is not inspect.Parameter.empty:
concrete_args[k] = v.default
for node in gm.graph.nodes:
node._meta_data = _meta_data_computing(meta_args, concrete_args, root, node.op, node.target, node.args,
node.kwargs)
def _meta_data_computing(meta_args, concrete_args, root, kind, target, args, kwargs):
unwrap_fn = lambda n: n._meta_data if isinstance(n, Node) else n
if kind == 'placeholder':
meta_out = meta_args[target] if target in meta_args else concrete_args.get(
_truncate_suffix(target), None)
elif kind == 'get_attr':
attr_itr = root
atoms = target.split(".")
for atom in atoms:
attr_itr = getattr(attr_itr, atom)
meta_out = attr_itr
elif kind == 'call_function':
meta_out = target(*tree_map(unwrap_fn, args), **tree_map(unwrap_fn, kwargs))
elif kind == 'call_method':
if target == '__call__':
meta_out = unwrap_fn(args[0])(*tree_map(unwrap_fn, args[1:]), **tree_map(unwrap_fn, kwargs))
else:
if target not in _TensorPropertyMethod:
meta_out = getattr(unwrap_fn(args[0]), target)(*tree_map(unwrap_fn, args[1:]),
**tree_map(unwrap_fn, kwargs))
elif kind == 'call_module':
mod = root.get_submodule(target)
meta_out = mod.forward(*tree_map(unwrap_fn, args), **tree_map(unwrap_fn, kwargs))
else:
meta_out = None
return meta_out
def _meta_data_computing_v0(meta_args, root, kind, target, args, kwargs):
if kind == "placeholder" and target in meta_args and meta_args[target].is_meta:
meta_out = meta_args[target]
return meta_out
if target in [getattr(torch, torch_func) for torch_func in _TorchNewMethod]:
# NOTE: tensor constructors in PyTorch define the `device` argument as
# *kwargs-only*. That is why this works. If you add methods to
# _TORCH_METHODS_TO_PATCH that do not define `device` as kwarg-only,
# this will break and you will likely see issues where we cannot infer
# the size of the output.
if "device" in kwargs:
kwargs["device"] = "meta"
try:
unwrap_fn = lambda n: n._meta_data if isinstance(n, Node) else n
args_metas = tree_map(unwrap_fn, args)
kwargs_metas = tree_map(unwrap_fn, kwargs)
if kind == "call_function":
# fetch patched function
if meta_patched_function.has(target):
meta_target = meta_patched_function.get(target)
elif meta_patched_function.has(target.__name__):
# use name for some builtin op like @ (matmul)
meta_target = meta_patched_function.get(target.__name__)
else:
meta_target = target
meta_out = meta_target(*args_metas, **kwargs_metas)
if isinstance(meta_out, torch.Tensor):
meta_out = meta_out.to(device="meta")
elif kind == "call_method":
method = getattr(args_metas[0].__class__, target)
# fetch patched method
if meta_patched_function.has(method):
meta_target = meta_patched_function.get(method)
else:
meta_target = method
meta_out = meta_target(*args_metas, **kwargs_metas)
elif kind == "call_module":
mod = root.get_submodule(target)
mod_type = type(mod)
if meta_patched_module.has(mod_type):
meta_out = meta_patched_module.get(mod_type)(mod, *args_metas, **kwargs_metas)
else:
meta_out = mod(*args_metas, **kwargs_metas)
elif kind == "get_attr":
attr_itr = root
atoms = target.split(".")
for atom in atoms:
attr_itr = getattr(attr_itr, atom)
if isinstance(attr_itr, torch.nn.parameter.Parameter):
meta_out = torch.nn.Parameter(attr_itr.to(device="meta"))
elif isinstance(attr_itr, torch.Tensor):
meta_out = attr_itr.to(device="meta")
else:
meta_out = attr_itr
else:
return None
except Exception as e:
raise RuntimeError(f"Could not compute metadata for {kind} target {target}: {e}")
return meta_out
def bias_addition_pass(gm: ColoGraphModule, root_model: torch.nn.Module, meta_args: Optional[Dict[str, Any]]=None):
result_graph = Graph()
value_remap = {}
unwrap_fn = lambda n: n._meta_data if isinstance(n, Node) else n
for orig_node in gm.graph.nodes:
assert hasattr(orig_node, "_meta_data")
kind = orig_node.op
target = orig_node.target
args = orig_node.args
kwargs = orig_node.kwargs
args_metas = tree_map(unwrap_fn, args)
tracer = ColoTracer()
tracer.graph = Graph(tracer_cls=ColoTracer)
tracer.root = root_model
def wrap_fn(n):
if isinstance(n, Node):
proxy = ColoProxy(n, tracer)
proxy.meta_data = n._meta_data
return proxy
return n
args_proxy = tree_map(wrap_fn, args)
kwargs_proxy = tree_map(wrap_fn, kwargs)
handle = None
if kind == "call_function":
if bias_addition_function.has(target):
if target == torch.nn.functional.linear:
if 'bias' in kwargs and kwargs['bias'] is not None:
function_to_substitute = func_to_func_dict[target]
handle = bias_addition_function.get(target)(tracer, target, args_proxy, kwargs_proxy, function_to_substitute)
else:
function_to_substitute = func_to_func_dict[target]
handle = bias_addition_function.get(target)(tracer, target, args_proxy, kwargs_proxy, function_to_substitute)
elif bias_addition_function.has(target.__name__):
# use name for some builtin op like @ (matmul)
function_to_substitute = func_to_func_dict[target]
handle = bias_addition_function.get(target.__name__)(tracer, target, args_proxy, kwargs_proxy, function_to_substitute)
elif kind == "call_method":
method = getattr(args_metas[0].__class__, target)
if bias_addition_method.has(method):
function_to_substitute = method_to_func_dict[method]
handle = bias_addition_method.get(method)(tracer, target, args_proxy, kwargs_proxy, function_to_substitute)
elif kind == "call_module":
# if not hasattr(self, "orig_forward"):
# raise AttributeError(f"{self} does not have an attribute called orig_forward")
mod = gm.get_submodule(target)
mod_type = type(mod)
if bias_addition_module.has(mod_type) and mod.bias is not None:
function_to_substitute = module_to_func_dict[mod_type]
handle = bias_addition_module.get(mod_type)(tracer, target, args_proxy, kwargs_proxy, function_to_substitute)
if handle is not None:
handle.generate()
for node_inserted in tracer.graph.nodes:
value_remap[node_inserted] = result_graph.node_copy(node_inserted, lambda n : value_remap[n])
last_node = value_remap[node_inserted]
value_remap[orig_node] = last_node
else:
value_remap[orig_node] = result_graph.node_copy(orig_node, lambda n : value_remap[n])
del tracer
gm.graph = result_graph
gm.recompile()
meta_prop_pass(gm, root_model, meta_args)