[misc] update pre-commit and run all files (#4752)

* [misc] update pre-commit

* [misc] run pre-commit

* [misc] remove useless configuration files

* [misc] ignore cuda for clang-format
This commit is contained in:
Hongxin Liu
2023-09-19 14:20:26 +08:00
committed by GitHub
parent 3c6b831c26
commit 079bf3cb26
1268 changed files with 50037 additions and 38444 deletions

View File

@@ -12,7 +12,7 @@ from .tracer import register_tracer_impl
__all__ = []
@register_tracer_impl(F.linear, name='_bias_addition_impl')
@register_tracer_impl(F.linear, name="_bias_addition_impl")
def linear_impl(input, weight, bias=None):
if bias is None:
return F.linear(input, weight)
@@ -20,116 +20,130 @@ def linear_impl(input, weight, bias=None):
return F.linear(input, weight) + bias
@register_tracer_impl(F.conv1d, name='_bias_addition_impl')
@register_tracer_impl(F.conv1d, name="_bias_addition_impl")
def conv1d_impl(input, weight, bias=None, stride=_single(1), padding=_single(0), dilation=_single(1), groups=1):
if bias is None:
return F.conv1d(input, weight, stride=stride, padding=padding, dilation=dilation, groups=groups)
else:
return F.conv1d(input, weight, stride=stride, padding=padding, dilation=dilation, groups=groups) + bias.reshape(
(-1, 1))
(-1, 1)
)
@register_tracer_impl(F.conv2d, name='_bias_addition_impl')
@register_tracer_impl(F.conv2d, name="_bias_addition_impl")
def conv2d_impl(input, weight, bias=None, stride=_pair(1), padding=_pair(0), dilation=_pair(1), groups=1):
if bias is None:
return F.conv2d(input, weight, stride=stride, padding=padding, dilation=dilation, groups=groups)
else:
return F.conv2d(input, weight, stride=stride, padding=padding, dilation=dilation, groups=groups) + bias.reshape(
(-1, 1, 1))
(-1, 1, 1)
)
@register_tracer_impl(F.conv3d, name='_bias_addition_impl')
@register_tracer_impl(F.conv3d, name="_bias_addition_impl")
def conv3d_impl(input, weight, bias=None, stride=_triple(1), padding=_triple(0), dilation=_triple(1), groups=1):
if bias is None:
return F.conv3d(input, weight, stride=stride, padding=padding, dilation=dilation, groups=groups)
else:
return F.conv3d(input, weight, stride=stride, padding=padding, dilation=dilation, groups=groups) + bias.reshape(
(-1, 1, 1, 1))
(-1, 1, 1, 1)
)
@register_tracer_impl(F.conv_transpose1d, name='_bias_addition_impl')
def conv_transpose1d_impl(input,
weight,
bias=None,
stride=_single(1),
padding=_single(0),
output_padding=_single(0),
groups=1,
dilation=_single(1)):
@register_tracer_impl(F.conv_transpose1d, name="_bias_addition_impl")
def conv_transpose1d_impl(
input,
weight,
bias=None,
stride=_single(1),
padding=_single(0),
output_padding=_single(0),
groups=1,
dilation=_single(1),
):
if bias is None:
return F.conv_transpose1d(input,
weight,
stride=stride,
padding=padding,
output_padding=output_padding,
groups=groups,
dilation=dilation)
return F.conv_transpose1d(
input,
weight,
stride=stride,
padding=padding,
output_padding=output_padding,
groups=groups,
dilation=dilation,
)
else:
return F.conv_transpose1d(input,
weight,
stride=stride,
padding=padding,
output_padding=output_padding,
groups=groups,
dilation=dilation) + bias.reshape((-1, 1))
return F.conv_transpose1d(
input,
weight,
stride=stride,
padding=padding,
output_padding=output_padding,
groups=groups,
dilation=dilation,
) + bias.reshape((-1, 1))
@register_tracer_impl(F.conv_transpose2d, name='_bias_addition_impl')
def conv_transpose2d_impl(input,
weight,
bias=None,
stride=_pair(1),
padding=_pair(0),
output_padding=_pair(0),
groups=1,
dilation=_pair(1)):
@register_tracer_impl(F.conv_transpose2d, name="_bias_addition_impl")
def conv_transpose2d_impl(
input, weight, bias=None, stride=_pair(1), padding=_pair(0), output_padding=_pair(0), groups=1, dilation=_pair(1)
):
if bias is None:
return F.conv_transpose2d(input,
weight,
stride=stride,
padding=padding,
output_padding=output_padding,
groups=groups,
dilation=dilation)
return F.conv_transpose2d(
input,
weight,
stride=stride,
padding=padding,
output_padding=output_padding,
groups=groups,
dilation=dilation,
)
else:
return F.conv_transpose2d(input,
weight,
stride=stride,
padding=padding,
output_padding=output_padding,
groups=groups,
dilation=dilation) + bias.reshape((-1, 1, 1))
return F.conv_transpose2d(
input,
weight,
stride=stride,
padding=padding,
output_padding=output_padding,
groups=groups,
dilation=dilation,
) + bias.reshape((-1, 1, 1))
@register_tracer_impl(F.conv_transpose3d, name='_bias_addition_impl')
def conv_transpose3d_impl(input,
weight,
bias=None,
stride=_triple(1),
padding=_triple(0),
output_padding=_triple(0),
groups=1,
dilation=_triple(1)):
@register_tracer_impl(F.conv_transpose3d, name="_bias_addition_impl")
def conv_transpose3d_impl(
input,
weight,
bias=None,
stride=_triple(1),
padding=_triple(0),
output_padding=_triple(0),
groups=1,
dilation=_triple(1),
):
if bias is None:
return F.conv_transpose3d(input,
weight,
stride=stride,
padding=padding,
output_padding=output_padding,
groups=groups,
dilation=dilation)
return F.conv_transpose3d(
input,
weight,
stride=stride,
padding=padding,
output_padding=output_padding,
groups=groups,
dilation=dilation,
)
else:
return F.conv_transpose3d(input,
weight,
stride=stride,
padding=padding,
output_padding=output_padding,
groups=groups,
dilation=dilation) + bias.reshape((-1, 1, 1, 1))
return F.conv_transpose3d(
input,
weight,
stride=stride,
padding=padding,
output_padding=output_padding,
groups=groups,
dilation=dilation,
) + bias.reshape((-1, 1, 1, 1))
@register_tracer_impl(torch.addmm, name='_bias_addition_impl')
@register_tracer_impl(torch.Tensor.addmm, name='_bias_addition_impl')
@register_tracer_impl(torch.addmm, name="_bias_addition_impl")
@register_tracer_impl(torch.Tensor.addmm, name="_bias_addition_impl")
def addmm_impl(input, mat1, mat2, beta=1, alpha=1):
if alpha != 1 and beta != 1:
return F.linear(mat1, mat2.transpose(0, 1)) * alpha + input * beta
@@ -141,8 +155,8 @@ def addmm_impl(input, mat1, mat2, beta=1, alpha=1):
return F.linear(mat1, mat2.transpose(0, 1)) + input
@register_tracer_impl(torch.addbmm, name='_bias_addition_impl')
@register_tracer_impl(torch.Tensor.addbmm, name='_bias_addition_impl')
@register_tracer_impl(torch.addbmm, name="_bias_addition_impl")
@register_tracer_impl(torch.Tensor.addbmm, name="_bias_addition_impl")
def addbmm_impl(input, batch1, batch2, beta=1, alpha=1):
if alpha != 1 and beta != 1:
return torch.bmm(batch1, batch2.transpose(1, 2)) * alpha + input * beta

