[misc] update pre-commit and run all files (#4752)

* [misc] update pre-commit

* [misc] run pre-commit

* [misc] remove useless configuration files

* [misc] ignore cuda for clang-format
This commit is contained in:
Hongxin Liu
2023-09-19 14:20:26 +08:00
committed by GitHub
parent 3c6b831c26
commit 079bf3cb26
1268 changed files with 50037 additions and 38444 deletions

View File

@@ -3,7 +3,7 @@
# refer to https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml
# for more meta_registrations
from typing import Callable, List, Optional, Tuple, Union
from typing import List, Optional, Union
import torch
from packaging import version
@@ -24,25 +24,23 @@ orig_empty_like = torch.empty_like
def new(*args, **kwargs):
return orig_empty(*args, **kwargs, device=torch.device('meta'))
return orig_empty(*args, **kwargs, device=torch.device("meta"))
def new_strided(*args, **kwargs):
return orig_empty_strided(*args, **kwargs, device=torch.device('meta'))
return orig_empty_strided(*args, **kwargs, device=torch.device("meta"))
def new_like(*args, **kwargs):
return orig_empty_like(*args, **kwargs, device=torch.device('meta'))
return orig_empty_like(*args, **kwargs, device=torch.device("meta"))
def register_meta(op, register_dispatcher=True):
def wrapper(f):
def add_func(op):
meta_table[op] = f
if register_dispatcher:
name = (op.__name__ if op._overloadname != "default" else op.overloadpacket.__name__)
name = op.__name__ if op._overloadname != "default" else op.overloadpacket.__name__
try:
meta_lib.impl(name, f)
except:
@@ -54,7 +52,7 @@ def register_meta(op, register_dispatcher=True):
return wrapper
if version.parse(torch.__version__) >= version.parse('1.12.0'):
if version.parse(torch.__version__) >= version.parse("1.12.0"):
# ============================== Convolutions ======================================
# https://github.com/pytorch/pytorch/pull/79834
@register_meta(aten.convolution.default)
@@ -69,7 +67,6 @@ if version.parse(torch.__version__) >= version.parse('1.12.0'):
output_padding: List[int],
groups: int,
):
def _formula(ln: int, p: int, d: int, k: int, s: int) -> int:
"""
Formula to apply to calculate the length of some dimension of the output
@@ -146,7 +143,8 @@ if version.parse(torch.__version__) >= version.parse('1.12.0'):
kernel_size[i],
stride[i],
output_padding_list[i],
))
)
)
else:
ret_shape.append(_formula(dims[i], padding[i], dilation[i], kernel_size[i], stride[i]))
return ret_shape
@@ -180,19 +178,39 @@ if version.parse(torch.__version__) >= version.parse('1.12.0'):
shape_out = calc_conv_nd_return_shape(dims, kernel_size, stride, padding, dilation)
out = input_tensor.new_empty((input_tensor.shape[0], out_channels, *shape_out))
mem_fmt = pick_memory_format()
out = out.to(memory_format=mem_fmt) # type: ignore[call-overload]
out = out.to(memory_format=mem_fmt) # type: ignore[call-overload]
return out
@register_meta(aten._convolution.default)
def meta__conv(input_tensor: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, stride: List[int],
padding: List[int], dilation: List[int], is_transposed: bool, output_padding: List[int], groups: int,
*extra_args):
def meta__conv(
input_tensor: torch.Tensor,
weight: torch.Tensor,
bias: torch.Tensor,
stride: List[int],
padding: List[int],
dilation: List[int],
is_transposed: bool,
output_padding: List[int],
groups: int,
*extra_args,
):
out = meta_conv(input_tensor, weight, bias, stride, padding, dilation, is_transposed, output_padding, groups)
return out
@register_meta(aten.convolution_backward.default)
def meta_conv_backward(grad_output: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, bias_sizes, stride,
padding, dilation, transposed, output_padding, groups, output_mask):
def meta_conv_backward(
grad_output: torch.Tensor,
input: torch.Tensor,
weight: torch.Tensor,
bias_sizes,
stride,
padding,
dilation,
transposed,
output_padding,
groups,
output_mask,
):
return new_like(input), new_like(weight), new((bias_sizes))
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/AdaptiveAveragePooling.cpp
@@ -224,7 +242,6 @@ if version.parse(torch.__version__) >= version.parse('1.12.0'):
batch_sizes,
dropout_state,
):
is_input_packed = len(batch_sizes) != 0
if is_input_packed:
seq_length = len(batch_sizes)
@@ -240,8 +257,11 @@ if version.parse(torch.__version__) >= version.parse('1.12.0'):
if is_input_packed:
out_shape = [batch_sizes_sum, out_size * num_directions]
else:
out_shape = ([mini_batch, seq_length, out_size *
num_directions] if batch_first else [seq_length, mini_batch, out_size * num_directions])
out_shape = (
[mini_batch, seq_length, out_size * num_directions]
if batch_first
else [seq_length, mini_batch, out_size * num_directions]
)
output = input.new_empty(out_shape)
cell_shape = [num_layers * num_directions, mini_batch, hidden_size]
@@ -257,15 +277,21 @@ if version.parse(torch.__version__) >= version.parse('1.12.0'):
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/RNN.cpp
@register_meta(aten._cudnn_rnn_backward.default)
def meta_cudnn_rnn_backward(input: torch.Tensor,
weight: torch.Tensor,
weight_stride0: int,
hx: torch.Tensor,
cx: Optional[torch.Tensor] = None,
*args,
**kwargs):
return new_like(input), new_like(weight), new_like(hx), new_like(cx) if cx is not None else new(
()) # (grad_input, grad_weight, grad_hx, grad_cx)
def meta_cudnn_rnn_backward(
input: torch.Tensor,
weight: torch.Tensor,
weight_stride0: int,
hx: torch.Tensor,
cx: Optional[torch.Tensor] = None,
*args,
**kwargs,
):
return (
new_like(input),
new_like(weight),
new_like(hx),
new_like(cx) if cx is not None else new(()),
) # (grad_input, grad_weight, grad_hx, grad_cx)
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Activation.cpp
# ============================== Activations =======================================
@@ -278,7 +304,7 @@ if version.parse(torch.__version__) >= version.parse('1.12.0'):
aten.hardtanh_backward.default,
]
if version.parse(torch.__version__) < version.parse('2.0.0'):
if version.parse(torch.__version__) < version.parse("2.0.0"):
_unregistered_ewise += [
aten.prelu_backward.default,
]
@@ -296,37 +322,61 @@ if version.parse(torch.__version__) >= version.parse('1.12.0'):
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/BatchNorm.cpp
@register_meta(aten.native_batch_norm_backward.default)
def meta_bn_backward(dY: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, running_mean, running_var,
save_mean, save_invstd, train, eps, output_mask):
return new_like(input), new_like(weight), new_like(weight) # (dX, dgamma, dbeta)
def meta_bn_backward(
dY: torch.Tensor,
input: torch.Tensor,
weight: torch.Tensor,
running_mean,
running_var,
save_mean,
save_invstd,
train,
eps,
output_mask,
):
return new_like(input), new_like(weight), new_like(weight) # (dX, dgamma, dbeta)
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/BatchNorm.cpp
@register_meta(aten.cudnn_batch_norm.default)
def meta_cudnn_bn(input: torch.Tensor, weight, bias, running_mean, running_var, training, momentum, eps):
n_input = input.size(1)
return new_like(input), new((n_input)), new((n_input)), new(
(0), dtype=torch.uint8) # (output, running_mean, running_var, reserve)
return (
new_like(input),
new((n_input)),
new((n_input)),
new((0), dtype=torch.uint8),
) # (output, running_mean, running_var, reserve)
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/BatchNorm.cpp
# NB: CuDNN only implements the backward algorithm for batchnorm
# in training mode (evaluation mode batchnorm has a different algorithm),
# which is why this doesn't accept a 'training' parameter.
@register_meta(aten.cudnn_batch_norm_backward.default)
def meta_cudnn_bn_backward(dY: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, running_mean, running_var,
save_mean, save_invstd, eps, reserve):
return new_like(input), new_like(weight), new_like(weight) # (dX, dgamma, dbeta)
def meta_cudnn_bn_backward(
dY: torch.Tensor,
input: torch.Tensor,
weight: torch.Tensor,
running_mean,
running_var,
save_mean,
save_invstd,
eps,
reserve,
):
return new_like(input), new_like(weight), new_like(weight) # (dX, dgamma, dbeta)
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/layer_norm.cpp
@register_meta(aten.native_layer_norm.default)
def meta_ln(input: torch.Tensor, normalized_shape, weight, bias, eps):
bs, n_input = input.size(0), input.size(1)
return new_like(input), new((bs, n_input, 1)), new((bs, n_input, 1)) # (output, running_mean, running_var)
return new_like(input), new((bs, n_input, 1)), new((bs, n_input, 1)) # (output, running_mean, running_var)
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/layer_norm.cpp
@register_meta(aten.native_layer_norm_backward.default)
def meta_ln_backward(dY: torch.Tensor, input: torch.Tensor, normalized_shape, mean, rstd, weight, bias,
grad_input_mask):
return new_like(input), new_like(weight), new_like(bias) # (dX, dgamma, dbeta)
def meta_ln_backward(
dY: torch.Tensor, input: torch.Tensor, normalized_shape, mean, rstd, weight, bias, grad_input_mask
):
return new_like(input), new_like(weight), new_like(bias) # (dX, dgamma, dbeta)
# ================================== Misc ==========================================
# Maybe incorrect
@@ -355,8 +405,9 @@ if version.parse(torch.__version__) >= version.parse('1.12.0'):
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Embedding.cpp
@register_meta(aten.embedding_dense_backward.default)
def meta_embedding_dense_backward(grad_output: torch.Tensor, indices: torch.Tensor, num_weights, padding_idx,
scale_grad_by_freq):
def meta_embedding_dense_backward(
grad_output: torch.Tensor, indices: torch.Tensor, num_weights, padding_idx, scale_grad_by_freq
):
return new((num_weights, grad_output.size(-1)), dtype=grad_output.dtype, layout=grad_output.layout)
# ============================== Dropout ===========================================
@@ -364,14 +415,14 @@ if version.parse(torch.__version__) >= version.parse('1.12.0'):
@register_meta(aten.native_dropout.default)
def meta_native_dropout_default(input: torch.Tensor, p: float, train: bool = False):
# notice that mask is bool
return new_like(input), new_like(input, dtype=torch.bool) # (output, mask)
return new_like(input), new_like(input, dtype=torch.bool) # (output, mask)
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Dropout.cpp
@register_meta(aten.native_dropout_backward.default)
def meta_native_dropout_backward_default(grad: torch.Tensor, mask: torch.Tensor, scale: float):
return new_like(grad) # (grad_in)
return new_like(grad) # (grad_in)
if version.parse(torch.__version__) < version.parse('1.13.0'):
if version.parse(torch.__version__) < version.parse("1.13.0"):
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml
@register_meta(aten.eye.m_out)
def meta_eye(n: int, m: int, out: torch.Tensor):
@@ -385,24 +436,28 @@ if version.parse(torch.__version__) >= version.parse('1.12.0'):
result: List[Optional[torch.Tensor]] = []
for i, index in enumerate(indices):
if index is not None:
assert index.dtype in [torch.long, torch.int8, torch.bool],\
"tensors used as indices must be long, byte or bool tensors"
assert index.dtype in [
torch.long,
torch.int8,
torch.bool,
], "tensors used as indices must be long, byte or bool tensors"
if index.dtype in [torch.int8, torch.bool]:
nonzero = index.nonzero()
k = len(result)
assert k + index.ndim <= self.ndim, f"too many indices for tensor of dimension {self.ndim}"
for j in range(index.ndim):
assert index.shape[j] == self.shape[
k +
j], f"The shape of the mask {index.shape} at index {i} does not match the shape of the indexed tensor {self.shape} at index {k + j}"
assert (
index.shape[j] == self.shape[k + j]
), f"The shape of the mask {index.shape} at index {i} does not match the shape of the indexed tensor {self.shape} at index {k + j}"
result.append(nonzero.select(1, j))
else:
result.append(index)
else:
result.append(index)
indices = result
assert len(
indices) <= self.ndim, f"too many indices for tensor of dimension {self.ndim} (got {len(indices)})"
assert (
len(indices) <= self.ndim
), f"too many indices for tensor of dimension {self.ndim} (got {len(indices)})"
# expand_outplace
import torch._refs as refs

View File

@@ -1,5 +1,4 @@
import torch
import torch.distributed as dist
from packaging import version
__all__ = [
@@ -48,7 +47,7 @@ _DistCommMethod = [
"scatter",
]
if version.parse(torch.__version__) >= version.parse('1.12.0'):
if version.parse(torch.__version__) >= version.parse("1.12.0"):
aten = torch.ops.aten
# TODO: dive deep here
# refer to https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/TensorShape.cpp

View File

@@ -8,7 +8,7 @@ from contextlib import contextmanager
from enum import Enum, auto
from functools import partial, reduce
from numbers import Number
from typing import Any, Callable, List, Optional, Union
from typing import Any, Callable, List, Union
import torch
from packaging import version
@@ -36,15 +36,15 @@ def _format_flops(flop):
B = 1e9
T = 1e12
if flop < K:
return f'{flop:.2f}'
return f"{flop:.2f}"
elif flop < M:
return f'{flop / K:.2f}K'
return f"{flop / K:.2f}K"
elif flop < B:
return f'{flop / M:.2f}M'
return f"{flop / M:.2f}M"
elif flop < T:
return f'{flop / B:.2f}B'
return f"{flop / B:.2f}B"
else:
return f'{flop / T:.2f}T'
return f"{flop / T:.2f}T"
def flop_count(module: Union[torch.nn.Module, Callable] = None, *args, verbose: bool = False, **kwargs) -> Number:
@@ -59,11 +59,13 @@ def flop_count(module: Union[torch.nn.Module, Callable] = None, *args, verbose:
Returns:
Number: The total number of floating point operations (FWD + BWD).
"""
maybe_inplace = (getattr(module, 'inplace', False) or kwargs.get('inplace', False)
or getattr(module, '__name__', None) in ('add_', 'mul_', 'div_', 'sub_'))
maybe_inplace = (
getattr(module, "inplace", False)
or kwargs.get("inplace", False)
or getattr(module, "__name__", None) in ("add_", "mul_", "div_", "sub_")
)
class DummyModule(torch.nn.Module):
def __init__(self, func):
super().__init__()
self.func = func
@@ -74,21 +76,20 @@ def flop_count(module: Union[torch.nn.Module, Callable] = None, *args, verbose:
total_flop_count = {Phase.FWD: 0, Phase.BWD: 0}
flop_counts = defaultdict(lambda: defaultdict(int))
parents = ['Global']
parents = ["Global"]
module = module if isinstance(module, torch.nn.Module) else DummyModule(module)
class FlopTensor(MetaTensor):
_tensor: torch.Tensor
def __repr__(self):
name = 'FlopParameter' if getattr(self, '_is_param', False) else 'FlopTensor'
name = "FlopParameter" if getattr(self, "_is_param", False) else "FlopTensor"
if self.grad_fn:
return f"{name}(..., size={tuple(self.shape)}, device='{self.device}', dtype={self.dtype}, grad_fn={self.grad_fn})"
return f"{name}(..., size={tuple(self.shape)}, device='{self.device}', dtype={self.dtype})"
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
# no_dispatch is only needed if you use enable_python_mode.
# It prevents infinite recursion.
rs = super().__torch_dispatch__(func, types, args, kwargs)
@@ -115,9 +116,7 @@ def flop_count(module: Union[torch.nn.Module, Callable] = None, *args, verbose:
return isinstance(x, torch.Tensor) and x.is_floating_point()
def create_backwards_push(name):
class PushState(torch.autograd.Function):
@staticmethod
def forward(ctx, *args):
args = tree_map(lambda x: x.clone() if isinstance(x, torch.Tensor) else x, args)
@@ -134,9 +133,7 @@ def flop_count(module: Union[torch.nn.Module, Callable] = None, *args, verbose:
return PushState.apply
def create_backwards_pop(name):
class PopState(torch.autograd.Function):
@staticmethod
def forward(ctx, *args):
args = tree_map(lambda x: x.clone() if isinstance(x, torch.Tensor) else x, args)
@@ -147,14 +144,13 @@ def flop_count(module: Union[torch.nn.Module, Callable] = None, *args, verbose:
@staticmethod
def backward(ctx, *grad_outs):
nonlocal parents
assert (parents[-1] == name)
assert parents[-1] == name
parents.pop()
return grad_outs
return PopState.apply
def enter_module(name):
def f(module, inputs):
nonlocal parents
parents.append(name)
@@ -165,10 +161,9 @@ def flop_count(module: Union[torch.nn.Module, Callable] = None, *args, verbose:
return f
def exit_module(name):
def f(module, inputs, outputs):
nonlocal parents
assert (parents[-1] == name)
assert parents[-1] == name
parents.pop()
outputs = normalize_tuple(outputs)
return create_backwards_push(name)(*outputs)
@@ -189,7 +184,7 @@ def flop_count(module: Union[torch.nn.Module, Callable] = None, *args, verbose:
for mod in flop_counts.keys():
print(f"Module: ", mod)
for k, v in flop_counts[mod].items():
print('\t', k, _format_flops(v))
print("\t", k, _format_flops(v))
print()
def detach_variables(r):
@@ -201,7 +196,7 @@ def flop_count(module: Union[torch.nn.Module, Callable] = None, *args, verbose:
def wrap(r):
if isinstance(r, torch.Tensor):
data_ptr_fn = getattr(r, '_tensor', r).data_ptr
data_ptr_fn = getattr(r, "_tensor", r).data_ptr
r = FlopTensor(detach_variables(r))
if maybe_inplace:
r = r + 0
@@ -375,8 +370,11 @@ def norm_flop_counter(affine_arg_index: int, input_arg_index: int) -> Callable:
# Inputs[0] contains the shape of the input.
input_shape = inputs[input_arg_index].shape
has_affine = inputs[affine_arg_index].shape is not None if hasattr(inputs[affine_arg_index],
'shape') else inputs[affine_arg_index]
has_affine = (
inputs[affine_arg_index].shape is not None
if hasattr(inputs[affine_arg_index], "shape")
else inputs[affine_arg_index]
)
assert 2 <= len(input_shape) <= 5, input_shape
# 5 is just a rough estimate
flop = reduce(operator.mul, input_shape) * (5 if has_affine else 4)
@@ -390,7 +388,7 @@ def batchnorm_flop_jit(inputs: List[Any], outputs: List[Any], training: bool = N
training = inputs[-3]
assert isinstance(training, bool), "Signature of aten::batch_norm has changed!"
if training:
return norm_flop_counter(1, 0)(inputs, outputs) # pyre-ignore
return norm_flop_counter(1, 0)(inputs, outputs) # pyre-ignore
has_affine = inputs[1].shape is not None
input_shape = reduce(operator.mul, inputs[0].shape)
return input_shape * (2 if has_affine else 1)
@@ -420,33 +418,30 @@ def ewise_flop_counter(input_scale: float = 1, output_scale: float = 0) -> Calla
def zero_flop_jit(*args):
"""
Count flops for zero flop layers.
Count flops for zero flop layers.
"""
return 0
if version.parse(torch.__version__) >= version.parse('1.12.0'):
if version.parse(torch.__version__) >= version.parse("1.12.0"):
flop_mapping = {
# gemm
# gemm
aten.mm.default: matmul_flop_jit,
aten.matmul.default: matmul_flop_jit,
aten.addmm.default: addmm_flop_jit,
aten.bmm.default: bmm_flop_jit,
# convolution
# convolution
aten.convolution.default: conv_flop_jit,
aten._convolution.default: conv_flop_jit,
aten.convolution_backward.default: conv_backward_flop_jit,
# normalization
# normalization
aten.native_batch_norm.default: batchnorm_flop_jit,
aten.native_batch_norm_backward.default: batchnorm_flop_jit,
aten.cudnn_batch_norm.default: batchnorm_flop_jit,
aten.cudnn_batch_norm_backward.default: partial(batchnorm_flop_jit, training=True),
aten.native_layer_norm.default: norm_flop_counter(2, 0),
aten.native_layer_norm_backward.default: norm_flop_counter(2, 0),
# pooling
# pooling
aten.avg_pool1d.default: ewise_flop_counter(1, 0),
aten.avg_pool2d.default: ewise_flop_counter(1, 0),
aten.avg_pool2d_backward.default: ewise_flop_counter(0, 1),
@@ -469,7 +464,7 @@ if version.parse(torch.__version__) >= version.parse('1.12.0'):
}
ewise_flop_aten = [
# basic op
# basic op
aten.add.Tensor,
aten.add_.Tensor,
aten.div.Tensor,
@@ -485,8 +480,7 @@ if version.parse(torch.__version__) >= version.parse('1.12.0'):
aten.sum.default,
aten.sum.dim_IntList,
aten.mean.dim,
# activation op
# activation op
aten.hardswish.default,
aten.hardswish_.default,
aten.hardswish_backward.default,
@@ -509,15 +503,12 @@ if version.parse(torch.__version__) >= version.parse('1.12.0'):
aten.tanh.default,
aten.tanh_backward.default,
aten.threshold_backward.default,
# dropout
# dropout
aten.native_dropout.default,
aten.native_dropout_backward.default,
# distribution
# distribution
aten.bernoulli_.float,
# where
# where
aten.where.self,
]
for op in ewise_flop_aten:

View File

@@ -3,12 +3,12 @@ from functools import partial
import torch
import torch.distributed as dist
from torch.types import _bool, _device, _dtype
from torch.utils._pytree import tree_flatten, tree_map
from torch.types import _device
from torch.utils._pytree import tree_map
from ._monkey_patch import _AliasATen, _DistCommMethod, _InplaceATen, _MaybeInplaceATen, _TorchOverrideableFactoryMethod
__all__ = ['MetaTensor', 'MetaTensorMode']
__all__ = ["MetaTensor", "MetaTensorMode"]
def register_storage(r, data_ptr_fn=None):
@@ -28,8 +28,7 @@ def _normalize_tuple(x):
# a hack of inplace execution in PyTorch
def _assert_alias(func):
return func in (_AliasATen + _InplaceATen + _MaybeInplaceATen # TODO: check if should be this aggressive
)
return func in (_AliasATen + _InplaceATen + _MaybeInplaceATen) # TODO: check if should be this aggressive
class MetaTensor(torch.Tensor):
@@ -65,14 +64,15 @@ class MetaTensor(torch.Tensor):
storage_offset=elem.storage_offset(),
dtype=elem.dtype,
layout=elem.layout,
device=device or (elem.device if elem.device.type != 'meta' else torch.device('cpu')),
requires_grad=requires_grad) # deceive the frontend for aten selections
device=device or (elem.device if elem.device.type != "meta" else torch.device("cpu")),
requires_grad=requires_grad,
) # deceive the frontend for aten selections
r._tensor = elem
# ...the real tensor is held as an element on the tensor.
if not r._tensor.is_meta:
val = elem.data_ptr()
data_ptr_fn = lambda: val
r._tensor = r._tensor.to(torch.device('meta'))
r._tensor = r._tensor.to(torch.device("meta"))
# only tensor not on `meta` should be copied to `meta`
register_storage(r._tensor, data_ptr_fn)
@@ -81,7 +81,7 @@ class MetaTensor(torch.Tensor):
return r
def __repr__(self):
name = 'MetaParameter' if getattr(self, '_is_param', False) else 'MetaTensor'
name = "MetaParameter" if getattr(self, "_is_param", False) else "MetaTensor"
if self.grad_fn:
return f"{name}(..., size={tuple(self.shape)}, device='{self.device}', dtype={self.dtype}, grad_fn={self.grad_fn})"
return f"{name}(..., size={tuple(self.shape)}, device='{self.device}', dtype={self.dtype})"
@@ -97,15 +97,15 @@ class MetaTensor(torch.Tensor):
x = x._tensor
elif isinstance(x, torch.Tensor):
device = x.device
x = x.to(torch.device('meta'))
x = x.to(torch.device("meta"))
return x
args = tree_map(unwrap, args)
kwargs = tree_map(unwrap, kwargs)
if 'device' in kwargs:
device = kwargs['device']
kwargs['device'] = torch.device('meta')
if "device" in kwargs:
device = kwargs["device"]
kwargs["device"] = torch.device("meta")
# run aten for backend=CPU but actually on backend=Meta
# here we detect whether or not the execution generates a physical copy
@@ -143,21 +143,21 @@ class MetaTensor(torch.Tensor):
nonlocal device
if isinstance(x, str) or isinstance(x, _device):
device = x
return torch.device('meta')
return torch.device("meta")
return x
elem = self._tensor.to(*tree_map(replace, args), **tree_map(replace, kwargs))
return MetaTensor(elem, device=device)
def cpu(self, *args, **kwargs):
if self.device.type == 'cpu':
if self.device.type == "cpu":
return self.to(*args, **kwargs)
return self.to(*args, device='cpu', **kwargs)
return self.to(*args, device="cpu", **kwargs)
def cuda(self, device=None, non_blocking=False):
if device is not None:
return self.to(device=device, non_blocking=non_blocking)
return self.to(device='cuda:0', non_blocking=non_blocking)
return self.to(device="cuda:0", non_blocking=non_blocking)
def data_ptr(self):
return self._tensor.data_ptr()
@@ -177,19 +177,17 @@ class MetaTensorMode(object):
"""
def __init__(self):
self.torch_overrides = {} # override torch.xxx
self.dist_overrides = {} # override torch.distributed.xxx
self.torch_overrides = {} # override torch.xxx
self.dist_overrides = {} # override torch.distributed.xxx
def __enter__(self):
def _dummy(*args, **kwargs):
pass
def _new(*args, orig_new=torch.empty, **kwargs):
return MetaTensor(orig_new(*args, **{
**kwargs, 'device': 'meta'
}),
device=kwargs.get('device', torch.device('cpu')))
return MetaTensor(
orig_new(*args, **{**kwargs, "device": "meta"}), device=kwargs.get("device", torch.device("cpu"))
)
for func in _TorchOverrideableFactoryMethod:
self.torch_overrides[func] = getattr(torch, func)

View File

@@ -1,4 +1,4 @@
from typing import Any, Callable, Dict, Iterable, List, Tuple
from typing import Any, Dict, List, Tuple
import torch
@@ -22,7 +22,7 @@ from torch.fx.node import Argument, Node, _get_qualified_name, _type_repr, map_a
import colossalai
from colossalai.fx._compatibility import compatibility
_register_custom_builtin('colossalai', 'import colossalai', colossalai)
_register_custom_builtin("colossalai", "import colossalai", colossalai)
def _gen_ckpt_fn_def(label, free_vars: List[str]) -> str:
@@ -43,17 +43,17 @@ def _gen_ckpt_usage(label, input_vars, output_vars, use_reentrant=True):
"""
Generate the checkpoint function call code text
"""
outputs = ', '.join(output_vars)
inputs = ', '.join(input_vars)
return f'{outputs} = torch.utils.checkpoint.checkpoint(self.checkpoint_{label}, {inputs}, use_reentrant={use_reentrant})'
outputs = ", ".join(output_vars)
inputs = ", ".join(input_vars)
return f"{outputs} = torch.utils.checkpoint.checkpoint(self.checkpoint_{label}, {inputs}, use_reentrant={use_reentrant})"
def _end_of_ckpt(node: Node, ckpt_level: int) -> bool:
"""
Check if the node could end the ckpt region at `ckpt_level`
"""
if len(node.meta['info'].activation_checkpoint) > ckpt_level:
return node.meta['info'].activation_checkpoint[ckpt_level] is not None
if len(node.meta["info"].activation_checkpoint) > ckpt_level:
return node.meta["info"].activation_checkpoint[ckpt_level] is not None
return True
@@ -94,8 +94,8 @@ def _find_nested_ckpt_regions(node_list: List[Node], ckpt_level: int = 0):
current_region = None
for idx, node in enumerate(node_list):
if len(node.meta['info'].activation_checkpoint) > ckpt_level:
act_ckpt_label = node.meta['info'].activation_checkpoint[ckpt_level]
if len(node.meta["info"].activation_checkpoint) > ckpt_level:
act_ckpt_label = node.meta["info"].activation_checkpoint[ckpt_level]
# this activation checkpoint label is not set yet
# meaning this is the first node of the activation ckpt region
@@ -131,13 +131,9 @@ def _find_nested_ckpt_regions(node_list: List[Node], ckpt_level: int = 0):
return ckpt_regions
def emit_ckpt_func(body,
ckpt_func,
node_list: List[Node],
emit_node_func,
delete_unused_value_func,
ckpt_level=0,
in_ckpt=False):
def emit_ckpt_func(
body, ckpt_func, node_list: List[Node], emit_node_func, delete_unused_value_func, ckpt_level=0, in_ckpt=False
):
"""Emit ckpt function in nested way
Args:
@@ -156,12 +152,12 @@ def emit_ckpt_func(body,
# label given by each layer, e.g. if you are currently at level (0, 1, 1)
# the label will be '0_1_1'
label = "_".join([str(idx) for idx in node_list[0].meta['info'].activation_checkpoint[:ckpt_level + 1]])
label = "_".join([str(idx) for idx in node_list[0].meta["info"].activation_checkpoint[: ckpt_level + 1]])
ckpt_fn_def = _gen_ckpt_fn_def(label, inputs)
ckpt_func.append(f'{ckpt_fn_def}\n')
ckpt_func.append(f"{ckpt_fn_def}\n")
# if there is more level to fetch
if ckpt_level + 1 < max(map(lambda node: len(node.meta['info'].activation_checkpoint), node_list)):
if ckpt_level + 1 < max(map(lambda node: len(node.meta["info"].activation_checkpoint), node_list)):
ckpt_regions = _find_nested_ckpt_regions(node_list, ckpt_level + 1)
start_idx = [item[0] for item in ckpt_regions]
end_idx = [item[1] for item in ckpt_regions]
@@ -174,33 +170,40 @@ def emit_ckpt_func(body,
break
if node_idx in start_idx:
ckpt_node_list = node_list[node_idx:end_idx[start_idx.index(node_idx)] + 1]
emit_ckpt_func(ckpt_func, ckpt_func_buffer, ckpt_node_list, emit_node_func, delete_unused_value_func,
ckpt_level + 1, True)
ckpt_node_list = node_list[node_idx : end_idx[start_idx.index(node_idx)] + 1]
emit_ckpt_func(
ckpt_func,
ckpt_func_buffer,
ckpt_node_list,
emit_node_func,
delete_unused_value_func,
ckpt_level + 1,
True,
)
node_idx += len(ckpt_node_list)
else:
node = node_list[node_idx]
emit_node_func(node, ckpt_func)
ckpt_func[-1] = ' ' + ckpt_func[-1]
ckpt_func[-1] = " " + ckpt_func[-1]
delete_unused_value_func(node, ckpt_func)
node_idx += 1
ckpt_func.append(' ' + _gen_ckpt_output(outputs) + '\n\n')
ckpt_func.append(" " + _gen_ckpt_output(outputs) + "\n\n")
ckpt_func += ckpt_func_buffer
# last level
else:
for node in node_list:
emit_node_func(node, ckpt_func)
ckpt_func[-1] = ' ' + ckpt_func[-1]
ckpt_func[-1] = " " + ckpt_func[-1]
delete_unused_value_func(node, ckpt_func)
ckpt_func.append(' ' + _gen_ckpt_output(outputs) + '\n\n')
ckpt_func.append(" " + _gen_ckpt_output(outputs) + "\n\n")
usage = _gen_ckpt_usage(label, inputs, outputs, False) + '\n'
usage = _gen_ckpt_usage(label, inputs, outputs, False) + "\n"
if in_ckpt:
usage = ' ' + usage
usage = " " + usage
body.append(usage)
@@ -229,7 +232,7 @@ def emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node_func,
# process ckpt_regions
if node_idx in start_idx:
ckpt_node_list = node_list[node_idx:end_idx[start_idx.index(node_idx)] + 1]
ckpt_node_list = node_list[node_idx : end_idx[start_idx.index(node_idx)] + 1]
emit_ckpt_func(body, ckpt_func, ckpt_node_list, emit_node_func, delete_unused_value_func)
node_idx += len(ckpt_node_list)
@@ -243,7 +246,6 @@ def emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node_func,
@compatibility(is_backward_compatible=True)
class ActivationCheckpointCodeGen(CodeGen):
def _gen_python_code(self, nodes, root_module: str, namespace: _Namespace) -> PythonCode:
free_vars: List[str] = []
body: List[str] = []
@@ -251,7 +253,7 @@ class ActivationCheckpointCodeGen(CodeGen):
wrapped_fns: Dict[str, None] = {}
# Wrap string in list to pass by reference
maybe_return_annotation: List[str] = ['']
maybe_return_annotation: List[str] = [""]
def add_global(name_hint: str, obj: Any):
"""Add an obj to be tracked as a global.
@@ -259,7 +261,7 @@ class ActivationCheckpointCodeGen(CodeGen):
Graph, like functions or types.
Returns: the global name that should be used to reference 'obj' in generated source.
"""
if _is_from_torch(obj) and obj != torch.device: # to support registering torch.device
if _is_from_torch(obj) and obj != torch.device: # to support registering torch.device
# HACK: workaround for how torch custom ops are registered. We
# can't import them like normal modules so they must retain their
# fully qualified name.
@@ -281,16 +283,16 @@ class ActivationCheckpointCodeGen(CodeGen):
def type_repr(o: Any):
if o == ():
# Empty tuple is used for empty tuple type annotation Tuple[()]
return '()'
return "()"
typename = _type_repr(o)
if hasattr(o, '__origin__'):
if hasattr(o, "__origin__"):
# This is a generic type, e.g. typing.List[torch.Tensor]
origin_type = _origin_type_map.get(o.__origin__, o.__origin__)
origin_typename = add_global(_type_repr(origin_type), origin_type)
if hasattr(o, '__args__'):
if hasattr(o, "__args__"):
# Assign global names for each of the inner type variables.
args = [type_repr(arg) for arg in o.__args__]
@@ -309,19 +311,18 @@ class ActivationCheckpointCodeGen(CodeGen):
return add_global(typename, o)
def _format_args(args: Tuple[Argument, ...], kwargs: Dict[str, Argument]) -> str:
def _get_repr(arg):
# Handle NamedTuples (if it has `_fields`) via add_global.
if isinstance(arg, tuple) and hasattr(arg, '_fields'):
if isinstance(arg, tuple) and hasattr(arg, "_fields"):
qualified_name = _get_qualified_name(type(arg))
global_name = add_global(qualified_name, type(arg))
return f"{global_name}{repr(tuple(arg))}"
return repr(arg)
args_s = ', '.join(_get_repr(a) for a in args)
kwargs_s = ', '.join(f'{k} = {_get_repr(v)}' for k, v in kwargs.items())
args_s = ", ".join(_get_repr(a) for a in args)
kwargs_s = ", ".join(f"{k} = {_get_repr(v)}" for k, v in kwargs.items())
if args_s and kwargs_s:
return f'{args_s}, {kwargs_s}'
return f"{args_s}, {kwargs_s}"
return args_s or kwargs_s
# Run through reverse nodes and record the first instance of a use
@@ -347,82 +348,94 @@ class ActivationCheckpointCodeGen(CodeGen):
not used in the remainder of the code are freed and the memory usage
of the code is optimal.
"""
if user.op == 'placeholder':
if user.op == "placeholder":
return
if user.op == 'output':
body.append('\n')
if user.op == "output":
body.append("\n")
return
nodes_to_delete = user_to_last_uses.get(user, [])
if len(nodes_to_delete):
to_delete_str = ' = '.join([repr(n) for n in nodes_to_delete] + ['None'])
body.append(f'; {to_delete_str}\n')
to_delete_str = " = ".join([repr(n) for n in nodes_to_delete] + ["None"])
body.append(f"; {to_delete_str}\n")
else:
body.append('\n')
body.append("\n")
# NOTE: we add a variable to distinguish body and ckpt_func
def emit_node(node: Node, body):
maybe_type_annotation = '' if node.type is None else f' : {type_repr(node.type)}'
if node.op == 'placeholder':
maybe_type_annotation = "" if node.type is None else f" : {type_repr(node.type)}"
if node.op == "placeholder":
assert isinstance(node.target, str)
maybe_default_arg = '' if not node.args else f' = {repr(node.args[0])}'
free_vars.append(f'{node.target}{maybe_type_annotation}{maybe_default_arg}')
raw_name = node.target.replace('*', '')
maybe_default_arg = "" if not node.args else f" = {repr(node.args[0])}"
free_vars.append(f"{node.target}{maybe_type_annotation}{maybe_default_arg}")
raw_name = node.target.replace("*", "")
if raw_name != repr(node):
body.append(f'{repr(node)} = {raw_name}\n')
body.append(f"{repr(node)} = {raw_name}\n")
return
elif node.op == 'call_method':
elif node.op == "call_method":
assert isinstance(node.target, str)
body.append(f'{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.target)}'
f'({_format_args(node.args[1:], node.kwargs)})')
body.append(
f"{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.target)}"
f"({_format_args(node.args[1:], node.kwargs)})"
)
return
elif node.op == 'call_function':
elif node.op == "call_function":
assert callable(node.target)
# pretty print operators
if node.target.__module__ == '_operator' and node.target.__name__ in magic_methods:
if node.target.__module__ == "_operator" and node.target.__name__ in magic_methods:
assert isinstance(node.args, tuple)
body.append(f'{repr(node)}{maybe_type_annotation} = '
f'{magic_methods[node.target.__name__].format(*(repr(a) for a in node.args))}')
body.append(
f"{repr(node)}{maybe_type_annotation} = "
f"{magic_methods[node.target.__name__].format(*(repr(a) for a in node.args))}"
)
return
# pretty print inplace operators; required for jit.script to work properly
# not currently supported in normal FX graphs, but generated by torchdynamo
if node.target.__module__ == '_operator' and node.target.__name__ in inplace_methods:
body.append(f'{inplace_methods[node.target.__name__].format(*(repr(a) for a in node.args))}; '
f'{repr(node)}{maybe_type_annotation} = {repr(node.args[0])}')
if node.target.__module__ == "_operator" and node.target.__name__ in inplace_methods:
body.append(
f"{inplace_methods[node.target.__name__].format(*(repr(a) for a in node.args))}; "
f"{repr(node)}{maybe_type_annotation} = {repr(node.args[0])}"
)
return
qualified_name = _get_qualified_name(node.target)
global_name = add_global(qualified_name, node.target)
# special case for getattr: node.args could be 2-argument or 3-argument
# 2-argument: attribute access; 3-argument: fall through to attrib function call with default value
if global_name == 'getattr' and \
isinstance(node.args, tuple) and \
isinstance(node.args[1], str) and \
node.args[1].isidentifier() and \
len(node.args) == 2:
if (
global_name == "getattr"
and isinstance(node.args, tuple)
and isinstance(node.args[1], str)
and node.args[1].isidentifier()
and len(node.args) == 2
):
body.append(
f'{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.args[1])}')
f"{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.args[1])}"
)
return
body.append(
f'{repr(node)}{maybe_type_annotation} = {global_name}({_format_args(node.args, node.kwargs)})')
if node.meta.get('is_wrapped', False):
f"{repr(node)}{maybe_type_annotation} = {global_name}({_format_args(node.args, node.kwargs)})"
)
if node.meta.get("is_wrapped", False):
wrapped_fns.setdefault(global_name)
return
elif node.op == 'call_module':
elif node.op == "call_module":
assert isinstance(node.target, str)
body.append(f'{repr(node)}{maybe_type_annotation} = '
f'{_format_target(root_module, node.target)}({_format_args(node.args, node.kwargs)})')
body.append(
f"{repr(node)}{maybe_type_annotation} = "
f"{_format_target(root_module, node.target)}({_format_args(node.args, node.kwargs)})"
)
return
elif node.op == 'get_attr':
elif node.op == "get_attr":
assert isinstance(node.target, str)
body.append(f'{repr(node)}{maybe_type_annotation} = {_format_target(root_module, node.target)}')
body.append(f"{repr(node)}{maybe_type_annotation} = {_format_target(root_module, node.target)}")
return
elif node.op == 'output':
elif node.op == "output":
if node.type is not None:
maybe_return_annotation[0] = f" -> {type_repr(node.type)}"
body.append(self.generate_output(node.args[0]))
return
raise NotImplementedError(f'node: {node.op} {node.target}')
raise NotImplementedError(f"node: {node.op} {node.target}")
# Modified for activation checkpointing
ckpt_func = []
@@ -432,13 +445,13 @@ class ActivationCheckpointCodeGen(CodeGen):
# If the Graph has no non-placeholder nodes, no lines for the body
# have been emitted. To continue to have valid Python code, emit a
# single pass statement
body.append('pass\n')
body.append("pass\n")
if len(wrapped_fns) > 0:
wrap_name = add_global('wrap', torch.fx.wrap)
wrap_stmts = '\n'.join([f'{wrap_name}("{name}")' for name in wrapped_fns])
wrap_name = add_global("wrap", torch.fx.wrap)
wrap_stmts = "\n".join([f'{wrap_name}("{name}")' for name in wrapped_fns])
else:
wrap_stmts = ''
wrap_stmts = ""
if self._body_transformer:
body = self._body_transformer(body)
@@ -447,11 +460,11 @@ class ActivationCheckpointCodeGen(CodeGen):
add_global(name, value)
prologue = self.gen_fn_def(free_vars, maybe_return_annotation[0])
prologue = ''.join(ckpt_func) + prologue
prologue = "".join(ckpt_func) + prologue
prologue = prologue
code = ''.join(body)
code = '\n'.join(' ' + line for line in code.split('\n'))
code = "".join(body)
code = "\n".join(" " + line for line in code.split("\n"))
fn_code = f"""
{wrap_stmts}
{prologue}

View File

@@ -13,6 +13,7 @@ from torch.fx.graph import PythonCode
try:
from torch.fx.graph import _PyTreeCodeGen
SUPPORT_PT_CODEGEN = True
except ImportError:
SUPPORT_PT_CODEGEN = False
@@ -24,7 +25,6 @@ from torch.nn.modules.module import _addindent
# This is a copy of torch.fx.graph_module._WrappedCall.
# It should be removed when we stop supporting torch < 1.12.0.
class _WrappedCall:
def __init__(self, cls, cls_call):
self.cls = cls
self.cls_call = cls_call
@@ -50,12 +50,14 @@ class _WrappedCall:
# constituent substrings of the error message
tb_repr = traceback.format_exc()
custom_msg = ("Call using an FX-traced Module, "
f"line {err_lineno} of the traced Module's "
"generated forward function:")
before_err = "".join(all_src_lines[err_lineno - 2:err_lineno])
custom_msg = (
"Call using an FX-traced Module, "
f"line {err_lineno} of the traced Module's "
"generated forward function:"
)
before_err = "".join(all_src_lines[err_lineno - 2 : err_lineno])
marker = "~" * err_line_len + "~~~ <--- HERE"
err_and_after_err = "\n".join(all_src_lines[err_lineno:err_lineno + 2])
err_and_after_err = "\n".join(all_src_lines[err_lineno : err_lineno + 2])
# joined message
return "\n".join([tb_repr, custom_msg, before_err, marker, err_and_after_err])
@@ -65,11 +67,14 @@ class _WrappedCall:
if self.cls_call is not None:
return self.cls_call(obj, *args, **kwargs)
else:
return super(self.cls, obj).__call__(*args, **kwargs) # type: ignore[misc]
return super(self.cls, obj).__call__(*args, **kwargs) # type: ignore[misc]
except Exception as e:
assert e.__traceback__
topmost_framesummary: traceback.FrameSummary = \
traceback.StackSummary.extract(traceback.walk_tb(e.__traceback__))[-1] # type: ignore[arg-type]
topmost_framesummary: traceback.FrameSummary = traceback.StackSummary.extract(
traceback.walk_tb(e.__traceback__)
)[
-1
] # type: ignore[arg-type]
if "eval_with_key" in topmost_framesummary.filename:
print(_WrappedCall._generate_error_message(topmost_framesummary), file=sys.stderr)
raise e.with_traceback(None)
@@ -99,10 +104,9 @@ class ColoGraphModule(torch.fx.GraphModule):
code.
"""
def __init__(self,
root: Union[torch.nn.Module, Dict[str, Any]],
graph: torch.fx.Graph,
class_name: str = 'GraphModule'):
def __init__(
self, root: Union[torch.nn.Module, Dict[str, Any]], graph: torch.fx.Graph, class_name: str = "GraphModule"
):
super().__init__(root, graph, class_name)
def bind(self, ckpt_def, globals):
@@ -134,7 +138,7 @@ class ColoGraphModule(torch.fx.GraphModule):
if SUPPORT_PT_CODEGEN and isinstance(self._graph._codegen, _PyTreeCodeGen):
self._in_spec = self._graph._codegen.pytree_info.in_spec
self._out_spec = self._graph._codegen.pytree_info.out_spec
python_code = self._graph.python_code(root_module='self')
python_code = self._graph.python_code(root_module="self")
self._code = python_code.src
# To split ckpt functions code and forward code
@@ -157,8 +161,8 @@ class ColoGraphModule(torch.fx.GraphModule):
# bypass patching of torch.nn.Module.__call__ done while symbolic tracing.
cls_call = cls.__call__ if "__call__" in vars(cls) else None
if '_wrapped_call' not in vars(cls):
cls._wrapped_call = _WrappedCall(cls, cls_call) # type: ignore[attr-defined]
if "_wrapped_call" not in vars(cls):
cls._wrapped_call = _WrappedCall(cls, cls_call) # type: ignore[attr-defined]
def call_wrapped(self, *args, **kwargs):
return self._wrapped_call(self, *args, **kwargs)
@@ -182,7 +186,7 @@ class ColoGraphModule(torch.fx.GraphModule):
"""
folder = Path(folder)
Path(folder).mkdir(exist_ok=True)
torch.save(self.state_dict(), folder / 'state_dict.pt')
torch.save(self.state_dict(), folder / "state_dict.pt")
tab = " " * 4
# we add import colossalai here
@@ -208,10 +212,10 @@ class {module_name}(torch.nn.Module):
for module_name, module in self.named_children():
module_str = _gen_model_repr(module_name, module)
if module_str is None:
module_file = folder / f'{module_name}.pt'
module_file = folder / f"{module_name}.pt"
torch.save(module, module_file)
blobified_modules.append(module_name)
module_repr = module.__repr__().replace('\r', ' ').replace('\n', ' ')
module_repr = module.__repr__().replace("\r", " ").replace("\n", " ")
module_str = f"torch.load(r'{module_file}') # {module_repr}"
model_str += f"{tab*2}self.{module_name} = {module_str}\n"
@@ -228,12 +232,14 @@ class {module_name}(torch.nn.Module):
model_str += f"{tab*2}self.load_state_dict(torch.load(r'{folder}/state_dict.pt'))\n"
model_str += f"{_addindent(self.code, 4)}\n"
module_file = folder / 'module.py'
module_file = folder / "module.py"
module_file.write_text(model_str)
init_file = folder / '__init__.py'
init_file.write_text('from .module import *')
init_file = folder / "__init__.py"
init_file.write_text("from .module import *")
if len(blobified_modules) > 0:
warnings.warn("Was not able to save the following children modules as reprs -"
f"saved as pickled files instead: {blobified_modules}")
warnings.warn(
"Was not able to save the following children modules as reprs -"
f"saved as pickled files instead: {blobified_modules}"
)

View File

@@ -1,9 +1,9 @@
from dataclasses import dataclass, field
from typing import Callable, ClassVar, Dict, List, Optional, Tuple, Union
from typing import Dict, List, Optional, Tuple, Union
import torch
from torch.autograd.profiler_util import _format_memory, _format_time
from torch.fx import Graph, GraphModule, Node
from torch.autograd.profiler_util import _format_memory
from torch.fx import Node
from colossalai._analyzer.envs import MeshConfig
@@ -85,12 +85,12 @@ class MetaInfo:
node: Node
# directory
mod_dir: str = ''
mod_dir: str = ""
# ctx[data_ptr] = Tensor
# mark the storage for ctx.save_for_backward
global_ctx: Dict[str, torch.Tensor] = field(default_factory=lambda: {}) # globally shared
curr_ctx: Dict[str, torch.Tensor] = field(default_factory=lambda: {}) # global_ctx till this node
global_ctx: Dict[str, torch.Tensor] = field(default_factory=lambda: {}) # globally shared
curr_ctx: Dict[str, torch.Tensor] = field(default_factory=lambda: {}) # global_ctx till this node
# should be updated after each graph manipulation
# ============================== Update ====================================
@@ -100,7 +100,7 @@ class MetaInfo:
inputs: Tuple[torch.Tensor] = ()
outputs: Tuple[torch.Tensor] = ()
is_alias: Tuple[bool] = () # whether the output is an alias of input
is_alias: Tuple[bool] = () # whether the output is an alias of input
# compute cost
fwd_flop: Optional[int] = 0
@@ -112,29 +112,29 @@ class MetaInfo:
# should keep the same whenever manipulated
# ============================= Invariant ==================================
activation_checkpoint: Tuple[torch.Tensor] = () # (region_0, region_1, ...) support nested codegen
activation_checkpoint: Tuple[torch.Tensor] = () # (region_0, region_1, ...) support nested codegen
to_offload: Optional[bool] = False
sharding_spec: str = 'RR'
sharding_spec: str = "RR"
def __new__(cls, node: Node, **kwargs):
orig_init = cls.__init__
# if initialized, return the existing one
# should disable the __init__ function
if node.meta.get('info', None) is not None:
if node.meta.get("info", None) is not None:
def _dummy(self, *args, **kwargs):
if getattr(self, '_is_init', False):
if getattr(self, "_is_init", False):
self._is_init = True
orig_init(self, *args, **kwargs)
cls.__init__ = orig_init
cls.__init__ = _dummy
return node.meta['info']
return node.meta["info"]
return super().__new__(cls)
def __post_init__(self):
self.node.meta['info'] = self
self.node.meta["info"] = self
@property
def fwd_time(self, tflops: float = MeshConfig.TFLOPS, bandwidth: float = MeshConfig.BANDWIDTH):
@@ -188,24 +188,26 @@ class MetaInfo:
return compute_size_in_bytes(self.inputs)
def __repr__(self):
s = f'Node {self.node.name}'
s = f"Node {self.node.name}"
if self.parameters:
s += f'\n\thas parameter of size {_format_memory(self.param_size)}'
s += f"\n\thas parameter of size {_format_memory(self.param_size)}"
if self.buffers:
s += f'\n\thas buffer of size {_format_memory(self.buffer_size)}'
s += f"\n\thas buffer of size {_format_memory(self.buffer_size)}"
if self.output_size:
s += f'\n\thas output activation of size {_format_memory(self.output_size)}'
s += f"\n\thas output activation of size {_format_memory(self.output_size)}"
# if self.total_size:
# s += f'\n\thas total activation of size {_format_memory(self.total_size)}'
if self.temp_size:
s += f'\n\thas temp activation of size {_format_memory(self.temp_size)}'
s += f"\n\thas temp activation of size {_format_memory(self.temp_size)}"
if self.backward_size:
s += f'\n\thas backward activation of size {_format_memory(self.backward_size)}'
s += f'\n\tfwd_flop = {self.fwd_flop}'\
f'\n\tbwd_flop = {self.bwd_flop}'\
f'\n\tfwd_comm = {self.fwd_comm}'\
f'\n\tbwd_comm = {self.bwd_comm}'\
f'\n\tto_recompute = {self.to_recompute}'\
f'\n\tto_offload = {self.to_offload}'\
f'\n\tsharding_spec = {self.sharding_spec}'
s += f"\n\thas backward activation of size {_format_memory(self.backward_size)}"
s += (
f"\n\tfwd_flop = {self.fwd_flop}"
f"\n\tbwd_flop = {self.bwd_flop}"
f"\n\tfwd_comm = {self.fwd_comm}"
f"\n\tbwd_comm = {self.bwd_comm}"
f"\n\tto_recompute = {self.to_recompute}"
f"\n\tto_offload = {self.to_offload}"
f"\n\tsharding_spec = {self.sharding_spec}"
)
return s

View File

@@ -1,8 +1,8 @@
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
from typing import Any, Dict, Iterator, List, Optional, Tuple
import torch
import torch.fx
from torch.autograd.profiler_util import _format_memory, _format_time
from torch.autograd.profiler_util import _format_memory
from torch.fx import GraphModule
from torch.fx.node import Argument, Node, Target
@@ -13,14 +13,14 @@ from colossalai._analyzer.fx.node_util import MetaInfo
def _format_flops(flops: float) -> str:
"""Returns a formatted FLOP size string"""
if flops > 1e12:
return f'{flops / 1e12:.2f} TFLOPs'
return f"{flops / 1e12:.2f} TFLOPs"
elif flops > 1e9:
return f'{flops / 1e9:.2f} GFLOPs'
return f"{flops / 1e9:.2f} GFLOPs"
elif flops > 1e6:
return f'{flops / 1e6:.2f} MFLOPs'
return f"{flops / 1e6:.2f} MFLOPs"
elif flops > 1e3:
return f'{flops / 1e3:.2f} kFLOPs'
return f'{flops} FLOPs'
return f"{flops / 1e3:.2f} kFLOPs"
return f"{flops} FLOPs"
def _denormalize_tuple(t: Tuple[int, ...]) -> Tuple[int, ...]:
@@ -42,10 +42,11 @@ class GraphProfiler(torch.fx.Interpreter):
Fetch shape argument from ``ShapeProp`` without re-executing
the ``GraphModule`` from scratch.
"""
_profileable = [
'call_function',
'call_module',
'call_method',
"call_function",
"call_module",
"call_method",
]
def __init__(self, module: GraphModule, garbage_collect_values: bool = True):
@@ -77,14 +78,13 @@ class GraphProfiler(torch.fx.Interpreter):
self.args_iter: Iterator[Any] = iter(args)
for node in self.module.graph.nodes:
self.run_node(node) # No need to store.
self.run_node(node) # No need to store.
if self.garbage_collect_values:
for to_delete in self.user_to_last_uses.get(node, []):
del self.env[to_delete]
if node.op == 'output':
if node.op == "output":
output_val = self.env[node]
return self.module.graph.process_outputs(output_val) if enable_io_processing else output_val
@@ -133,9 +133,11 @@ class GraphProfiler(torch.fx.Interpreter):
try:
from tabulate import tabulate
except ImportError:
print("`summary` relies on the library `tabulate`, "
"which could not be found on this machine. Run `pip "
"install tabulate` to install the library.")
print(
"`summary` relies on the library `tabulate`, "
"which could not be found on this machine. Run `pip "
"install tabulate` to install the library."
)
# Build up a list of summary information for each node
node_summaries: List[List[Any]] = []
@@ -145,36 +147,38 @@ class GraphProfiler(torch.fx.Interpreter):
node: Node
n_info = MetaInfo(node)
last_n_info = last_n_info or n_info
node_summaries.append([
node.op,
str(node),
_format_memory(n_info.accumulate_size),
_format_memory(n_info.accumulate_size - last_n_info.accumulate_size),
_format_memory(n_info.output_size),
_format_memory(n_info.temp_size),
_format_memory(n_info.param_size),
_format_memory(n_info.backward_size),
_format_flops(n_info.fwd_flop),
_format_flops(n_info.bwd_flop),
])
node_summaries.append(
[
node.op,
str(node),
_format_memory(n_info.accumulate_size),
_format_memory(n_info.accumulate_size - last_n_info.accumulate_size),
_format_memory(n_info.output_size),
_format_memory(n_info.temp_size),
_format_memory(n_info.param_size),
_format_memory(n_info.backward_size),
_format_flops(n_info.fwd_flop),
_format_flops(n_info.bwd_flop),
]
)
last_n_info = n_info
# Use the ``tabulate`` library to create a well-formatted table
# presenting our summary information
headers: List[str] = [
'Op type',
'Op',
'Accumulate size',
'Incremental size',
'Output size',
'Temp size',
'Param size',
'Backward size',
'Fwd FLOPs',
'Bwd FLOPs',
"Op type",
"Op",
"Accumulate size",
"Incremental size",
"Output size",
"Temp size",
"Param size",
"Backward size",
"Fwd FLOPs",
"Bwd FLOPs",
]
return tabulate(node_summaries, headers=headers, stralign='right')
return tabulate(node_summaries, headers=headers, stralign="right")
class CommunicationProfiler(GraphProfiler):
@@ -222,6 +226,7 @@ class FlopProfiler(GraphProfiler):
>>> def my_fn_flop_count_impl(*args, **kwargs):
>>> return 0, 0
"""
_custom_flop_count_impl = {}
def run_node(self, n: torch.fx.Node) -> Any:
@@ -246,11 +251,13 @@ class FlopProfiler(GraphProfiler):
(
n_info.fwd_flop,
n_info.bwd_flop,
) = getattr(self, n.op)(n.target, args, kwargs)
) = getattr(
self, n.op
)(n.target, args, kwargs)
except Exception as e:
raise RuntimeError(
f'Error {str(e)} occurred when profiling node {n}, node.target = {n.target}. '
f'Please refer to function\'s docstring to register the relevant profile_impl for this node!'
f"Error {str(e)} occurred when profiling node {n}, node.target = {n.target}. "
f"Please refer to function's docstring to register the relevant profile_impl for this node!"
) from e
# retain the autograd graph
@@ -259,7 +266,7 @@ class FlopProfiler(GraphProfiler):
return _denormalize_tuple(n_info.outputs)
def call_function(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
def call_function(self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
"""
Execute a ``call_function`` node and return the profiling result.
Dispatch to ``_custom_flop_count_impl`` if ``call_function`` should be
@@ -283,7 +290,7 @@ class FlopProfiler(GraphProfiler):
else:
return flop_count(target, *args, **kwargs)
def call_method(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
def call_method(self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
"""
Execute a ``call_method`` node and return the profiling result.
@@ -301,7 +308,7 @@ class FlopProfiler(GraphProfiler):
assert isinstance(target, str)
return flop_count(getattr(torch.Tensor, target), *args, **kwargs)
def call_module(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
def call_module(self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
"""
Execute a ``call_module`` node and return the profiling result.
@@ -336,9 +343,10 @@ def graph_profile_pass(module: GraphModule, *args, verbose=False) -> GraphModule
Returns:
GraphModule: The same GraphModule with profiling information
"""
for profiler_cls in (FlopProfiler,
# CommunicationProfiler, # TODO: add communication profiling
):
for profiler_cls in (
FlopProfiler,
# CommunicationProfiler, # TODO: add communication profiling
):
profiler = profiler_cls(module)
profiler.propagate(*args, device=_current_device(module))

View File

@@ -54,7 +54,7 @@ def _current_device(module):
try:
return next(module.parameters()).device
except StopIteration:
return torch.device('cpu')
return torch.device("cpu")
@compatibility(is_backward_compatible=False)
@@ -90,6 +90,7 @@ class ShapeProp(torch.fx.Interpreter):
>>> # do something here
>>> return torch.empty(output_shape, device=output_device)
"""
_custom_dispatch_func = {}
_mode = MetaTensorMode()
@@ -115,15 +116,14 @@ class ShapeProp(torch.fx.Interpreter):
r = getattr(self, n.op)(n.target, args, kwargs)
def unwrap_fn(elem):
def _convert_meta(t: torch.Tensor):
if t.device == 'meta':
if t.device == "meta":
return t
else:
return t.to('meta')
return t.to("meta")
if isinstance(elem, MetaTensor):
if getattr(self, '_is_param', False):
if getattr(self, "_is_param", False):
return torch.nn.Parameter(_convert_meta(elem._tensor))
return _convert_meta(elem._tensor)
@@ -139,21 +139,24 @@ class ShapeProp(torch.fx.Interpreter):
n_info = MetaInfo(n)
n_info.outputs = _normalize_tuple(r)
if n.op == 'call_module':
if n.op == "call_module":
submod = self.fetch_attr(n.target)
n_info.parameters.update({k: MetaTensor(v) for k, v in submod.named_parameters()})
n_info.buffers.update({k: MetaTensor(v) for k, v in submod.named_buffers()})
else:
n_info.parameters.update({
k.name: MetaTensor(v)
for k, v in zip(n.args, args)
if isinstance(k, torch.fx.Node) and isinstance(v, torch.nn.Parameter)
})
n_info.parameters.update(
{
k.name: MetaTensor(v)
for k, v in zip(n.args, args)
if isinstance(k, torch.fx.Node) and isinstance(v, torch.nn.Parameter)
}
)
n_info.parameters.update({k: MetaTensor(v) for k, v in kwargs.items() if isinstance(v, torch.nn.Parameter)})
n_info.inputs = tuple(v for v in args if is_pure_tensor(v)) + \
tuple(v for v in kwargs.values() if is_pure_tensor(v))
n_info.inputs = tuple(v for v in args if is_pure_tensor(v)) + tuple(
v for v in kwargs.values() if is_pure_tensor(v)
)
# align with SPMD
if isinstance(r, (tuple, list)):
@@ -168,7 +171,7 @@ class ShapeProp(torch.fx.Interpreter):
n_info.is_alias = _normalize_tuple(tree_map(crit, n_info.outputs))
return r
def call_function(self, target: 'Target', args: Tuple[Any, ...], kwargs: Dict[str, Any]) -> Any:
def call_function(self, target: "Target", args: Tuple[Any, ...], kwargs: Dict[str, Any]) -> Any:
"""
Execute a ``call_function`` node and return the result.
If the target of ``Node`` is registered with ``@register_shape_impl``,
@@ -197,7 +200,7 @@ class ShapeProp(torch.fx.Interpreter):
else:
return res
def call_method(self, target: 'Target', args: Tuple[Any, ...], kwargs: Dict[str, Any]) -> Any:
def call_method(self, target: "Target", args: Tuple[Any, ...], kwargs: Dict[str, Any]) -> Any:
"""
Execute a ``call_method`` node and return the result.
@@ -218,7 +221,8 @@ class ShapeProp(torch.fx.Interpreter):
convert_to_parameter = False
if target_method in (torch.Tensor.view, torch.Tensor.transpose) and isinstance(
args[0], torch.nn.parameter.Parameter):
args[0], torch.nn.parameter.Parameter
):
convert_to_parameter = True
# Execute the method and return the result
assert isinstance(target, str)

View File

@@ -1,5 +1,3 @@
import torch
import torch.fx
from torch.fx import GraphModule
from .passes import ShapeProp, graph_profile_pass, shape_prop_pass
@@ -7,7 +5,6 @@ from .passes.graph_profile import FlopProfiler
def register_flop_count_impl(func):
def wrapper(impl):
FlopProfiler._custom_flop_count_impl[func] = impl
return impl
@@ -16,7 +13,6 @@ def register_flop_count_impl(func):
def register_shape_impl(func):
def wrapper(impl):
ShapeProp._custom_dispatch_func[func] = impl
return impl

View File

@@ -12,7 +12,7 @@ from .tracer import register_tracer_impl
__all__ = []
@register_tracer_impl(F.linear, name='_bias_addition_impl')
@register_tracer_impl(F.linear, name="_bias_addition_impl")
def linear_impl(input, weight, bias=None):
if bias is None:
return F.linear(input, weight)
@@ -20,116 +20,130 @@ def linear_impl(input, weight, bias=None):
return F.linear(input, weight) + bias
@register_tracer_impl(F.conv1d, name='_bias_addition_impl')
@register_tracer_impl(F.conv1d, name="_bias_addition_impl")
def conv1d_impl(input, weight, bias=None, stride=_single(1), padding=_single(0), dilation=_single(1), groups=1):
if bias is None:
return F.conv1d(input, weight, stride=stride, padding=padding, dilation=dilation, groups=groups)
else:
return F.conv1d(input, weight, stride=stride, padding=padding, dilation=dilation, groups=groups) + bias.reshape(
(-1, 1))
(-1, 1)
)
@register_tracer_impl(F.conv2d, name='_bias_addition_impl')
@register_tracer_impl(F.conv2d, name="_bias_addition_impl")
def conv2d_impl(input, weight, bias=None, stride=_pair(1), padding=_pair(0), dilation=_pair(1), groups=1):
if bias is None:
return F.conv2d(input, weight, stride=stride, padding=padding, dilation=dilation, groups=groups)
else:
return F.conv2d(input, weight, stride=stride, padding=padding, dilation=dilation, groups=groups) + bias.reshape(
(-1, 1, 1))
(-1, 1, 1)
)
@register_tracer_impl(F.conv3d, name='_bias_addition_impl')
@register_tracer_impl(F.conv3d, name="_bias_addition_impl")
def conv3d_impl(input, weight, bias=None, stride=_triple(1), padding=_triple(0), dilation=_triple(1), groups=1):
if bias is None:
return F.conv3d(input, weight, stride=stride, padding=padding, dilation=dilation, groups=groups)
else:
return F.conv3d(input, weight, stride=stride, padding=padding, dilation=dilation, groups=groups) + bias.reshape(
(-1, 1, 1, 1))
(-1, 1, 1, 1)
)
@register_tracer_impl(F.conv_transpose1d, name='_bias_addition_impl')
def conv_transpose1d_impl(input,
weight,
bias=None,
stride=_single(1),
padding=_single(0),
output_padding=_single(0),
groups=1,
dilation=_single(1)):
@register_tracer_impl(F.conv_transpose1d, name="_bias_addition_impl")
def conv_transpose1d_impl(
input,
weight,
bias=None,
stride=_single(1),
padding=_single(0),
output_padding=_single(0),
groups=1,
dilation=_single(1),
):
if bias is None:
return F.conv_transpose1d(input,
weight,
stride=stride,
padding=padding,
output_padding=output_padding,
groups=groups,
dilation=dilation)
return F.conv_transpose1d(
input,
weight,
stride=stride,
padding=padding,
output_padding=output_padding,
groups=groups,
dilation=dilation,
)
else:
return F.conv_transpose1d(input,
weight,
stride=stride,
padding=padding,
output_padding=output_padding,
groups=groups,
dilation=dilation) + bias.reshape((-1, 1))
return F.conv_transpose1d(
input,
weight,
stride=stride,
padding=padding,
output_padding=output_padding,
groups=groups,
dilation=dilation,
) + bias.reshape((-1, 1))
@register_tracer_impl(F.conv_transpose2d, name='_bias_addition_impl')
def conv_transpose2d_impl(input,
weight,
bias=None,
stride=_pair(1),
padding=_pair(0),
output_padding=_pair(0),
groups=1,
dilation=_pair(1)):
@register_tracer_impl(F.conv_transpose2d, name="_bias_addition_impl")
def conv_transpose2d_impl(
input, weight, bias=None, stride=_pair(1), padding=_pair(0), output_padding=_pair(0), groups=1, dilation=_pair(1)
):
if bias is None:
return F.conv_transpose2d(input,
weight,
stride=stride,
padding=padding,
output_padding=output_padding,
groups=groups,
dilation=dilation)
return F.conv_transpose2d(
input,
weight,
stride=stride,
padding=padding,
output_padding=output_padding,
groups=groups,
dilation=dilation,
)
else:
return F.conv_transpose2d(input,
weight,
stride=stride,
padding=padding,
output_padding=output_padding,
groups=groups,
dilation=dilation) + bias.reshape((-1, 1, 1))
return F.conv_transpose2d(
input,
weight,
stride=stride,
padding=padding,
output_padding=output_padding,
groups=groups,
dilation=dilation,
) + bias.reshape((-1, 1, 1))
@register_tracer_impl(F.conv_transpose3d, name='_bias_addition_impl')
def conv_transpose3d_impl(input,
weight,
bias=None,
stride=_triple(1),
padding=_triple(0),
output_padding=_triple(0),
groups=1,
dilation=_triple(1)):
@register_tracer_impl(F.conv_transpose3d, name="_bias_addition_impl")
def conv_transpose3d_impl(
input,
weight,
bias=None,
stride=_triple(1),
padding=_triple(0),
output_padding=_triple(0),
groups=1,
dilation=_triple(1),
):
if bias is None:
return F.conv_transpose3d(input,
weight,
stride=stride,
padding=padding,
output_padding=output_padding,
groups=groups,
dilation=dilation)
return F.conv_transpose3d(
input,
weight,
stride=stride,
padding=padding,
output_padding=output_padding,
groups=groups,
dilation=dilation,
)
else:
return F.conv_transpose3d(input,
weight,
stride=stride,
padding=padding,
output_padding=output_padding,
groups=groups,
dilation=dilation) + bias.reshape((-1, 1, 1, 1))
return F.conv_transpose3d(
input,
weight,
stride=stride,
padding=padding,
output_padding=output_padding,
groups=groups,
dilation=dilation,
) + bias.reshape((-1, 1, 1, 1))
@register_tracer_impl(torch.addmm, name='_bias_addition_impl')
@register_tracer_impl(torch.Tensor.addmm, name='_bias_addition_impl')
@register_tracer_impl(torch.addmm, name="_bias_addition_impl")
@register_tracer_impl(torch.Tensor.addmm, name="_bias_addition_impl")
def addmm_impl(input, mat1, mat2, beta=1, alpha=1):
if alpha != 1 and beta != 1:
return F.linear(mat1, mat2.transpose(0, 1)) * alpha + input * beta
@@ -141,8 +155,8 @@ def addmm_impl(input, mat1, mat2, beta=1, alpha=1):
return F.linear(mat1, mat2.transpose(0, 1)) + input
@register_tracer_impl(torch.addbmm, name='_bias_addition_impl')
@register_tracer_impl(torch.Tensor.addbmm, name='_bias_addition_impl')
@register_tracer_impl(torch.addbmm, name="_bias_addition_impl")
@register_tracer_impl(torch.Tensor.addbmm, name="_bias_addition_impl")
def addbmm_impl(input, batch1, batch2, beta=1, alpha=1):
if alpha != 1 and beta != 1:
return torch.bmm(batch1, batch2.transpose(1, 2)) * alpha + input * beta

View File

@@ -4,6 +4,7 @@ from .tracer import register_leaf_module, register_leaf_module_impl
try:
import apex
register_leaf_module(apex.normalization.FusedLayerNorm)
register_leaf_module(apex.normalization.FusedRMSNorm)
register_leaf_module(apex.normalization.MixedFusedLayerNorm)

View File

@@ -1,10 +1,8 @@
import operator
from typing import Any, Callable, Dict, Optional, Set, Union
from typing import Any, Callable, Dict, Optional, Union
import torch
import torch.nn as nn
from torch.fx import Graph, Node, Proxy, Tracer
from torch.fx.graph import _Namespace
from torch.fx import Node, Proxy
from torch.utils._pytree import tree_map
from colossalai._analyzer._subclasses import MetaTensor
@@ -32,7 +30,7 @@ class ColoProxy(Proxy):
def __torch_function__(cls, orig_method, types, args=(), kwargs=None):
kwargs = {} if kwargs is None else kwargs
if orig_method in cls._func_dispatch:
impl = cls._func_dispatch.pop(orig_method) # avoid recursion
impl = cls._func_dispatch.pop(orig_method) # avoid recursion
proxy = impl(*args, **kwargs)
cls._func_dispatch[orig_method] = impl
return proxy
@@ -72,7 +70,7 @@ class ColoProxy(Proxy):
return ColoAttribute(self, k, getattr(self._meta_data, k, None))
def __setitem__(self, key, value):
proxy = self.tracer.create_proxy('call_function', operator.setitem, (self, key, value), {})
proxy = self.tracer.create_proxy("call_function", operator.setitem, (self, key, value), {})
proxy.meta_data = self._meta_data
return proxy
@@ -89,7 +87,6 @@ class ColoProxy(Proxy):
class ColoAttribute(ColoProxy):
def __init__(self, root, attr: str, data=None):
self.root = root
self.attr = attr
@@ -102,11 +99,11 @@ class ColoAttribute(ColoProxy):
# the node for attributes is added lazily, since most will just be method calls
# which do not rely on the getitem call
if self._node is None:
self._node = self.tracer.create_proxy('call_function', getattr, (self.root, self.attr), {}).node
self._node = self.tracer.create_proxy("call_function", getattr, (self.root, self.attr), {}).node
return self._node
def __call__(self, *args, **kwargs):
return self.tracer.create_proxy('call_method', self.attr, (self.root,) + args, kwargs)
return self.tracer.create_proxy("call_method", self.attr, (self.root,) + args, kwargs)
def __repr__(self):
return f"ColoAttribute({self.node.name}, attr={self.attr})"

View File

@@ -1,4 +1,4 @@
from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Type, Union
from typing import Any, Callable, Dict, Optional, Union
import torch
from torch.fx import Tracer
@@ -8,6 +8,7 @@ from colossalai._analyzer._subclasses import MetaTensor
try:
from ..codegen import ActivationCheckpointCodeGen
SUPPORT_ACTIVATION = True
except:
SUPPORT_ACTIVATION = False
@@ -16,7 +17,7 @@ from .tracer import ColoTracer
def _default_device():
return torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
return torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
def _current_device(module: torch.nn.Module):
@@ -144,10 +145,9 @@ def symbolic_trace(
if meta_args:
device, orig_device = _default_device(), _current_device(root)
wrap_fn = lambda elem: MetaTensor(elem, device=device) if isinstance(elem, torch.Tensor) else elem
graph = ColoTracer(trace_act_ckpt=trace_act_ckpt,
bias_addition_split=bias_addition_split).trace(root.to(device),
concrete_args=concrete_args,
meta_args=tree_map(wrap_fn, meta_args))
graph = ColoTracer(trace_act_ckpt=trace_act_ckpt, bias_addition_split=bias_addition_split).trace(
root.to(device), concrete_args=concrete_args, meta_args=tree_map(wrap_fn, meta_args)
)
if trace_act_ckpt and SUPPORT_ACTIVATION:
graph.set_codegen(ActivationCheckpointCodeGen())
root.to(orig_device)

View File

@@ -20,11 +20,10 @@ def _truncate_suffix(s: str):
import re
# FIXME: don't know why but torch.fx always gets a suffix like '_1' in the name
return re.sub(r'_\d+$', '', s)
return re.sub(r"_\d+$", "", s)
def register_tracer_impl(func: Callable[..., Any], name: Optional[str] = '_custom_impl'):
def register_tracer_impl(func: Callable[..., Any], name: Optional[str] = "_custom_impl"):
def wrapper(impl):
assert hasattr(ColoTracer, name), f"Cannot register {func.__name__} in ColoTracer.{name}"
getattr(ColoTracer, name)[func] = impl
@@ -34,7 +33,6 @@ def register_tracer_impl(func: Callable[..., Any], name: Optional[str] = '_custo
def register_leaf_module_impl(module: nn.Module):
def wrapper(impl):
ColoTracer._custom_leaf_module_impl[module] = impl
return impl
@@ -76,7 +74,7 @@ class ColoTracer(Tracer):
self.ckpt_regions = []
self.ckpt_idx = 0
self.mod_dir = ''
self.mod_dir = ""
# whether the tracer should split the bias_add ops into two ops
self.bias_addition_split = bias_addition_split
@@ -87,35 +85,41 @@ class ColoTracer(Tracer):
if self.bias_addition_split and type(m) in self._bias_addition_module and m.bias is not None:
return False
# user can specify which modules are leaf modules and which are not
return (type(m) not in self._custom_non_leaf_module
and (type(m) in self._custom_leaf_module or super().is_leaf_module(m, module_qualified_name)))
return type(m) not in self._custom_non_leaf_module and (
type(m) in self._custom_leaf_module or super().is_leaf_module(m, module_qualified_name)
)
def call_module(self, m: torch.nn.Module, forward: Callable[..., Any], args: Tuple[Any, ...],
kwargs: Dict[str, Any]) -> Any:
def call_module(
self, m: torch.nn.Module, forward: Callable[..., Any], args: Tuple[Any, ...], kwargs: Dict[str, Any]
) -> Any:
curr_dir = self.mod_dir
self.mod_dir = 'self.' + self.path_of_module(m)
self.mod_dir = "self." + self.path_of_module(m)
rst = super().call_module(m, forward, args, kwargs)
self.mod_dir = curr_dir
return rst
def proxy(self, node: Node) -> 'ColoProxy':
def proxy(self, node: Node) -> "ColoProxy":
return ColoProxy(node, self)
def create_proxy(self,
kind: str,
target: Target,
args: Tuple[Any, ...],
kwargs: Dict[str, Any],
name: Optional[str] = None,
type_expr: Optional[Any] = None,
proxy_factory_fn: Callable[[Node], 'Proxy'] = None):
def create_proxy(
self,
kind: str,
target: Target,
args: Tuple[Any, ...],
kwargs: Dict[str, Any],
name: Optional[str] = None,
type_expr: Optional[Any] = None,
proxy_factory_fn: Callable[[Node], "Proxy"] = None,
):
proxy: ColoProxy = super().create_proxy(kind, target, args, kwargs, name, type_expr, proxy_factory_fn)
unwrap_fn = lambda p: p.meta_data if isinstance(p, ColoProxy) else p
if kind == 'placeholder':
proxy.meta_data = self.meta_args[target] if target in self.meta_args else self.concrete_args.get(
_truncate_suffix(target), None)
elif kind == 'get_attr':
if kind == "placeholder":
proxy.meta_data = (
self.meta_args[target]
if target in self.meta_args
else self.concrete_args.get(_truncate_suffix(target), None)
)
elif kind == "get_attr":
self.disable_module_getattr = True
try:
attr_itr = self.root
@@ -125,20 +129,21 @@ class ColoTracer(Tracer):
proxy.meta_data = attr_itr
finally:
self.disable_module_getattr = False
elif kind == 'call_function':
elif kind == "call_function":
proxy.meta_data = target(*tree_map(unwrap_fn, args), **tree_map(unwrap_fn, kwargs))
elif kind == 'call_method':
elif kind == "call_method":
self.disable_module_getattr = True
try:
if target == '__call__':
if target == "__call__":
proxy.meta_data = unwrap_fn(args[0])(*tree_map(unwrap_fn, args[1:]), **tree_map(unwrap_fn, kwargs))
else:
if target not in _TensorPropertyMethod:
proxy._meta_data = getattr(unwrap_fn(args[0]), target)(*tree_map(unwrap_fn, args[1:]),
**tree_map(unwrap_fn, kwargs))
proxy._meta_data = getattr(unwrap_fn(args[0]), target)(
*tree_map(unwrap_fn, args[1:]), **tree_map(unwrap_fn, kwargs)
)
finally:
self.disable_module_getattr = False
elif kind == 'call_module':
elif kind == "call_module":
mod = self.root.get_submodule(target)
self.disable_module_getattr = True
try:
@@ -158,11 +163,12 @@ class ColoTracer(Tracer):
n_info = MetaInfo(node, mod_dir=self.mod_dir, activation_checkpoint=tuple(self.ckpt_regions))
return node
def trace(self,
root: torch.nn.Module,
concrete_args: Optional[Dict[str, torch.Tensor]] = None,
meta_args: Optional[Dict[str, torch.Tensor]] = None) -> Graph:
def trace(
self,
root: torch.nn.Module,
concrete_args: Optional[Dict[str, torch.Tensor]] = None,
meta_args: Optional[Dict[str, torch.Tensor]] = None,
) -> Graph:
if meta_args is None:
meta_args = {}
@@ -177,9 +183,7 @@ class ColoTracer(Tracer):
non_concrete_arg_names = sig_names - concrete_arg_names
# update concrete args with default values
for k, v in sig.parameters.items():
if k in sig_names - meta_arg_names and \
k not in concrete_args and \
v.default is not inspect.Parameter.empty:
if k in sig_names - meta_arg_names and k not in concrete_args and v.default is not inspect.Parameter.empty:
concrete_args[k] = v.default
def _check_arg_name_valid(names: Iterable[str]):
@@ -194,9 +198,9 @@ class ColoTracer(Tracer):
self.meta_args = meta_args
with self._torch_factory_override(), self._tracer_override(), torch.no_grad():
self.mod_dir = 'self'
self.mod_dir = "self"
self.graph = super().trace(root, concrete_args=concrete_args)
self.mod_dir = ''
self.mod_dir = ""
self.graph.lint()
for node in self.graph.nodes:
@@ -266,17 +270,17 @@ class ColoTracer(Tracer):
# override the torch factory functions to create a proxy when the method
# is called during ``symbolic_trace()``.
def wrap_factory_method(target):
@functools.wraps(target)
def wrapper(*args, **kwargs):
is_proxy = any(isinstance(p, ColoProxy) for p in args) | any(
isinstance(p, ColoProxy) for p in kwargs.values())
isinstance(p, ColoProxy) for p in kwargs.values()
)
if is_proxy:
# if the arg is a proxy, then need to record this function called on this proxy
# e.g. torch.ones(size) where size is an input proxy
self.disable_module_getattr = True
try:
proxy = self.create_proxy('call_function', target, args, kwargs)
proxy = self.create_proxy("call_function", target, args, kwargs)
finally:
self.disable_module_getattr = False
return proxy
@@ -341,10 +345,13 @@ class ColoTracer(Tracer):
if attr_val is p:
if n not in parameter_proxy_cache:
kwargs = {}
if 'proxy_factory_fn' in inspect.signature(self.create_proxy).parameters:
kwargs['proxy_factory_fn'] = (None if not self.param_shapes_constant else
lambda node: ColoProxy(self, node, n, attr_val))
val_proxy = self.create_proxy('get_attr', n, (), {}, **kwargs) # type: ignore[arg-type]
if "proxy_factory_fn" in inspect.signature(self.create_proxy).parameters:
kwargs["proxy_factory_fn"] = (
None
if not self.param_shapes_constant
else lambda node: ColoProxy(self, node, n, attr_val)
)
val_proxy = self.create_proxy("get_attr", n, (), {}, **kwargs) # type: ignore[arg-type]
parameter_proxy_cache[n] = val_proxy
return parameter_proxy_cache[n]
return None
@@ -355,8 +362,9 @@ class ColoTracer(Tracer):
return maybe_buffer_proxy
if isinstance(attr_val, torch.nn.Parameter):
maybe_parameter_proxy = maybe_get_proxy_for_attr(attr_val, self.root.named_parameters(),
parameter_proxy_cache)
maybe_parameter_proxy = maybe_get_proxy_for_attr(
attr_val, self.root.named_parameters(), parameter_proxy_cache
)
if maybe_parameter_proxy is not None:
return maybe_parameter_proxy