View File

@@ -4,6 +4,7 @@ from .tracer import register_leaf_module, register_leaf_module_impl
try:
import apex
register_leaf_module(apex.normalization.FusedLayerNorm)
register_leaf_module(apex.normalization.FusedRMSNorm)
register_leaf_module(apex.normalization.MixedFusedLayerNorm)

View File

@@ -1,10 +1,8 @@
import operator
from typing import Any, Callable, Dict, Optional, Set, Union
from typing import Any, Callable, Dict, Optional, Union
import torch
import torch.nn as nn
from torch.fx import Graph, Node, Proxy, Tracer
from torch.fx.graph import _Namespace
from torch.fx import Node, Proxy
from torch.utils._pytree import tree_map
from colossalai._analyzer._subclasses import MetaTensor
@@ -32,7 +30,7 @@ class ColoProxy(Proxy):
def __torch_function__(cls, orig_method, types, args=(), kwargs=None):
kwargs = {} if kwargs is None else kwargs
if orig_method in cls._func_dispatch:
impl = cls._func_dispatch.pop(orig_method) # avoid recursion
impl = cls._func_dispatch.pop(orig_method) # avoid recursion
proxy = impl(*args, **kwargs)
cls._func_dispatch[orig_method] = impl
return proxy
@@ -72,7 +70,7 @@ class ColoProxy(Proxy):
return ColoAttribute(self, k, getattr(self._meta_data, k, None))
def __setitem__(self, key, value):
proxy = self.tracer.create_proxy('call_function', operator.setitem, (self, key, value), {})
proxy = self.tracer.create_proxy("call_function", operator.setitem, (self, key, value), {})
proxy.meta_data = self._meta_data
return proxy
@@ -89,7 +87,6 @@ class ColoProxy(Proxy):
class ColoAttribute(ColoProxy):
def __init__(self, root, attr: str, data=None):
self.root = root
self.attr = attr
@@ -102,11 +99,11 @@ class ColoAttribute(ColoProxy):
# the node for attributes is added lazily, since most will just be method calls
# which do not rely on the getitem call
if self._node is None:
self._node = self.tracer.create_proxy('call_function', getattr, (self.root, self.attr), {}).node
self._node = self.tracer.create_proxy("call_function", getattr, (self.root, self.attr), {}).node
return self._node
def __call__(self, *args, **kwargs):
return self.tracer.create_proxy('call_method', self.attr, (self.root,) + args, kwargs)
return self.tracer.create_proxy("call_method", self.attr, (self.root,) + args, kwargs)
def __repr__(self):
return f"ColoAttribute({self.node.name}, attr={self.attr})"

View File

@@ -1,4 +1,4 @@
from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Type, Union
from typing import Any, Callable, Dict, Optional, Union
import torch
from torch.fx import Tracer
@@ -8,6 +8,7 @@ from colossalai._analyzer._subclasses import MetaTensor
try:
from ..codegen import ActivationCheckpointCodeGen
SUPPORT_ACTIVATION = True
except:
SUPPORT_ACTIVATION = False
@@ -16,7 +17,7 @@ from .tracer import ColoTracer
def _default_device():
return torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
return torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
def _current_device(module: torch.nn.Module):
@@ -144,10 +145,9 @@ def symbolic_trace(
if meta_args:
device, orig_device = _default_device(), _current_device(root)
wrap_fn = lambda elem: MetaTensor(elem, device=device) if isinstance(elem, torch.Tensor) else elem
graph = ColoTracer(trace_act_ckpt=trace_act_ckpt,
bias_addition_split=bias_addition_split).trace(root.to(device),
concrete_args=concrete_args,
meta_args=tree_map(wrap_fn, meta_args))
graph = ColoTracer(trace_act_ckpt=trace_act_ckpt, bias_addition_split=bias_addition_split).trace(
root.to(device), concrete_args=concrete_args, meta_args=tree_map(wrap_fn, meta_args)
)
if trace_act_ckpt and SUPPORT_ACTIVATION:
graph.set_codegen(ActivationCheckpointCodeGen())
root.to(orig_device)

View File

@@ -20,11 +20,10 @@ def _truncate_suffix(s: str):
import re
# FIXME: don't know why but torch.fx always gets a suffix like '_1' in the name
return re.sub(r'_\d+$', '', s)
return re.sub(r"_\d+$", "", s)
def register_tracer_impl(func: Callable[..., Any], name: Optional[str] = '_custom_impl'):
def register_tracer_impl(func: Callable[..., Any], name: Optional[str] = "_custom_impl"):
def wrapper(impl):
assert hasattr(ColoTracer, name), f"Cannot register {func.__name__} in ColoTracer.{name}"
getattr(ColoTracer, name)[func] = impl
@@ -34,7 +33,6 @@ def register_tracer_impl(func: Callable[..., Any], name: Optional[str] = '_custo
def register_leaf_module_impl(module: nn.Module):
def wrapper(impl):
ColoTracer._custom_leaf_module_impl[module] = impl
return impl
@@ -76,7 +74,7 @@ class ColoTracer(Tracer):
self.ckpt_regions = []
self.ckpt_idx = 0
self.mod_dir = ''
self.mod_dir = ""
# whether the tracer should split the bias_add ops into two ops
self.bias_addition_split = bias_addition_split
@@ -87,35 +85,41 @@ class ColoTracer(Tracer):
if self.bias_addition_split and type(m) in self._bias_addition_module and m.bias is not None:
return False
# user can specify which modules are leaf modules and which are not
return (type(m) not in self._custom_non_leaf_module
and (type(m) in self._custom_leaf_module or super().is_leaf_module(m, module_qualified_name)))
return type(m) not in self._custom_non_leaf_module and (
type(m) in self._custom_leaf_module or super().is_leaf_module(m, module_qualified_name)
)
def call_module(self, m: torch.nn.Module, forward: Callable[..., Any], args: Tuple[Any, ...],
kwargs: Dict[str, Any]) -> Any:
def call_module(
self, m: torch.nn.Module, forward: Callable[..., Any], args: Tuple[Any, ...], kwargs: Dict[str, Any]
) -> Any:
curr_dir = self.mod_dir
self.mod_dir = 'self.' + self.path_of_module(m)
self.mod_dir = "self." + self.path_of_module(m)
rst = super().call_module(m, forward, args, kwargs)
self.mod_dir = curr_dir
return rst
def proxy(self, node: Node) -> 'ColoProxy':
def proxy(self, node: Node) -> "ColoProxy":
return ColoProxy(node, self)
def create_proxy(self,
kind: str,
target: Target,
args: Tuple[Any, ...],
kwargs: Dict[str, Any],
name: Optional[str] = None,
type_expr: Optional[Any] = None,
proxy_factory_fn: Callable[[Node], 'Proxy'] = None):
def create_proxy(
self,
kind: str,
target: Target,
args: Tuple[Any, ...],
kwargs: Dict[str, Any],
name: Optional[str] = None,
type_expr: Optional[Any] = None,
proxy_factory_fn: Callable[[Node], "Proxy"] = None,
):
proxy: ColoProxy = super().create_proxy(kind, target, args, kwargs, name, type_expr, proxy_factory_fn)
unwrap_fn = lambda p: p.meta_data if isinstance(p, ColoProxy) else p
if kind == 'placeholder':
proxy.meta_data = self.meta_args[target] if target in self.meta_args else self.concrete_args.get(
_truncate_suffix(target), None)
elif kind == 'get_attr':
if kind == "placeholder":
proxy.meta_data = (
self.meta_args[target]
if target in self.meta_args
else self.concrete_args.get(_truncate_suffix(target), None)
)
elif kind == "get_attr":
self.disable_module_getattr = True
try:
attr_itr = self.root
@@ -125,20 +129,21 @@ class ColoTracer(Tracer):
proxy.meta_data = attr_itr
finally:
self.disable_module_getattr = False
elif kind == 'call_function':
elif kind == "call_function":
proxy.meta_data = target(*tree_map(unwrap_fn, args), **tree_map(unwrap_fn, kwargs))
elif kind == 'call_method':
elif kind == "call_method":
self.disable_module_getattr = True
try:
if target == '__call__':
if target == "__call__":
proxy.meta_data = unwrap_fn(args[0])(*tree_map(unwrap_fn, args[1:]), **tree_map(unwrap_fn, kwargs))
else:
if target not in _TensorPropertyMethod:
proxy._meta_data = getattr(unwrap_fn(args[0]), target)(*tree_map(unwrap_fn, args[1:]),
**tree_map(unwrap_fn, kwargs))
proxy._meta_data = getattr(unwrap_fn(args[0]), target)(
*tree_map(unwrap_fn, args[1:]), **tree_map(unwrap_fn, kwargs)
)
finally:
self.disable_module_getattr = False
elif kind == 'call_module':
elif kind == "call_module":
mod = self.root.get_submodule(target)
self.disable_module_getattr = True
try:
@@ -158,11 +163,12 @@ class ColoTracer(Tracer):
n_info = MetaInfo(node, mod_dir=self.mod_dir, activation_checkpoint=tuple(self.ckpt_regions))
return node
def trace(self,
root: torch.nn.Module,
concrete_args: Optional[Dict[str, torch.Tensor]] = None,
meta_args: Optional[Dict[str, torch.Tensor]] = None) -> Graph:
def trace(
self,
root: torch.nn.Module,
concrete_args: Optional[Dict[str, torch.Tensor]] = None,
meta_args: Optional[Dict[str, torch.Tensor]] = None,
) -> Graph:
if meta_args is None:
meta_args = {}
@@ -177,9 +183,7 @@ class ColoTracer(Tracer):
non_concrete_arg_names = sig_names - concrete_arg_names
# update concrete args with default values
for k, v in sig.parameters.items():
if k in sig_names - meta_arg_names and \
k not in concrete_args and \
v.default is not inspect.Parameter.empty:
if k in sig_names - meta_arg_names and k not in concrete_args and v.default is not inspect.Parameter.empty:
concrete_args[k] = v.default
def _check_arg_name_valid(names: Iterable[str]):
@@ -194,9 +198,9 @@ class ColoTracer(Tracer):
self.meta_args = meta_args
with self._torch_factory_override(), self._tracer_override(), torch.no_grad():
self.mod_dir = 'self'
self.mod_dir = "self"
self.graph = super().trace(root, concrete_args=concrete_args)
self.mod_dir = ''
self.mod_dir = ""
self.graph.lint()
for node in self.graph.nodes:
@@ -266,17 +270,17 @@ class ColoTracer(Tracer):
# override the torch factory functions to create a proxy when the method
# is called during ``symbolic_trace()``.
def wrap_factory_method(target):
@functools.wraps(target)
def wrapper(*args, **kwargs):
is_proxy = any(isinstance(p, ColoProxy) for p in args) | any(
isinstance(p, ColoProxy) for p in kwargs.values())
isinstance(p, ColoProxy) for p in kwargs.values()
)
if is_proxy:
# if the arg is a proxy, then need to record this function called on this proxy
# e.g. torch.ones(size) where size is an input proxy
self.disable_module_getattr = True
try:
proxy = self.create_proxy('call_function', target, args, kwargs)
proxy = self.create_proxy("call_function", target, args, kwargs)
finally:
self.disable_module_getattr = False
return proxy
@@ -341,10 +345,13 @@ class ColoTracer(Tracer):
if attr_val is p:
if n not in parameter_proxy_cache:
kwargs = {}
if 'proxy_factory_fn' in inspect.signature(self.create_proxy).parameters:
kwargs['proxy_factory_fn'] = (None if not self.param_shapes_constant else
lambda node: ColoProxy(self, node, n, attr_val))
val_proxy = self.create_proxy('get_attr', n, (), {}, **kwargs) # type: ignore[arg-type]
if "proxy_factory_fn" in inspect.signature(self.create_proxy).parameters:
kwargs["proxy_factory_fn"] = (
None
if not self.param_shapes_constant
else lambda node: ColoProxy(self, node, n, attr_val)
)
val_proxy = self.create_proxy("get_attr", n, (), {}, **kwargs) # type: ignore[arg-type]
parameter_proxy_cache[n] = val_proxy
return parameter_proxy_cache[n]
return None
@@ -355,8 +362,9 @@ class ColoTracer(Tracer):
return maybe_buffer_proxy
if isinstance(attr_val, torch.nn.Parameter):
maybe_parameter_proxy = maybe_get_proxy_for_attr(attr_val, self.root.named_parameters(),
parameter_proxy_cache)
maybe_parameter_proxy = maybe_get_proxy_for_attr(
attr_val, self.root.named_parameters(), parameter_proxy_cache
)
if maybe_parameter_proxy is not None:
return maybe_parameter_proxy