mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-04 18:40:28 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -3,7 +3,7 @@
|
||||
# refer to https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml
|
||||
# for more meta_registrations
|
||||
|
||||
from typing import Callable, List, Optional, Tuple, Union
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import torch
|
||||
from packaging import version
|
||||
@@ -24,25 +24,23 @@ orig_empty_like = torch.empty_like
|
||||
|
||||
|
||||
def new(*args, **kwargs):
|
||||
return orig_empty(*args, **kwargs, device=torch.device('meta'))
|
||||
return orig_empty(*args, **kwargs, device=torch.device("meta"))
|
||||
|
||||
|
||||
def new_strided(*args, **kwargs):
|
||||
return orig_empty_strided(*args, **kwargs, device=torch.device('meta'))
|
||||
return orig_empty_strided(*args, **kwargs, device=torch.device("meta"))
|
||||
|
||||
|
||||
def new_like(*args, **kwargs):
|
||||
return orig_empty_like(*args, **kwargs, device=torch.device('meta'))
|
||||
return orig_empty_like(*args, **kwargs, device=torch.device("meta"))
|
||||
|
||||
|
||||
def register_meta(op, register_dispatcher=True):
|
||||
|
||||
def wrapper(f):
|
||||
|
||||
def add_func(op):
|
||||
meta_table[op] = f
|
||||
if register_dispatcher:
|
||||
name = (op.__name__ if op._overloadname != "default" else op.overloadpacket.__name__)
|
||||
name = op.__name__ if op._overloadname != "default" else op.overloadpacket.__name__
|
||||
try:
|
||||
meta_lib.impl(name, f)
|
||||
except:
|
||||
@@ -54,7 +52,7 @@ def register_meta(op, register_dispatcher=True):
|
||||
return wrapper
|
||||
|
||||
|
||||
if version.parse(torch.__version__) >= version.parse('1.12.0'):
|
||||
if version.parse(torch.__version__) >= version.parse("1.12.0"):
|
||||
# ============================== Convolutions ======================================
|
||||
# https://github.com/pytorch/pytorch/pull/79834
|
||||
@register_meta(aten.convolution.default)
|
||||
@@ -69,7 +67,6 @@ if version.parse(torch.__version__) >= version.parse('1.12.0'):
|
||||
output_padding: List[int],
|
||||
groups: int,
|
||||
):
|
||||
|
||||
def _formula(ln: int, p: int, d: int, k: int, s: int) -> int:
|
||||
"""
|
||||
Formula to apply to calculate the length of some dimension of the output
|
||||
@@ -146,7 +143,8 @@ if version.parse(torch.__version__) >= version.parse('1.12.0'):
|
||||
kernel_size[i],
|
||||
stride[i],
|
||||
output_padding_list[i],
|
||||
))
|
||||
)
|
||||
)
|
||||
else:
|
||||
ret_shape.append(_formula(dims[i], padding[i], dilation[i], kernel_size[i], stride[i]))
|
||||
return ret_shape
|
||||
@@ -180,19 +178,39 @@ if version.parse(torch.__version__) >= version.parse('1.12.0'):
|
||||
shape_out = calc_conv_nd_return_shape(dims, kernel_size, stride, padding, dilation)
|
||||
out = input_tensor.new_empty((input_tensor.shape[0], out_channels, *shape_out))
|
||||
mem_fmt = pick_memory_format()
|
||||
out = out.to(memory_format=mem_fmt) # type: ignore[call-overload]
|
||||
out = out.to(memory_format=mem_fmt) # type: ignore[call-overload]
|
||||
return out
|
||||
|
||||
@register_meta(aten._convolution.default)
|
||||
def meta__conv(input_tensor: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, stride: List[int],
|
||||
padding: List[int], dilation: List[int], is_transposed: bool, output_padding: List[int], groups: int,
|
||||
*extra_args):
|
||||
def meta__conv(
|
||||
input_tensor: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
bias: torch.Tensor,
|
||||
stride: List[int],
|
||||
padding: List[int],
|
||||
dilation: List[int],
|
||||
is_transposed: bool,
|
||||
output_padding: List[int],
|
||||
groups: int,
|
||||
*extra_args,
|
||||
):
|
||||
out = meta_conv(input_tensor, weight, bias, stride, padding, dilation, is_transposed, output_padding, groups)
|
||||
return out
|
||||
|
||||
@register_meta(aten.convolution_backward.default)
|
||||
def meta_conv_backward(grad_output: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, bias_sizes, stride,
|
||||
padding, dilation, transposed, output_padding, groups, output_mask):
|
||||
def meta_conv_backward(
|
||||
grad_output: torch.Tensor,
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
bias_sizes,
|
||||
stride,
|
||||
padding,
|
||||
dilation,
|
||||
transposed,
|
||||
output_padding,
|
||||
groups,
|
||||
output_mask,
|
||||
):
|
||||
return new_like(input), new_like(weight), new((bias_sizes))
|
||||
|
||||
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/AdaptiveAveragePooling.cpp
|
||||
@@ -224,7 +242,6 @@ if version.parse(torch.__version__) >= version.parse('1.12.0'):
|
||||
batch_sizes,
|
||||
dropout_state,
|
||||
):
|
||||
|
||||
is_input_packed = len(batch_sizes) != 0
|
||||
if is_input_packed:
|
||||
seq_length = len(batch_sizes)
|
||||
@@ -240,8 +257,11 @@ if version.parse(torch.__version__) >= version.parse('1.12.0'):
|
||||
if is_input_packed:
|
||||
out_shape = [batch_sizes_sum, out_size * num_directions]
|
||||
else:
|
||||
out_shape = ([mini_batch, seq_length, out_size *
|
||||
num_directions] if batch_first else [seq_length, mini_batch, out_size * num_directions])
|
||||
out_shape = (
|
||||
[mini_batch, seq_length, out_size * num_directions]
|
||||
if batch_first
|
||||
else [seq_length, mini_batch, out_size * num_directions]
|
||||
)
|
||||
output = input.new_empty(out_shape)
|
||||
|
||||
cell_shape = [num_layers * num_directions, mini_batch, hidden_size]
|
||||
@@ -257,15 +277,21 @@ if version.parse(torch.__version__) >= version.parse('1.12.0'):
|
||||
|
||||
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/RNN.cpp
|
||||
@register_meta(aten._cudnn_rnn_backward.default)
|
||||
def meta_cudnn_rnn_backward(input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
weight_stride0: int,
|
||||
hx: torch.Tensor,
|
||||
cx: Optional[torch.Tensor] = None,
|
||||
*args,
|
||||
**kwargs):
|
||||
return new_like(input), new_like(weight), new_like(hx), new_like(cx) if cx is not None else new(
|
||||
()) # (grad_input, grad_weight, grad_hx, grad_cx)
|
||||
def meta_cudnn_rnn_backward(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
weight_stride0: int,
|
||||
hx: torch.Tensor,
|
||||
cx: Optional[torch.Tensor] = None,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
return (
|
||||
new_like(input),
|
||||
new_like(weight),
|
||||
new_like(hx),
|
||||
new_like(cx) if cx is not None else new(()),
|
||||
) # (grad_input, grad_weight, grad_hx, grad_cx)
|
||||
|
||||
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Activation.cpp
|
||||
# ============================== Activations =======================================
|
||||
@@ -278,7 +304,7 @@ if version.parse(torch.__version__) >= version.parse('1.12.0'):
|
||||
aten.hardtanh_backward.default,
|
||||
]
|
||||
|
||||
if version.parse(torch.__version__) < version.parse('2.0.0'):
|
||||
if version.parse(torch.__version__) < version.parse("2.0.0"):
|
||||
_unregistered_ewise += [
|
||||
aten.prelu_backward.default,
|
||||
]
|
||||
@@ -296,37 +322,61 @@ if version.parse(torch.__version__) >= version.parse('1.12.0'):
|
||||
|
||||
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/BatchNorm.cpp
|
||||
@register_meta(aten.native_batch_norm_backward.default)
|
||||
def meta_bn_backward(dY: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, running_mean, running_var,
|
||||
save_mean, save_invstd, train, eps, output_mask):
|
||||
return new_like(input), new_like(weight), new_like(weight) # (dX, dgamma, dbeta)
|
||||
def meta_bn_backward(
|
||||
dY: torch.Tensor,
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
running_mean,
|
||||
running_var,
|
||||
save_mean,
|
||||
save_invstd,
|
||||
train,
|
||||
eps,
|
||||
output_mask,
|
||||
):
|
||||
return new_like(input), new_like(weight), new_like(weight) # (dX, dgamma, dbeta)
|
||||
|
||||
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/BatchNorm.cpp
|
||||
@register_meta(aten.cudnn_batch_norm.default)
|
||||
def meta_cudnn_bn(input: torch.Tensor, weight, bias, running_mean, running_var, training, momentum, eps):
|
||||
n_input = input.size(1)
|
||||
return new_like(input), new((n_input)), new((n_input)), new(
|
||||
(0), dtype=torch.uint8) # (output, running_mean, running_var, reserve)
|
||||
return (
|
||||
new_like(input),
|
||||
new((n_input)),
|
||||
new((n_input)),
|
||||
new((0), dtype=torch.uint8),
|
||||
) # (output, running_mean, running_var, reserve)
|
||||
|
||||
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/BatchNorm.cpp
|
||||
# NB: CuDNN only implements the backward algorithm for batchnorm
|
||||
# in training mode (evaluation mode batchnorm has a different algorithm),
|
||||
# which is why this doesn't accept a 'training' parameter.
|
||||
@register_meta(aten.cudnn_batch_norm_backward.default)
|
||||
def meta_cudnn_bn_backward(dY: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, running_mean, running_var,
|
||||
save_mean, save_invstd, eps, reserve):
|
||||
return new_like(input), new_like(weight), new_like(weight) # (dX, dgamma, dbeta)
|
||||
def meta_cudnn_bn_backward(
|
||||
dY: torch.Tensor,
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
running_mean,
|
||||
running_var,
|
||||
save_mean,
|
||||
save_invstd,
|
||||
eps,
|
||||
reserve,
|
||||
):
|
||||
return new_like(input), new_like(weight), new_like(weight) # (dX, dgamma, dbeta)
|
||||
|
||||
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/layer_norm.cpp
|
||||
@register_meta(aten.native_layer_norm.default)
|
||||
def meta_ln(input: torch.Tensor, normalized_shape, weight, bias, eps):
|
||||
bs, n_input = input.size(0), input.size(1)
|
||||
return new_like(input), new((bs, n_input, 1)), new((bs, n_input, 1)) # (output, running_mean, running_var)
|
||||
return new_like(input), new((bs, n_input, 1)), new((bs, n_input, 1)) # (output, running_mean, running_var)
|
||||
|
||||
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/layer_norm.cpp
|
||||
@register_meta(aten.native_layer_norm_backward.default)
|
||||
def meta_ln_backward(dY: torch.Tensor, input: torch.Tensor, normalized_shape, mean, rstd, weight, bias,
|
||||
grad_input_mask):
|
||||
return new_like(input), new_like(weight), new_like(bias) # (dX, dgamma, dbeta)
|
||||
def meta_ln_backward(
|
||||
dY: torch.Tensor, input: torch.Tensor, normalized_shape, mean, rstd, weight, bias, grad_input_mask
|
||||
):
|
||||
return new_like(input), new_like(weight), new_like(bias) # (dX, dgamma, dbeta)
|
||||
|
||||
# ================================== Misc ==========================================
|
||||
# Maybe incorrect
|
||||
@@ -355,8 +405,9 @@ if version.parse(torch.__version__) >= version.parse('1.12.0'):
|
||||
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Embedding.cpp
|
||||
|
||||
@register_meta(aten.embedding_dense_backward.default)
|
||||
def meta_embedding_dense_backward(grad_output: torch.Tensor, indices: torch.Tensor, num_weights, padding_idx,
|
||||
scale_grad_by_freq):
|
||||
def meta_embedding_dense_backward(
|
||||
grad_output: torch.Tensor, indices: torch.Tensor, num_weights, padding_idx, scale_grad_by_freq
|
||||
):
|
||||
return new((num_weights, grad_output.size(-1)), dtype=grad_output.dtype, layout=grad_output.layout)
|
||||
|
||||
# ============================== Dropout ===========================================
|
||||
@@ -364,14 +415,14 @@ if version.parse(torch.__version__) >= version.parse('1.12.0'):
|
||||
@register_meta(aten.native_dropout.default)
|
||||
def meta_native_dropout_default(input: torch.Tensor, p: float, train: bool = False):
|
||||
# notice that mask is bool
|
||||
return new_like(input), new_like(input, dtype=torch.bool) # (output, mask)
|
||||
return new_like(input), new_like(input, dtype=torch.bool) # (output, mask)
|
||||
|
||||
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Dropout.cpp
|
||||
@register_meta(aten.native_dropout_backward.default)
|
||||
def meta_native_dropout_backward_default(grad: torch.Tensor, mask: torch.Tensor, scale: float):
|
||||
return new_like(grad) # (grad_in)
|
||||
return new_like(grad) # (grad_in)
|
||||
|
||||
if version.parse(torch.__version__) < version.parse('1.13.0'):
|
||||
if version.parse(torch.__version__) < version.parse("1.13.0"):
|
||||
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml
|
||||
@register_meta(aten.eye.m_out)
|
||||
def meta_eye(n: int, m: int, out: torch.Tensor):
|
||||
@@ -385,24 +436,28 @@ if version.parse(torch.__version__) >= version.parse('1.12.0'):
|
||||
result: List[Optional[torch.Tensor]] = []
|
||||
for i, index in enumerate(indices):
|
||||
if index is not None:
|
||||
assert index.dtype in [torch.long, torch.int8, torch.bool],\
|
||||
"tensors used as indices must be long, byte or bool tensors"
|
||||
assert index.dtype in [
|
||||
torch.long,
|
||||
torch.int8,
|
||||
torch.bool,
|
||||
], "tensors used as indices must be long, byte or bool tensors"
|
||||
if index.dtype in [torch.int8, torch.bool]:
|
||||
nonzero = index.nonzero()
|
||||
k = len(result)
|
||||
assert k + index.ndim <= self.ndim, f"too many indices for tensor of dimension {self.ndim}"
|
||||
for j in range(index.ndim):
|
||||
assert index.shape[j] == self.shape[
|
||||
k +
|
||||
j], f"The shape of the mask {index.shape} at index {i} does not match the shape of the indexed tensor {self.shape} at index {k + j}"
|
||||
assert (
|
||||
index.shape[j] == self.shape[k + j]
|
||||
), f"The shape of the mask {index.shape} at index {i} does not match the shape of the indexed tensor {self.shape} at index {k + j}"
|
||||
result.append(nonzero.select(1, j))
|
||||
else:
|
||||
result.append(index)
|
||||
else:
|
||||
result.append(index)
|
||||
indices = result
|
||||
assert len(
|
||||
indices) <= self.ndim, f"too many indices for tensor of dimension {self.ndim} (got {len(indices)})"
|
||||
assert (
|
||||
len(indices) <= self.ndim
|
||||
), f"too many indices for tensor of dimension {self.ndim} (got {len(indices)})"
|
||||
# expand_outplace
|
||||
import torch._refs as refs
|
||||
|
||||
|
@@ -1,5 +1,4 @@
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from packaging import version
|
||||
|
||||
__all__ = [
|
||||
@@ -48,7 +47,7 @@ _DistCommMethod = [
|
||||
"scatter",
|
||||
]
|
||||
|
||||
if version.parse(torch.__version__) >= version.parse('1.12.0'):
|
||||
if version.parse(torch.__version__) >= version.parse("1.12.0"):
|
||||
aten = torch.ops.aten
|
||||
# TODO: dive deep here
|
||||
# refer to https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/TensorShape.cpp
|
||||
|
@@ -8,7 +8,7 @@ from contextlib import contextmanager
|
||||
from enum import Enum, auto
|
||||
from functools import partial, reduce
|
||||
from numbers import Number
|
||||
from typing import Any, Callable, List, Optional, Union
|
||||
from typing import Any, Callable, List, Union
|
||||
|
||||
import torch
|
||||
from packaging import version
|
||||
@@ -36,15 +36,15 @@ def _format_flops(flop):
|
||||
B = 1e9
|
||||
T = 1e12
|
||||
if flop < K:
|
||||
return f'{flop:.2f}'
|
||||
return f"{flop:.2f}"
|
||||
elif flop < M:
|
||||
return f'{flop / K:.2f}K'
|
||||
return f"{flop / K:.2f}K"
|
||||
elif flop < B:
|
||||
return f'{flop / M:.2f}M'
|
||||
return f"{flop / M:.2f}M"
|
||||
elif flop < T:
|
||||
return f'{flop / B:.2f}B'
|
||||
return f"{flop / B:.2f}B"
|
||||
else:
|
||||
return f'{flop / T:.2f}T'
|
||||
return f"{flop / T:.2f}T"
|
||||
|
||||
|
||||
def flop_count(module: Union[torch.nn.Module, Callable] = None, *args, verbose: bool = False, **kwargs) -> Number:
|
||||
@@ -59,11 +59,13 @@ def flop_count(module: Union[torch.nn.Module, Callable] = None, *args, verbose:
|
||||
Returns:
|
||||
Number: The total number of floating point operations (FWD + BWD).
|
||||
"""
|
||||
maybe_inplace = (getattr(module, 'inplace', False) or kwargs.get('inplace', False)
|
||||
or getattr(module, '__name__', None) in ('add_', 'mul_', 'div_', 'sub_'))
|
||||
maybe_inplace = (
|
||||
getattr(module, "inplace", False)
|
||||
or kwargs.get("inplace", False)
|
||||
or getattr(module, "__name__", None) in ("add_", "mul_", "div_", "sub_")
|
||||
)
|
||||
|
||||
class DummyModule(torch.nn.Module):
|
||||
|
||||
def __init__(self, func):
|
||||
super().__init__()
|
||||
self.func = func
|
||||
@@ -74,21 +76,20 @@ def flop_count(module: Union[torch.nn.Module, Callable] = None, *args, verbose:
|
||||
|
||||
total_flop_count = {Phase.FWD: 0, Phase.BWD: 0}
|
||||
flop_counts = defaultdict(lambda: defaultdict(int))
|
||||
parents = ['Global']
|
||||
parents = ["Global"]
|
||||
module = module if isinstance(module, torch.nn.Module) else DummyModule(module)
|
||||
|
||||
class FlopTensor(MetaTensor):
|
||||
_tensor: torch.Tensor
|
||||
|
||||
def __repr__(self):
|
||||
name = 'FlopParameter' if getattr(self, '_is_param', False) else 'FlopTensor'
|
||||
name = "FlopParameter" if getattr(self, "_is_param", False) else "FlopTensor"
|
||||
if self.grad_fn:
|
||||
return f"{name}(..., size={tuple(self.shape)}, device='{self.device}', dtype={self.dtype}, grad_fn={self.grad_fn})"
|
||||
return f"{name}(..., size={tuple(self.shape)}, device='{self.device}', dtype={self.dtype})"
|
||||
|
||||
@classmethod
|
||||
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
|
||||
|
||||
# no_dispatch is only needed if you use enable_python_mode.
|
||||
# It prevents infinite recursion.
|
||||
rs = super().__torch_dispatch__(func, types, args, kwargs)
|
||||
@@ -115,9 +116,7 @@ def flop_count(module: Union[torch.nn.Module, Callable] = None, *args, verbose:
|
||||
return isinstance(x, torch.Tensor) and x.is_floating_point()
|
||||
|
||||
def create_backwards_push(name):
|
||||
|
||||
class PushState(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, *args):
|
||||
args = tree_map(lambda x: x.clone() if isinstance(x, torch.Tensor) else x, args)
|
||||
@@ -134,9 +133,7 @@ def flop_count(module: Union[torch.nn.Module, Callable] = None, *args, verbose:
|
||||
return PushState.apply
|
||||
|
||||
def create_backwards_pop(name):
|
||||
|
||||
class PopState(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, *args):
|
||||
args = tree_map(lambda x: x.clone() if isinstance(x, torch.Tensor) else x, args)
|
||||
@@ -147,14 +144,13 @@ def flop_count(module: Union[torch.nn.Module, Callable] = None, *args, verbose:
|
||||
@staticmethod
|
||||
def backward(ctx, *grad_outs):
|
||||
nonlocal parents
|
||||
assert (parents[-1] == name)
|
||||
assert parents[-1] == name
|
||||
parents.pop()
|
||||
return grad_outs
|
||||
|
||||
return PopState.apply
|
||||
|
||||
def enter_module(name):
|
||||
|
||||
def f(module, inputs):
|
||||
nonlocal parents
|
||||
parents.append(name)
|
||||
@@ -165,10 +161,9 @@ def flop_count(module: Union[torch.nn.Module, Callable] = None, *args, verbose:
|
||||
return f
|
||||
|
||||
def exit_module(name):
|
||||
|
||||
def f(module, inputs, outputs):
|
||||
nonlocal parents
|
||||
assert (parents[-1] == name)
|
||||
assert parents[-1] == name
|
||||
parents.pop()
|
||||
outputs = normalize_tuple(outputs)
|
||||
return create_backwards_push(name)(*outputs)
|
||||
@@ -189,7 +184,7 @@ def flop_count(module: Union[torch.nn.Module, Callable] = None, *args, verbose:
|
||||
for mod in flop_counts.keys():
|
||||
print(f"Module: ", mod)
|
||||
for k, v in flop_counts[mod].items():
|
||||
print('\t', k, _format_flops(v))
|
||||
print("\t", k, _format_flops(v))
|
||||
print()
|
||||
|
||||
def detach_variables(r):
|
||||
@@ -201,7 +196,7 @@ def flop_count(module: Union[torch.nn.Module, Callable] = None, *args, verbose:
|
||||
|
||||
def wrap(r):
|
||||
if isinstance(r, torch.Tensor):
|
||||
data_ptr_fn = getattr(r, '_tensor', r).data_ptr
|
||||
data_ptr_fn = getattr(r, "_tensor", r).data_ptr
|
||||
r = FlopTensor(detach_variables(r))
|
||||
if maybe_inplace:
|
||||
r = r + 0
|
||||
@@ -375,8 +370,11 @@ def norm_flop_counter(affine_arg_index: int, input_arg_index: int) -> Callable:
|
||||
# Inputs[0] contains the shape of the input.
|
||||
input_shape = inputs[input_arg_index].shape
|
||||
|
||||
has_affine = inputs[affine_arg_index].shape is not None if hasattr(inputs[affine_arg_index],
|
||||
'shape') else inputs[affine_arg_index]
|
||||
has_affine = (
|
||||
inputs[affine_arg_index].shape is not None
|
||||
if hasattr(inputs[affine_arg_index], "shape")
|
||||
else inputs[affine_arg_index]
|
||||
)
|
||||
assert 2 <= len(input_shape) <= 5, input_shape
|
||||
# 5 is just a rough estimate
|
||||
flop = reduce(operator.mul, input_shape) * (5 if has_affine else 4)
|
||||
@@ -390,7 +388,7 @@ def batchnorm_flop_jit(inputs: List[Any], outputs: List[Any], training: bool = N
|
||||
training = inputs[-3]
|
||||
assert isinstance(training, bool), "Signature of aten::batch_norm has changed!"
|
||||
if training:
|
||||
return norm_flop_counter(1, 0)(inputs, outputs) # pyre-ignore
|
||||
return norm_flop_counter(1, 0)(inputs, outputs) # pyre-ignore
|
||||
has_affine = inputs[1].shape is not None
|
||||
input_shape = reduce(operator.mul, inputs[0].shape)
|
||||
return input_shape * (2 if has_affine else 1)
|
||||
@@ -420,33 +418,30 @@ def ewise_flop_counter(input_scale: float = 1, output_scale: float = 0) -> Calla
|
||||
|
||||
def zero_flop_jit(*args):
|
||||
"""
|
||||
Count flops for zero flop layers.
|
||||
Count flops for zero flop layers.
|
||||
"""
|
||||
return 0
|
||||
|
||||
|
||||
if version.parse(torch.__version__) >= version.parse('1.12.0'):
|
||||
if version.parse(torch.__version__) >= version.parse("1.12.0"):
|
||||
flop_mapping = {
|
||||
# gemm
|
||||
# gemm
|
||||
aten.mm.default: matmul_flop_jit,
|
||||
aten.matmul.default: matmul_flop_jit,
|
||||
aten.addmm.default: addmm_flop_jit,
|
||||
aten.bmm.default: bmm_flop_jit,
|
||||
|
||||
# convolution
|
||||
# convolution
|
||||
aten.convolution.default: conv_flop_jit,
|
||||
aten._convolution.default: conv_flop_jit,
|
||||
aten.convolution_backward.default: conv_backward_flop_jit,
|
||||
|
||||
# normalization
|
||||
# normalization
|
||||
aten.native_batch_norm.default: batchnorm_flop_jit,
|
||||
aten.native_batch_norm_backward.default: batchnorm_flop_jit,
|
||||
aten.cudnn_batch_norm.default: batchnorm_flop_jit,
|
||||
aten.cudnn_batch_norm_backward.default: partial(batchnorm_flop_jit, training=True),
|
||||
aten.native_layer_norm.default: norm_flop_counter(2, 0),
|
||||
aten.native_layer_norm_backward.default: norm_flop_counter(2, 0),
|
||||
|
||||
# pooling
|
||||
# pooling
|
||||
aten.avg_pool1d.default: ewise_flop_counter(1, 0),
|
||||
aten.avg_pool2d.default: ewise_flop_counter(1, 0),
|
||||
aten.avg_pool2d_backward.default: ewise_flop_counter(0, 1),
|
||||
@@ -469,7 +464,7 @@ if version.parse(torch.__version__) >= version.parse('1.12.0'):
|
||||
}
|
||||
|
||||
ewise_flop_aten = [
|
||||
# basic op
|
||||
# basic op
|
||||
aten.add.Tensor,
|
||||
aten.add_.Tensor,
|
||||
aten.div.Tensor,
|
||||
@@ -485,8 +480,7 @@ if version.parse(torch.__version__) >= version.parse('1.12.0'):
|
||||
aten.sum.default,
|
||||
aten.sum.dim_IntList,
|
||||
aten.mean.dim,
|
||||
|
||||
# activation op
|
||||
# activation op
|
||||
aten.hardswish.default,
|
||||
aten.hardswish_.default,
|
||||
aten.hardswish_backward.default,
|
||||
@@ -509,15 +503,12 @@ if version.parse(torch.__version__) >= version.parse('1.12.0'):
|
||||
aten.tanh.default,
|
||||
aten.tanh_backward.default,
|
||||
aten.threshold_backward.default,
|
||||
|
||||
# dropout
|
||||
# dropout
|
||||
aten.native_dropout.default,
|
||||
aten.native_dropout_backward.default,
|
||||
|
||||
# distribution
|
||||
# distribution
|
||||
aten.bernoulli_.float,
|
||||
|
||||
# where
|
||||
# where
|
||||
aten.where.self,
|
||||
]
|
||||
for op in ewise_flop_aten:
|
||||
|
@@ -3,12 +3,12 @@ from functools import partial
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.types import _bool, _device, _dtype
|
||||
from torch.utils._pytree import tree_flatten, tree_map
|
||||
from torch.types import _device
|
||||
from torch.utils._pytree import tree_map
|
||||
|
||||
from ._monkey_patch import _AliasATen, _DistCommMethod, _InplaceATen, _MaybeInplaceATen, _TorchOverrideableFactoryMethod
|
||||
|
||||
__all__ = ['MetaTensor', 'MetaTensorMode']
|
||||
__all__ = ["MetaTensor", "MetaTensorMode"]
|
||||
|
||||
|
||||
def register_storage(r, data_ptr_fn=None):
|
||||
@@ -28,8 +28,7 @@ def _normalize_tuple(x):
|
||||
|
||||
# a hack of inplace execution in PyTorch
|
||||
def _assert_alias(func):
|
||||
return func in (_AliasATen + _InplaceATen + _MaybeInplaceATen # TODO: check if should be this aggressive
|
||||
)
|
||||
return func in (_AliasATen + _InplaceATen + _MaybeInplaceATen) # TODO: check if should be this aggressive
|
||||
|
||||
|
||||
class MetaTensor(torch.Tensor):
|
||||
@@ -65,14 +64,15 @@ class MetaTensor(torch.Tensor):
|
||||
storage_offset=elem.storage_offset(),
|
||||
dtype=elem.dtype,
|
||||
layout=elem.layout,
|
||||
device=device or (elem.device if elem.device.type != 'meta' else torch.device('cpu')),
|
||||
requires_grad=requires_grad) # deceive the frontend for aten selections
|
||||
device=device or (elem.device if elem.device.type != "meta" else torch.device("cpu")),
|
||||
requires_grad=requires_grad,
|
||||
) # deceive the frontend for aten selections
|
||||
r._tensor = elem
|
||||
# ...the real tensor is held as an element on the tensor.
|
||||
if not r._tensor.is_meta:
|
||||
val = elem.data_ptr()
|
||||
data_ptr_fn = lambda: val
|
||||
r._tensor = r._tensor.to(torch.device('meta'))
|
||||
r._tensor = r._tensor.to(torch.device("meta"))
|
||||
|
||||
# only tensor not on `meta` should be copied to `meta`
|
||||
register_storage(r._tensor, data_ptr_fn)
|
||||
@@ -81,7 +81,7 @@ class MetaTensor(torch.Tensor):
|
||||
return r
|
||||
|
||||
def __repr__(self):
|
||||
name = 'MetaParameter' if getattr(self, '_is_param', False) else 'MetaTensor'
|
||||
name = "MetaParameter" if getattr(self, "_is_param", False) else "MetaTensor"
|
||||
if self.grad_fn:
|
||||
return f"{name}(..., size={tuple(self.shape)}, device='{self.device}', dtype={self.dtype}, grad_fn={self.grad_fn})"
|
||||
return f"{name}(..., size={tuple(self.shape)}, device='{self.device}', dtype={self.dtype})"
|
||||
@@ -97,15 +97,15 @@ class MetaTensor(torch.Tensor):
|
||||
x = x._tensor
|
||||
elif isinstance(x, torch.Tensor):
|
||||
device = x.device
|
||||
x = x.to(torch.device('meta'))
|
||||
x = x.to(torch.device("meta"))
|
||||
return x
|
||||
|
||||
args = tree_map(unwrap, args)
|
||||
kwargs = tree_map(unwrap, kwargs)
|
||||
|
||||
if 'device' in kwargs:
|
||||
device = kwargs['device']
|
||||
kwargs['device'] = torch.device('meta')
|
||||
if "device" in kwargs:
|
||||
device = kwargs["device"]
|
||||
kwargs["device"] = torch.device("meta")
|
||||
|
||||
# run aten for backend=CPU but actually on backend=Meta
|
||||
# here we detect whether or not the execution generates a physical copy
|
||||
@@ -143,21 +143,21 @@ class MetaTensor(torch.Tensor):
|
||||
nonlocal device
|
||||
if isinstance(x, str) or isinstance(x, _device):
|
||||
device = x
|
||||
return torch.device('meta')
|
||||
return torch.device("meta")
|
||||
return x
|
||||
|
||||
elem = self._tensor.to(*tree_map(replace, args), **tree_map(replace, kwargs))
|
||||
return MetaTensor(elem, device=device)
|
||||
|
||||
def cpu(self, *args, **kwargs):
|
||||
if self.device.type == 'cpu':
|
||||
if self.device.type == "cpu":
|
||||
return self.to(*args, **kwargs)
|
||||
return self.to(*args, device='cpu', **kwargs)
|
||||
return self.to(*args, device="cpu", **kwargs)
|
||||
|
||||
def cuda(self, device=None, non_blocking=False):
|
||||
if device is not None:
|
||||
return self.to(device=device, non_blocking=non_blocking)
|
||||
return self.to(device='cuda:0', non_blocking=non_blocking)
|
||||
return self.to(device="cuda:0", non_blocking=non_blocking)
|
||||
|
||||
def data_ptr(self):
|
||||
return self._tensor.data_ptr()
|
||||
@@ -177,19 +177,17 @@ class MetaTensorMode(object):
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.torch_overrides = {} # override torch.xxx
|
||||
self.dist_overrides = {} # override torch.distributed.xxx
|
||||
self.torch_overrides = {} # override torch.xxx
|
||||
self.dist_overrides = {} # override torch.distributed.xxx
|
||||
|
||||
def __enter__(self):
|
||||
|
||||
def _dummy(*args, **kwargs):
|
||||
pass
|
||||
|
||||
def _new(*args, orig_new=torch.empty, **kwargs):
|
||||
return MetaTensor(orig_new(*args, **{
|
||||
**kwargs, 'device': 'meta'
|
||||
}),
|
||||
device=kwargs.get('device', torch.device('cpu')))
|
||||
return MetaTensor(
|
||||
orig_new(*args, **{**kwargs, "device": "meta"}), device=kwargs.get("device", torch.device("cpu"))
|
||||
)
|
||||
|
||||
for func in _TorchOverrideableFactoryMethod:
|
||||
self.torch_overrides[func] = getattr(torch, func)
|
||||
|
@@ -1,4 +1,4 @@
|
||||
from typing import Any, Callable, Dict, Iterable, List, Tuple
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
@@ -22,7 +22,7 @@ from torch.fx.node import Argument, Node, _get_qualified_name, _type_repr, map_a
|
||||
import colossalai
|
||||
from colossalai.fx._compatibility import compatibility
|
||||
|
||||
_register_custom_builtin('colossalai', 'import colossalai', colossalai)
|
||||
_register_custom_builtin("colossalai", "import colossalai", colossalai)
|
||||
|
||||
|
||||
def _gen_ckpt_fn_def(label, free_vars: List[str]) -> str:
|
||||
@@ -43,17 +43,17 @@ def _gen_ckpt_usage(label, input_vars, output_vars, use_reentrant=True):
|
||||
"""
|
||||
Generate the checkpoint function call code text
|
||||
"""
|
||||
outputs = ', '.join(output_vars)
|
||||
inputs = ', '.join(input_vars)
|
||||
return f'{outputs} = torch.utils.checkpoint.checkpoint(self.checkpoint_{label}, {inputs}, use_reentrant={use_reentrant})'
|
||||
outputs = ", ".join(output_vars)
|
||||
inputs = ", ".join(input_vars)
|
||||
return f"{outputs} = torch.utils.checkpoint.checkpoint(self.checkpoint_{label}, {inputs}, use_reentrant={use_reentrant})"
|
||||
|
||||
|
||||
def _end_of_ckpt(node: Node, ckpt_level: int) -> bool:
|
||||
"""
|
||||
Check if the node could end the ckpt region at `ckpt_level`
|
||||
"""
|
||||
if len(node.meta['info'].activation_checkpoint) > ckpt_level:
|
||||
return node.meta['info'].activation_checkpoint[ckpt_level] is not None
|
||||
if len(node.meta["info"].activation_checkpoint) > ckpt_level:
|
||||
return node.meta["info"].activation_checkpoint[ckpt_level] is not None
|
||||
return True
|
||||
|
||||
|
||||
@@ -94,8 +94,8 @@ def _find_nested_ckpt_regions(node_list: List[Node], ckpt_level: int = 0):
|
||||
current_region = None
|
||||
|
||||
for idx, node in enumerate(node_list):
|
||||
if len(node.meta['info'].activation_checkpoint) > ckpt_level:
|
||||
act_ckpt_label = node.meta['info'].activation_checkpoint[ckpt_level]
|
||||
if len(node.meta["info"].activation_checkpoint) > ckpt_level:
|
||||
act_ckpt_label = node.meta["info"].activation_checkpoint[ckpt_level]
|
||||
|
||||
# this activation checkpoint label is not set yet
|
||||
# meaning this is the first node of the activation ckpt region
|
||||
@@ -131,13 +131,9 @@ def _find_nested_ckpt_regions(node_list: List[Node], ckpt_level: int = 0):
|
||||
return ckpt_regions
|
||||
|
||||
|
||||
def emit_ckpt_func(body,
|
||||
ckpt_func,
|
||||
node_list: List[Node],
|
||||
emit_node_func,
|
||||
delete_unused_value_func,
|
||||
ckpt_level=0,
|
||||
in_ckpt=False):
|
||||
def emit_ckpt_func(
|
||||
body, ckpt_func, node_list: List[Node], emit_node_func, delete_unused_value_func, ckpt_level=0, in_ckpt=False
|
||||
):
|
||||
"""Emit ckpt function in nested way
|
||||
|
||||
Args:
|
||||
@@ -156,12 +152,12 @@ def emit_ckpt_func(body,
|
||||
|
||||
# label given by each layer, e.g. if you are currently at level (0, 1, 1)
|
||||
# the label will be '0_1_1'
|
||||
label = "_".join([str(idx) for idx in node_list[0].meta['info'].activation_checkpoint[:ckpt_level + 1]])
|
||||
label = "_".join([str(idx) for idx in node_list[0].meta["info"].activation_checkpoint[: ckpt_level + 1]])
|
||||
ckpt_fn_def = _gen_ckpt_fn_def(label, inputs)
|
||||
ckpt_func.append(f'{ckpt_fn_def}\n')
|
||||
ckpt_func.append(f"{ckpt_fn_def}\n")
|
||||
|
||||
# if there is more level to fetch
|
||||
if ckpt_level + 1 < max(map(lambda node: len(node.meta['info'].activation_checkpoint), node_list)):
|
||||
if ckpt_level + 1 < max(map(lambda node: len(node.meta["info"].activation_checkpoint), node_list)):
|
||||
ckpt_regions = _find_nested_ckpt_regions(node_list, ckpt_level + 1)
|
||||
start_idx = [item[0] for item in ckpt_regions]
|
||||
end_idx = [item[1] for item in ckpt_regions]
|
||||
@@ -174,33 +170,40 @@ def emit_ckpt_func(body,
|
||||
break
|
||||
|
||||
if node_idx in start_idx:
|
||||
ckpt_node_list = node_list[node_idx:end_idx[start_idx.index(node_idx)] + 1]
|
||||
emit_ckpt_func(ckpt_func, ckpt_func_buffer, ckpt_node_list, emit_node_func, delete_unused_value_func,
|
||||
ckpt_level + 1, True)
|
||||
ckpt_node_list = node_list[node_idx : end_idx[start_idx.index(node_idx)] + 1]
|
||||
emit_ckpt_func(
|
||||
ckpt_func,
|
||||
ckpt_func_buffer,
|
||||
ckpt_node_list,
|
||||
emit_node_func,
|
||||
delete_unused_value_func,
|
||||
ckpt_level + 1,
|
||||
True,
|
||||
)
|
||||
node_idx += len(ckpt_node_list)
|
||||
|
||||
else:
|
||||
node = node_list[node_idx]
|
||||
emit_node_func(node, ckpt_func)
|
||||
ckpt_func[-1] = ' ' + ckpt_func[-1]
|
||||
ckpt_func[-1] = " " + ckpt_func[-1]
|
||||
delete_unused_value_func(node, ckpt_func)
|
||||
node_idx += 1
|
||||
|
||||
ckpt_func.append(' ' + _gen_ckpt_output(outputs) + '\n\n')
|
||||
ckpt_func.append(" " + _gen_ckpt_output(outputs) + "\n\n")
|
||||
ckpt_func += ckpt_func_buffer
|
||||
|
||||
# last level
|
||||
else:
|
||||
for node in node_list:
|
||||
emit_node_func(node, ckpt_func)
|
||||
ckpt_func[-1] = ' ' + ckpt_func[-1]
|
||||
ckpt_func[-1] = " " + ckpt_func[-1]
|
||||
delete_unused_value_func(node, ckpt_func)
|
||||
|
||||
ckpt_func.append(' ' + _gen_ckpt_output(outputs) + '\n\n')
|
||||
ckpt_func.append(" " + _gen_ckpt_output(outputs) + "\n\n")
|
||||
|
||||
usage = _gen_ckpt_usage(label, inputs, outputs, False) + '\n'
|
||||
usage = _gen_ckpt_usage(label, inputs, outputs, False) + "\n"
|
||||
if in_ckpt:
|
||||
usage = ' ' + usage
|
||||
usage = " " + usage
|
||||
body.append(usage)
|
||||
|
||||
|
||||
@@ -229,7 +232,7 @@ def emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node_func,
|
||||
|
||||
# process ckpt_regions
|
||||
if node_idx in start_idx:
|
||||
ckpt_node_list = node_list[node_idx:end_idx[start_idx.index(node_idx)] + 1]
|
||||
ckpt_node_list = node_list[node_idx : end_idx[start_idx.index(node_idx)] + 1]
|
||||
emit_ckpt_func(body, ckpt_func, ckpt_node_list, emit_node_func, delete_unused_value_func)
|
||||
node_idx += len(ckpt_node_list)
|
||||
|
||||
@@ -243,7 +246,6 @@ def emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node_func,
|
||||
|
||||
@compatibility(is_backward_compatible=True)
|
||||
class ActivationCheckpointCodeGen(CodeGen):
|
||||
|
||||
def _gen_python_code(self, nodes, root_module: str, namespace: _Namespace) -> PythonCode:
|
||||
free_vars: List[str] = []
|
||||
body: List[str] = []
|
||||
@@ -251,7 +253,7 @@ class ActivationCheckpointCodeGen(CodeGen):
|
||||
wrapped_fns: Dict[str, None] = {}
|
||||
|
||||
# Wrap string in list to pass by reference
|
||||
maybe_return_annotation: List[str] = ['']
|
||||
maybe_return_annotation: List[str] = [""]
|
||||
|
||||
def add_global(name_hint: str, obj: Any):
|
||||
"""Add an obj to be tracked as a global.
|
||||
@@ -259,7 +261,7 @@ class ActivationCheckpointCodeGen(CodeGen):
|
||||
Graph, like functions or types.
|
||||
Returns: the global name that should be used to reference 'obj' in generated source.
|
||||
"""
|
||||
if _is_from_torch(obj) and obj != torch.device: # to support registering torch.device
|
||||
if _is_from_torch(obj) and obj != torch.device: # to support registering torch.device
|
||||
# HACK: workaround for how torch custom ops are registered. We
|
||||
# can't import them like normal modules so they must retain their
|
||||
# fully qualified name.
|
||||
@@ -281,16 +283,16 @@ class ActivationCheckpointCodeGen(CodeGen):
|
||||
def type_repr(o: Any):
|
||||
if o == ():
|
||||
# Empty tuple is used for empty tuple type annotation Tuple[()]
|
||||
return '()'
|
||||
return "()"
|
||||
|
||||
typename = _type_repr(o)
|
||||
|
||||
if hasattr(o, '__origin__'):
|
||||
if hasattr(o, "__origin__"):
|
||||
# This is a generic type, e.g. typing.List[torch.Tensor]
|
||||
origin_type = _origin_type_map.get(o.__origin__, o.__origin__)
|
||||
origin_typename = add_global(_type_repr(origin_type), origin_type)
|
||||
|
||||
if hasattr(o, '__args__'):
|
||||
if hasattr(o, "__args__"):
|
||||
# Assign global names for each of the inner type variables.
|
||||
args = [type_repr(arg) for arg in o.__args__]
|
||||
|
||||
@@ -309,19 +311,18 @@ class ActivationCheckpointCodeGen(CodeGen):
|
||||
return add_global(typename, o)
|
||||
|
||||
def _format_args(args: Tuple[Argument, ...], kwargs: Dict[str, Argument]) -> str:
|
||||
|
||||
def _get_repr(arg):
|
||||
# Handle NamedTuples (if it has `_fields`) via add_global.
|
||||
if isinstance(arg, tuple) and hasattr(arg, '_fields'):
|
||||
if isinstance(arg, tuple) and hasattr(arg, "_fields"):
|
||||
qualified_name = _get_qualified_name(type(arg))
|
||||
global_name = add_global(qualified_name, type(arg))
|
||||
return f"{global_name}{repr(tuple(arg))}"
|
||||
return repr(arg)
|
||||
|
||||
args_s = ', '.join(_get_repr(a) for a in args)
|
||||
kwargs_s = ', '.join(f'{k} = {_get_repr(v)}' for k, v in kwargs.items())
|
||||
args_s = ", ".join(_get_repr(a) for a in args)
|
||||
kwargs_s = ", ".join(f"{k} = {_get_repr(v)}" for k, v in kwargs.items())
|
||||
if args_s and kwargs_s:
|
||||
return f'{args_s}, {kwargs_s}'
|
||||
return f"{args_s}, {kwargs_s}"
|
||||
return args_s or kwargs_s
|
||||
|
||||
# Run through reverse nodes and record the first instance of a use
|
||||
@@ -347,82 +348,94 @@ class ActivationCheckpointCodeGen(CodeGen):
|
||||
not used in the remainder of the code are freed and the memory usage
|
||||
of the code is optimal.
|
||||
"""
|
||||
if user.op == 'placeholder':
|
||||
if user.op == "placeholder":
|
||||
return
|
||||
if user.op == 'output':
|
||||
body.append('\n')
|
||||
if user.op == "output":
|
||||
body.append("\n")
|
||||
return
|
||||
nodes_to_delete = user_to_last_uses.get(user, [])
|
||||
if len(nodes_to_delete):
|
||||
to_delete_str = ' = '.join([repr(n) for n in nodes_to_delete] + ['None'])
|
||||
body.append(f'; {to_delete_str}\n')
|
||||
to_delete_str = " = ".join([repr(n) for n in nodes_to_delete] + ["None"])
|
||||
body.append(f"; {to_delete_str}\n")
|
||||
else:
|
||||
body.append('\n')
|
||||
body.append("\n")
|
||||
|
||||
# NOTE: we add a variable to distinguish body and ckpt_func
|
||||
def emit_node(node: Node, body):
|
||||
maybe_type_annotation = '' if node.type is None else f' : {type_repr(node.type)}'
|
||||
if node.op == 'placeholder':
|
||||
maybe_type_annotation = "" if node.type is None else f" : {type_repr(node.type)}"
|
||||
if node.op == "placeholder":
|
||||
assert isinstance(node.target, str)
|
||||
maybe_default_arg = '' if not node.args else f' = {repr(node.args[0])}'
|
||||
free_vars.append(f'{node.target}{maybe_type_annotation}{maybe_default_arg}')
|
||||
raw_name = node.target.replace('*', '')
|
||||
maybe_default_arg = "" if not node.args else f" = {repr(node.args[0])}"
|
||||
free_vars.append(f"{node.target}{maybe_type_annotation}{maybe_default_arg}")
|
||||
raw_name = node.target.replace("*", "")
|
||||
if raw_name != repr(node):
|
||||
body.append(f'{repr(node)} = {raw_name}\n')
|
||||
body.append(f"{repr(node)} = {raw_name}\n")
|
||||
return
|
||||
elif node.op == 'call_method':
|
||||
elif node.op == "call_method":
|
||||
assert isinstance(node.target, str)
|
||||
body.append(f'{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.target)}'
|
||||
f'({_format_args(node.args[1:], node.kwargs)})')
|
||||
body.append(
|
||||
f"{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.target)}"
|
||||
f"({_format_args(node.args[1:], node.kwargs)})"
|
||||
)
|
||||
return
|
||||
elif node.op == 'call_function':
|
||||
elif node.op == "call_function":
|
||||
assert callable(node.target)
|
||||
# pretty print operators
|
||||
if node.target.__module__ == '_operator' and node.target.__name__ in magic_methods:
|
||||
if node.target.__module__ == "_operator" and node.target.__name__ in magic_methods:
|
||||
assert isinstance(node.args, tuple)
|
||||
body.append(f'{repr(node)}{maybe_type_annotation} = '
|
||||
f'{magic_methods[node.target.__name__].format(*(repr(a) for a in node.args))}')
|
||||
body.append(
|
||||
f"{repr(node)}{maybe_type_annotation} = "
|
||||
f"{magic_methods[node.target.__name__].format(*(repr(a) for a in node.args))}"
|
||||
)
|
||||
return
|
||||
|
||||
# pretty print inplace operators; required for jit.script to work properly
|
||||
# not currently supported in normal FX graphs, but generated by torchdynamo
|
||||
if node.target.__module__ == '_operator' and node.target.__name__ in inplace_methods:
|
||||
body.append(f'{inplace_methods[node.target.__name__].format(*(repr(a) for a in node.args))}; '
|
||||
f'{repr(node)}{maybe_type_annotation} = {repr(node.args[0])}')
|
||||
if node.target.__module__ == "_operator" and node.target.__name__ in inplace_methods:
|
||||
body.append(
|
||||
f"{inplace_methods[node.target.__name__].format(*(repr(a) for a in node.args))}; "
|
||||
f"{repr(node)}{maybe_type_annotation} = {repr(node.args[0])}"
|
||||
)
|
||||
return
|
||||
|
||||
qualified_name = _get_qualified_name(node.target)
|
||||
global_name = add_global(qualified_name, node.target)
|
||||
# special case for getattr: node.args could be 2-argument or 3-argument
|
||||
# 2-argument: attribute access; 3-argument: fall through to attrib function call with default value
|
||||
if global_name == 'getattr' and \
|
||||
isinstance(node.args, tuple) and \
|
||||
isinstance(node.args[1], str) and \
|
||||
node.args[1].isidentifier() and \
|
||||
len(node.args) == 2:
|
||||
if (
|
||||
global_name == "getattr"
|
||||
and isinstance(node.args, tuple)
|
||||
and isinstance(node.args[1], str)
|
||||
and node.args[1].isidentifier()
|
||||
and len(node.args) == 2
|
||||
):
|
||||
body.append(
|
||||
f'{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.args[1])}')
|
||||
f"{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.args[1])}"
|
||||
)
|
||||
return
|
||||
body.append(
|
||||
f'{repr(node)}{maybe_type_annotation} = {global_name}({_format_args(node.args, node.kwargs)})')
|
||||
if node.meta.get('is_wrapped', False):
|
||||
f"{repr(node)}{maybe_type_annotation} = {global_name}({_format_args(node.args, node.kwargs)})"
|
||||
)
|
||||
if node.meta.get("is_wrapped", False):
|
||||
wrapped_fns.setdefault(global_name)
|
||||
return
|
||||
elif node.op == 'call_module':
|
||||
elif node.op == "call_module":
|
||||
assert isinstance(node.target, str)
|
||||
body.append(f'{repr(node)}{maybe_type_annotation} = '
|
||||
f'{_format_target(root_module, node.target)}({_format_args(node.args, node.kwargs)})')
|
||||
body.append(
|
||||
f"{repr(node)}{maybe_type_annotation} = "
|
||||
f"{_format_target(root_module, node.target)}({_format_args(node.args, node.kwargs)})"
|
||||
)
|
||||
return
|
||||
elif node.op == 'get_attr':
|
||||
elif node.op == "get_attr":
|
||||
assert isinstance(node.target, str)
|
||||
body.append(f'{repr(node)}{maybe_type_annotation} = {_format_target(root_module, node.target)}')
|
||||
body.append(f"{repr(node)}{maybe_type_annotation} = {_format_target(root_module, node.target)}")
|
||||
return
|
||||
elif node.op == 'output':
|
||||
elif node.op == "output":
|
||||
if node.type is not None:
|
||||
maybe_return_annotation[0] = f" -> {type_repr(node.type)}"
|
||||
body.append(self.generate_output(node.args[0]))
|
||||
return
|
||||
raise NotImplementedError(f'node: {node.op} {node.target}')
|
||||
raise NotImplementedError(f"node: {node.op} {node.target}")
|
||||
|
||||
# Modified for activation checkpointing
|
||||
ckpt_func = []
|
||||
@@ -432,13 +445,13 @@ class ActivationCheckpointCodeGen(CodeGen):
|
||||
# If the Graph has no non-placeholder nodes, no lines for the body
|
||||
# have been emitted. To continue to have valid Python code, emit a
|
||||
# single pass statement
|
||||
body.append('pass\n')
|
||||
body.append("pass\n")
|
||||
|
||||
if len(wrapped_fns) > 0:
|
||||
wrap_name = add_global('wrap', torch.fx.wrap)
|
||||
wrap_stmts = '\n'.join([f'{wrap_name}("{name}")' for name in wrapped_fns])
|
||||
wrap_name = add_global("wrap", torch.fx.wrap)
|
||||
wrap_stmts = "\n".join([f'{wrap_name}("{name}")' for name in wrapped_fns])
|
||||
else:
|
||||
wrap_stmts = ''
|
||||
wrap_stmts = ""
|
||||
|
||||
if self._body_transformer:
|
||||
body = self._body_transformer(body)
|
||||
@@ -447,11 +460,11 @@ class ActivationCheckpointCodeGen(CodeGen):
|
||||
add_global(name, value)
|
||||
|
||||
prologue = self.gen_fn_def(free_vars, maybe_return_annotation[0])
|
||||
prologue = ''.join(ckpt_func) + prologue
|
||||
prologue = "".join(ckpt_func) + prologue
|
||||
prologue = prologue
|
||||
|
||||
code = ''.join(body)
|
||||
code = '\n'.join(' ' + line for line in code.split('\n'))
|
||||
code = "".join(body)
|
||||
code = "\n".join(" " + line for line in code.split("\n"))
|
||||
fn_code = f"""
|
||||
{wrap_stmts}
|
||||
{prologue}
|
||||
|
@@ -13,6 +13,7 @@ from torch.fx.graph import PythonCode
|
||||
|
||||
try:
|
||||
from torch.fx.graph import _PyTreeCodeGen
|
||||
|
||||
SUPPORT_PT_CODEGEN = True
|
||||
except ImportError:
|
||||
SUPPORT_PT_CODEGEN = False
|
||||
@@ -24,7 +25,6 @@ from torch.nn.modules.module import _addindent
|
||||
# This is a copy of torch.fx.graph_module._WrappedCall.
|
||||
# It should be removed when we stop supporting torch < 1.12.0.
|
||||
class _WrappedCall:
|
||||
|
||||
def __init__(self, cls, cls_call):
|
||||
self.cls = cls
|
||||
self.cls_call = cls_call
|
||||
@@ -50,12 +50,14 @@ class _WrappedCall:
|
||||
|
||||
# constituent substrings of the error message
|
||||
tb_repr = traceback.format_exc()
|
||||
custom_msg = ("Call using an FX-traced Module, "
|
||||
f"line {err_lineno} of the traced Module's "
|
||||
"generated forward function:")
|
||||
before_err = "".join(all_src_lines[err_lineno - 2:err_lineno])
|
||||
custom_msg = (
|
||||
"Call using an FX-traced Module, "
|
||||
f"line {err_lineno} of the traced Module's "
|
||||
"generated forward function:"
|
||||
)
|
||||
before_err = "".join(all_src_lines[err_lineno - 2 : err_lineno])
|
||||
marker = "~" * err_line_len + "~~~ <--- HERE"
|
||||
err_and_after_err = "\n".join(all_src_lines[err_lineno:err_lineno + 2])
|
||||
err_and_after_err = "\n".join(all_src_lines[err_lineno : err_lineno + 2])
|
||||
|
||||
# joined message
|
||||
return "\n".join([tb_repr, custom_msg, before_err, marker, err_and_after_err])
|
||||
@@ -65,11 +67,14 @@ class _WrappedCall:
|
||||
if self.cls_call is not None:
|
||||
return self.cls_call(obj, *args, **kwargs)
|
||||
else:
|
||||
return super(self.cls, obj).__call__(*args, **kwargs) # type: ignore[misc]
|
||||
return super(self.cls, obj).__call__(*args, **kwargs) # type: ignore[misc]
|
||||
except Exception as e:
|
||||
assert e.__traceback__
|
||||
topmost_framesummary: traceback.FrameSummary = \
|
||||
traceback.StackSummary.extract(traceback.walk_tb(e.__traceback__))[-1] # type: ignore[arg-type]
|
||||
topmost_framesummary: traceback.FrameSummary = traceback.StackSummary.extract(
|
||||
traceback.walk_tb(e.__traceback__)
|
||||
)[
|
||||
-1
|
||||
] # type: ignore[arg-type]
|
||||
if "eval_with_key" in topmost_framesummary.filename:
|
||||
print(_WrappedCall._generate_error_message(topmost_framesummary), file=sys.stderr)
|
||||
raise e.with_traceback(None)
|
||||
@@ -99,10 +104,9 @@ class ColoGraphModule(torch.fx.GraphModule):
|
||||
code.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
root: Union[torch.nn.Module, Dict[str, Any]],
|
||||
graph: torch.fx.Graph,
|
||||
class_name: str = 'GraphModule'):
|
||||
def __init__(
|
||||
self, root: Union[torch.nn.Module, Dict[str, Any]], graph: torch.fx.Graph, class_name: str = "GraphModule"
|
||||
):
|
||||
super().__init__(root, graph, class_name)
|
||||
|
||||
def bind(self, ckpt_def, globals):
|
||||
@@ -134,7 +138,7 @@ class ColoGraphModule(torch.fx.GraphModule):
|
||||
if SUPPORT_PT_CODEGEN and isinstance(self._graph._codegen, _PyTreeCodeGen):
|
||||
self._in_spec = self._graph._codegen.pytree_info.in_spec
|
||||
self._out_spec = self._graph._codegen.pytree_info.out_spec
|
||||
python_code = self._graph.python_code(root_module='self')
|
||||
python_code = self._graph.python_code(root_module="self")
|
||||
self._code = python_code.src
|
||||
|
||||
# To split ckpt functions code and forward code
|
||||
@@ -157,8 +161,8 @@ class ColoGraphModule(torch.fx.GraphModule):
|
||||
# bypass patching of torch.nn.Module.__call__ done while symbolic tracing.
|
||||
cls_call = cls.__call__ if "__call__" in vars(cls) else None
|
||||
|
||||
if '_wrapped_call' not in vars(cls):
|
||||
cls._wrapped_call = _WrappedCall(cls, cls_call) # type: ignore[attr-defined]
|
||||
if "_wrapped_call" not in vars(cls):
|
||||
cls._wrapped_call = _WrappedCall(cls, cls_call) # type: ignore[attr-defined]
|
||||
|
||||
def call_wrapped(self, *args, **kwargs):
|
||||
return self._wrapped_call(self, *args, **kwargs)
|
||||
@@ -182,7 +186,7 @@ class ColoGraphModule(torch.fx.GraphModule):
|
||||
"""
|
||||
folder = Path(folder)
|
||||
Path(folder).mkdir(exist_ok=True)
|
||||
torch.save(self.state_dict(), folder / 'state_dict.pt')
|
||||
torch.save(self.state_dict(), folder / "state_dict.pt")
|
||||
tab = " " * 4
|
||||
|
||||
# we add import colossalai here
|
||||
@@ -208,10 +212,10 @@ class {module_name}(torch.nn.Module):
|
||||
for module_name, module in self.named_children():
|
||||
module_str = _gen_model_repr(module_name, module)
|
||||
if module_str is None:
|
||||
module_file = folder / f'{module_name}.pt'
|
||||
module_file = folder / f"{module_name}.pt"
|
||||
torch.save(module, module_file)
|
||||
blobified_modules.append(module_name)
|
||||
module_repr = module.__repr__().replace('\r', ' ').replace('\n', ' ')
|
||||
module_repr = module.__repr__().replace("\r", " ").replace("\n", " ")
|
||||
module_str = f"torch.load(r'{module_file}') # {module_repr}"
|
||||
model_str += f"{tab*2}self.{module_name} = {module_str}\n"
|
||||
|
||||
@@ -228,12 +232,14 @@ class {module_name}(torch.nn.Module):
|
||||
model_str += f"{tab*2}self.load_state_dict(torch.load(r'{folder}/state_dict.pt'))\n"
|
||||
model_str += f"{_addindent(self.code, 4)}\n"
|
||||
|
||||
module_file = folder / 'module.py'
|
||||
module_file = folder / "module.py"
|
||||
module_file.write_text(model_str)
|
||||
|
||||
init_file = folder / '__init__.py'
|
||||
init_file.write_text('from .module import *')
|
||||
init_file = folder / "__init__.py"
|
||||
init_file.write_text("from .module import *")
|
||||
|
||||
if len(blobified_modules) > 0:
|
||||
warnings.warn("Was not able to save the following children modules as reprs -"
|
||||
f"saved as pickled files instead: {blobified_modules}")
|
||||
warnings.warn(
|
||||
"Was not able to save the following children modules as reprs -"
|
||||
f"saved as pickled files instead: {blobified_modules}"
|
||||
)
|
||||
|
@@ -1,9 +1,9 @@
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Callable, ClassVar, Dict, List, Optional, Tuple, Union
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch.autograd.profiler_util import _format_memory, _format_time
|
||||
from torch.fx import Graph, GraphModule, Node
|
||||
from torch.autograd.profiler_util import _format_memory
|
||||
from torch.fx import Node
|
||||
|
||||
from colossalai._analyzer.envs import MeshConfig
|
||||
|
||||
@@ -85,12 +85,12 @@ class MetaInfo:
|
||||
node: Node
|
||||
|
||||
# directory
|
||||
mod_dir: str = ''
|
||||
mod_dir: str = ""
|
||||
|
||||
# ctx[data_ptr] = Tensor
|
||||
# mark the storage for ctx.save_for_backward
|
||||
global_ctx: Dict[str, torch.Tensor] = field(default_factory=lambda: {}) # globally shared
|
||||
curr_ctx: Dict[str, torch.Tensor] = field(default_factory=lambda: {}) # global_ctx till this node
|
||||
global_ctx: Dict[str, torch.Tensor] = field(default_factory=lambda: {}) # globally shared
|
||||
curr_ctx: Dict[str, torch.Tensor] = field(default_factory=lambda: {}) # global_ctx till this node
|
||||
|
||||
# should be updated after each graph manipulation
|
||||
# ============================== Update ====================================
|
||||
@@ -100,7 +100,7 @@ class MetaInfo:
|
||||
|
||||
inputs: Tuple[torch.Tensor] = ()
|
||||
outputs: Tuple[torch.Tensor] = ()
|
||||
is_alias: Tuple[bool] = () # whether the output is an alias of input
|
||||
is_alias: Tuple[bool] = () # whether the output is an alias of input
|
||||
|
||||
# compute cost
|
||||
fwd_flop: Optional[int] = 0
|
||||
@@ -112,29 +112,29 @@ class MetaInfo:
|
||||
|
||||
# should keep the same whenever manipulated
|
||||
# ============================= Invariant ==================================
|
||||
activation_checkpoint: Tuple[torch.Tensor] = () # (region_0, region_1, ...) support nested codegen
|
||||
activation_checkpoint: Tuple[torch.Tensor] = () # (region_0, region_1, ...) support nested codegen
|
||||
to_offload: Optional[bool] = False
|
||||
sharding_spec: str = 'RR'
|
||||
sharding_spec: str = "RR"
|
||||
|
||||
def __new__(cls, node: Node, **kwargs):
|
||||
orig_init = cls.__init__
|
||||
|
||||
# if initialized, return the existing one
|
||||
# should disable the __init__ function
|
||||
if node.meta.get('info', None) is not None:
|
||||
if node.meta.get("info", None) is not None:
|
||||
|
||||
def _dummy(self, *args, **kwargs):
|
||||
if getattr(self, '_is_init', False):
|
||||
if getattr(self, "_is_init", False):
|
||||
self._is_init = True
|
||||
orig_init(self, *args, **kwargs)
|
||||
cls.__init__ = orig_init
|
||||
|
||||
cls.__init__ = _dummy
|
||||
return node.meta['info']
|
||||
return node.meta["info"]
|
||||
return super().__new__(cls)
|
||||
|
||||
def __post_init__(self):
|
||||
self.node.meta['info'] = self
|
||||
self.node.meta["info"] = self
|
||||
|
||||
@property
|
||||
def fwd_time(self, tflops: float = MeshConfig.TFLOPS, bandwidth: float = MeshConfig.BANDWIDTH):
|
||||
@@ -188,24 +188,26 @@ class MetaInfo:
|
||||
return compute_size_in_bytes(self.inputs)
|
||||
|
||||
def __repr__(self):
|
||||
s = f'Node {self.node.name}'
|
||||
s = f"Node {self.node.name}"
|
||||
if self.parameters:
|
||||
s += f'\n\thas parameter of size {_format_memory(self.param_size)}'
|
||||
s += f"\n\thas parameter of size {_format_memory(self.param_size)}"
|
||||
if self.buffers:
|
||||
s += f'\n\thas buffer of size {_format_memory(self.buffer_size)}'
|
||||
s += f"\n\thas buffer of size {_format_memory(self.buffer_size)}"
|
||||
if self.output_size:
|
||||
s += f'\n\thas output activation of size {_format_memory(self.output_size)}'
|
||||
s += f"\n\thas output activation of size {_format_memory(self.output_size)}"
|
||||
# if self.total_size:
|
||||
# s += f'\n\thas total activation of size {_format_memory(self.total_size)}'
|
||||
if self.temp_size:
|
||||
s += f'\n\thas temp activation of size {_format_memory(self.temp_size)}'
|
||||
s += f"\n\thas temp activation of size {_format_memory(self.temp_size)}"
|
||||
if self.backward_size:
|
||||
s += f'\n\thas backward activation of size {_format_memory(self.backward_size)}'
|
||||
s += f'\n\tfwd_flop = {self.fwd_flop}'\
|
||||
f'\n\tbwd_flop = {self.bwd_flop}'\
|
||||
f'\n\tfwd_comm = {self.fwd_comm}'\
|
||||
f'\n\tbwd_comm = {self.bwd_comm}'\
|
||||
f'\n\tto_recompute = {self.to_recompute}'\
|
||||
f'\n\tto_offload = {self.to_offload}'\
|
||||
f'\n\tsharding_spec = {self.sharding_spec}'
|
||||
s += f"\n\thas backward activation of size {_format_memory(self.backward_size)}"
|
||||
s += (
|
||||
f"\n\tfwd_flop = {self.fwd_flop}"
|
||||
f"\n\tbwd_flop = {self.bwd_flop}"
|
||||
f"\n\tfwd_comm = {self.fwd_comm}"
|
||||
f"\n\tbwd_comm = {self.bwd_comm}"
|
||||
f"\n\tto_recompute = {self.to_recompute}"
|
||||
f"\n\tto_offload = {self.to_offload}"
|
||||
f"\n\tsharding_spec = {self.sharding_spec}"
|
||||
)
|
||||
return s
|
||||
|
@@ -1,8 +1,8 @@
|
||||
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
|
||||
from typing import Any, Dict, Iterator, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.fx
|
||||
from torch.autograd.profiler_util import _format_memory, _format_time
|
||||
from torch.autograd.profiler_util import _format_memory
|
||||
from torch.fx import GraphModule
|
||||
from torch.fx.node import Argument, Node, Target
|
||||
|
||||
@@ -13,14 +13,14 @@ from colossalai._analyzer.fx.node_util import MetaInfo
|
||||
def _format_flops(flops: float) -> str:
|
||||
"""Returns a formatted FLOP size string"""
|
||||
if flops > 1e12:
|
||||
return f'{flops / 1e12:.2f} TFLOPs'
|
||||
return f"{flops / 1e12:.2f} TFLOPs"
|
||||
elif flops > 1e9:
|
||||
return f'{flops / 1e9:.2f} GFLOPs'
|
||||
return f"{flops / 1e9:.2f} GFLOPs"
|
||||
elif flops > 1e6:
|
||||
return f'{flops / 1e6:.2f} MFLOPs'
|
||||
return f"{flops / 1e6:.2f} MFLOPs"
|
||||
elif flops > 1e3:
|
||||
return f'{flops / 1e3:.2f} kFLOPs'
|
||||
return f'{flops} FLOPs'
|
||||
return f"{flops / 1e3:.2f} kFLOPs"
|
||||
return f"{flops} FLOPs"
|
||||
|
||||
|
||||
def _denormalize_tuple(t: Tuple[int, ...]) -> Tuple[int, ...]:
|
||||
@@ -42,10 +42,11 @@ class GraphProfiler(torch.fx.Interpreter):
|
||||
Fetch shape argument from ``ShapeProp`` without re-executing
|
||||
the ``GraphModule`` from scratch.
|
||||
"""
|
||||
|
||||
_profileable = [
|
||||
'call_function',
|
||||
'call_module',
|
||||
'call_method',
|
||||
"call_function",
|
||||
"call_module",
|
||||
"call_method",
|
||||
]
|
||||
|
||||
def __init__(self, module: GraphModule, garbage_collect_values: bool = True):
|
||||
@@ -77,14 +78,13 @@ class GraphProfiler(torch.fx.Interpreter):
|
||||
self.args_iter: Iterator[Any] = iter(args)
|
||||
|
||||
for node in self.module.graph.nodes:
|
||||
|
||||
self.run_node(node) # No need to store.
|
||||
self.run_node(node) # No need to store.
|
||||
|
||||
if self.garbage_collect_values:
|
||||
for to_delete in self.user_to_last_uses.get(node, []):
|
||||
del self.env[to_delete]
|
||||
|
||||
if node.op == 'output':
|
||||
if node.op == "output":
|
||||
output_val = self.env[node]
|
||||
return self.module.graph.process_outputs(output_val) if enable_io_processing else output_val
|
||||
|
||||
@@ -133,9 +133,11 @@ class GraphProfiler(torch.fx.Interpreter):
|
||||
try:
|
||||
from tabulate import tabulate
|
||||
except ImportError:
|
||||
print("`summary` relies on the library `tabulate`, "
|
||||
"which could not be found on this machine. Run `pip "
|
||||
"install tabulate` to install the library.")
|
||||
print(
|
||||
"`summary` relies on the library `tabulate`, "
|
||||
"which could not be found on this machine. Run `pip "
|
||||
"install tabulate` to install the library."
|
||||
)
|
||||
|
||||
# Build up a list of summary information for each node
|
||||
node_summaries: List[List[Any]] = []
|
||||
@@ -145,36 +147,38 @@ class GraphProfiler(torch.fx.Interpreter):
|
||||
node: Node
|
||||
n_info = MetaInfo(node)
|
||||
last_n_info = last_n_info or n_info
|
||||
node_summaries.append([
|
||||
node.op,
|
||||
str(node),
|
||||
_format_memory(n_info.accumulate_size),
|
||||
_format_memory(n_info.accumulate_size - last_n_info.accumulate_size),
|
||||
_format_memory(n_info.output_size),
|
||||
_format_memory(n_info.temp_size),
|
||||
_format_memory(n_info.param_size),
|
||||
_format_memory(n_info.backward_size),
|
||||
_format_flops(n_info.fwd_flop),
|
||||
_format_flops(n_info.bwd_flop),
|
||||
])
|
||||
node_summaries.append(
|
||||
[
|
||||
node.op,
|
||||
str(node),
|
||||
_format_memory(n_info.accumulate_size),
|
||||
_format_memory(n_info.accumulate_size - last_n_info.accumulate_size),
|
||||
_format_memory(n_info.output_size),
|
||||
_format_memory(n_info.temp_size),
|
||||
_format_memory(n_info.param_size),
|
||||
_format_memory(n_info.backward_size),
|
||||
_format_flops(n_info.fwd_flop),
|
||||
_format_flops(n_info.bwd_flop),
|
||||
]
|
||||
)
|
||||
last_n_info = n_info
|
||||
|
||||
# Use the ``tabulate`` library to create a well-formatted table
|
||||
# presenting our summary information
|
||||
headers: List[str] = [
|
||||
'Op type',
|
||||
'Op',
|
||||
'Accumulate size',
|
||||
'Incremental size',
|
||||
'Output size',
|
||||
'Temp size',
|
||||
'Param size',
|
||||
'Backward size',
|
||||
'Fwd FLOPs',
|
||||
'Bwd FLOPs',
|
||||
"Op type",
|
||||
"Op",
|
||||
"Accumulate size",
|
||||
"Incremental size",
|
||||
"Output size",
|
||||
"Temp size",
|
||||
"Param size",
|
||||
"Backward size",
|
||||
"Fwd FLOPs",
|
||||
"Bwd FLOPs",
|
||||
]
|
||||
|
||||
return tabulate(node_summaries, headers=headers, stralign='right')
|
||||
return tabulate(node_summaries, headers=headers, stralign="right")
|
||||
|
||||
|
||||
class CommunicationProfiler(GraphProfiler):
|
||||
@@ -222,6 +226,7 @@ class FlopProfiler(GraphProfiler):
|
||||
>>> def my_fn_flop_count_impl(*args, **kwargs):
|
||||
>>> return 0, 0
|
||||
"""
|
||||
|
||||
_custom_flop_count_impl = {}
|
||||
|
||||
def run_node(self, n: torch.fx.Node) -> Any:
|
||||
@@ -246,11 +251,13 @@ class FlopProfiler(GraphProfiler):
|
||||
(
|
||||
n_info.fwd_flop,
|
||||
n_info.bwd_flop,
|
||||
) = getattr(self, n.op)(n.target, args, kwargs)
|
||||
) = getattr(
|
||||
self, n.op
|
||||
)(n.target, args, kwargs)
|
||||
except Exception as e:
|
||||
raise RuntimeError(
|
||||
f'Error {str(e)} occurred when profiling node {n}, node.target = {n.target}. '
|
||||
f'Please refer to function\'s docstring to register the relevant profile_impl for this node!'
|
||||
f"Error {str(e)} occurred when profiling node {n}, node.target = {n.target}. "
|
||||
f"Please refer to function's docstring to register the relevant profile_impl for this node!"
|
||||
) from e
|
||||
|
||||
# retain the autograd graph
|
||||
@@ -259,7 +266,7 @@ class FlopProfiler(GraphProfiler):
|
||||
|
||||
return _denormalize_tuple(n_info.outputs)
|
||||
|
||||
def call_function(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
|
||||
def call_function(self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
|
||||
"""
|
||||
Execute a ``call_function`` node and return the profiling result.
|
||||
Dispatch to ``_custom_flop_count_impl`` if ``call_function`` should be
|
||||
@@ -283,7 +290,7 @@ class FlopProfiler(GraphProfiler):
|
||||
else:
|
||||
return flop_count(target, *args, **kwargs)
|
||||
|
||||
def call_method(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
|
||||
def call_method(self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
|
||||
"""
|
||||
Execute a ``call_method`` node and return the profiling result.
|
||||
|
||||
@@ -301,7 +308,7 @@ class FlopProfiler(GraphProfiler):
|
||||
assert isinstance(target, str)
|
||||
return flop_count(getattr(torch.Tensor, target), *args, **kwargs)
|
||||
|
||||
def call_module(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
|
||||
def call_module(self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
|
||||
"""
|
||||
Execute a ``call_module`` node and return the profiling result.
|
||||
|
||||
@@ -336,9 +343,10 @@ def graph_profile_pass(module: GraphModule, *args, verbose=False) -> GraphModule
|
||||
Returns:
|
||||
GraphModule: The same GraphModule with profiling information
|
||||
"""
|
||||
for profiler_cls in (FlopProfiler,
|
||||
# CommunicationProfiler, # TODO: add communication profiling
|
||||
):
|
||||
for profiler_cls in (
|
||||
FlopProfiler,
|
||||
# CommunicationProfiler, # TODO: add communication profiling
|
||||
):
|
||||
profiler = profiler_cls(module)
|
||||
profiler.propagate(*args, device=_current_device(module))
|
||||
|
||||
|
@@ -54,7 +54,7 @@ def _current_device(module):
|
||||
try:
|
||||
return next(module.parameters()).device
|
||||
except StopIteration:
|
||||
return torch.device('cpu')
|
||||
return torch.device("cpu")
|
||||
|
||||
|
||||
@compatibility(is_backward_compatible=False)
|
||||
@@ -90,6 +90,7 @@ class ShapeProp(torch.fx.Interpreter):
|
||||
>>> # do something here
|
||||
>>> return torch.empty(output_shape, device=output_device)
|
||||
"""
|
||||
|
||||
_custom_dispatch_func = {}
|
||||
_mode = MetaTensorMode()
|
||||
|
||||
@@ -115,15 +116,14 @@ class ShapeProp(torch.fx.Interpreter):
|
||||
r = getattr(self, n.op)(n.target, args, kwargs)
|
||||
|
||||
def unwrap_fn(elem):
|
||||
|
||||
def _convert_meta(t: torch.Tensor):
|
||||
if t.device == 'meta':
|
||||
if t.device == "meta":
|
||||
return t
|
||||
else:
|
||||
return t.to('meta')
|
||||
return t.to("meta")
|
||||
|
||||
if isinstance(elem, MetaTensor):
|
||||
if getattr(self, '_is_param', False):
|
||||
if getattr(self, "_is_param", False):
|
||||
return torch.nn.Parameter(_convert_meta(elem._tensor))
|
||||
return _convert_meta(elem._tensor)
|
||||
|
||||
@@ -139,21 +139,24 @@ class ShapeProp(torch.fx.Interpreter):
|
||||
n_info = MetaInfo(n)
|
||||
n_info.outputs = _normalize_tuple(r)
|
||||
|
||||
if n.op == 'call_module':
|
||||
if n.op == "call_module":
|
||||
submod = self.fetch_attr(n.target)
|
||||
n_info.parameters.update({k: MetaTensor(v) for k, v in submod.named_parameters()})
|
||||
n_info.buffers.update({k: MetaTensor(v) for k, v in submod.named_buffers()})
|
||||
|
||||
else:
|
||||
n_info.parameters.update({
|
||||
k.name: MetaTensor(v)
|
||||
for k, v in zip(n.args, args)
|
||||
if isinstance(k, torch.fx.Node) and isinstance(v, torch.nn.Parameter)
|
||||
})
|
||||
n_info.parameters.update(
|
||||
{
|
||||
k.name: MetaTensor(v)
|
||||
for k, v in zip(n.args, args)
|
||||
if isinstance(k, torch.fx.Node) and isinstance(v, torch.nn.Parameter)
|
||||
}
|
||||
)
|
||||
n_info.parameters.update({k: MetaTensor(v) for k, v in kwargs.items() if isinstance(v, torch.nn.Parameter)})
|
||||
|
||||
n_info.inputs = tuple(v for v in args if is_pure_tensor(v)) + \
|
||||
tuple(v for v in kwargs.values() if is_pure_tensor(v))
|
||||
n_info.inputs = tuple(v for v in args if is_pure_tensor(v)) + tuple(
|
||||
v for v in kwargs.values() if is_pure_tensor(v)
|
||||
)
|
||||
|
||||
# align with SPMD
|
||||
if isinstance(r, (tuple, list)):
|
||||
@@ -168,7 +171,7 @@ class ShapeProp(torch.fx.Interpreter):
|
||||
n_info.is_alias = _normalize_tuple(tree_map(crit, n_info.outputs))
|
||||
return r
|
||||
|
||||
def call_function(self, target: 'Target', args: Tuple[Any, ...], kwargs: Dict[str, Any]) -> Any:
|
||||
def call_function(self, target: "Target", args: Tuple[Any, ...], kwargs: Dict[str, Any]) -> Any:
|
||||
"""
|
||||
Execute a ``call_function`` node and return the result.
|
||||
If the target of ``Node`` is registered with ``@register_shape_impl``,
|
||||
@@ -197,7 +200,7 @@ class ShapeProp(torch.fx.Interpreter):
|
||||
else:
|
||||
return res
|
||||
|
||||
def call_method(self, target: 'Target', args: Tuple[Any, ...], kwargs: Dict[str, Any]) -> Any:
|
||||
def call_method(self, target: "Target", args: Tuple[Any, ...], kwargs: Dict[str, Any]) -> Any:
|
||||
"""
|
||||
Execute a ``call_method`` node and return the result.
|
||||
|
||||
@@ -218,7 +221,8 @@ class ShapeProp(torch.fx.Interpreter):
|
||||
|
||||
convert_to_parameter = False
|
||||
if target_method in (torch.Tensor.view, torch.Tensor.transpose) and isinstance(
|
||||
args[0], torch.nn.parameter.Parameter):
|
||||
args[0], torch.nn.parameter.Parameter
|
||||
):
|
||||
convert_to_parameter = True
|
||||
# Execute the method and return the result
|
||||
assert isinstance(target, str)
|
||||
|
@@ -1,5 +1,3 @@
|
||||
import torch
|
||||
import torch.fx
|
||||
from torch.fx import GraphModule
|
||||
|
||||
from .passes import ShapeProp, graph_profile_pass, shape_prop_pass
|
||||
@@ -7,7 +5,6 @@ from .passes.graph_profile import FlopProfiler
|
||||
|
||||
|
||||
def register_flop_count_impl(func):
|
||||
|
||||
def wrapper(impl):
|
||||
FlopProfiler._custom_flop_count_impl[func] = impl
|
||||
return impl
|
||||
@@ -16,7 +13,6 @@ def register_flop_count_impl(func):
|
||||
|
||||
|
||||
def register_shape_impl(func):
|
||||
|
||||
def wrapper(impl):
|
||||
ShapeProp._custom_dispatch_func[func] = impl
|
||||
return impl
|
||||
|
@@ -12,7 +12,7 @@ from .tracer import register_tracer_impl
|
||||
__all__ = []
|
||||
|
||||
|
||||
@register_tracer_impl(F.linear, name='_bias_addition_impl')
|
||||
@register_tracer_impl(F.linear, name="_bias_addition_impl")
|
||||
def linear_impl(input, weight, bias=None):
|
||||
if bias is None:
|
||||
return F.linear(input, weight)
|
||||
@@ -20,116 +20,130 @@ def linear_impl(input, weight, bias=None):
|
||||
return F.linear(input, weight) + bias
|
||||
|
||||
|
||||
@register_tracer_impl(F.conv1d, name='_bias_addition_impl')
|
||||
@register_tracer_impl(F.conv1d, name="_bias_addition_impl")
|
||||
def conv1d_impl(input, weight, bias=None, stride=_single(1), padding=_single(0), dilation=_single(1), groups=1):
|
||||
if bias is None:
|
||||
return F.conv1d(input, weight, stride=stride, padding=padding, dilation=dilation, groups=groups)
|
||||
else:
|
||||
return F.conv1d(input, weight, stride=stride, padding=padding, dilation=dilation, groups=groups) + bias.reshape(
|
||||
(-1, 1))
|
||||
(-1, 1)
|
||||
)
|
||||
|
||||
|
||||
@register_tracer_impl(F.conv2d, name='_bias_addition_impl')
|
||||
@register_tracer_impl(F.conv2d, name="_bias_addition_impl")
|
||||
def conv2d_impl(input, weight, bias=None, stride=_pair(1), padding=_pair(0), dilation=_pair(1), groups=1):
|
||||
if bias is None:
|
||||
return F.conv2d(input, weight, stride=stride, padding=padding, dilation=dilation, groups=groups)
|
||||
else:
|
||||
return F.conv2d(input, weight, stride=stride, padding=padding, dilation=dilation, groups=groups) + bias.reshape(
|
||||
(-1, 1, 1))
|
||||
(-1, 1, 1)
|
||||
)
|
||||
|
||||
|
||||
@register_tracer_impl(F.conv3d, name='_bias_addition_impl')
|
||||
@register_tracer_impl(F.conv3d, name="_bias_addition_impl")
|
||||
def conv3d_impl(input, weight, bias=None, stride=_triple(1), padding=_triple(0), dilation=_triple(1), groups=1):
|
||||
if bias is None:
|
||||
return F.conv3d(input, weight, stride=stride, padding=padding, dilation=dilation, groups=groups)
|
||||
else:
|
||||
return F.conv3d(input, weight, stride=stride, padding=padding, dilation=dilation, groups=groups) + bias.reshape(
|
||||
(-1, 1, 1, 1))
|
||||
(-1, 1, 1, 1)
|
||||
)
|
||||
|
||||
|
||||
@register_tracer_impl(F.conv_transpose1d, name='_bias_addition_impl')
|
||||
def conv_transpose1d_impl(input,
|
||||
weight,
|
||||
bias=None,
|
||||
stride=_single(1),
|
||||
padding=_single(0),
|
||||
output_padding=_single(0),
|
||||
groups=1,
|
||||
dilation=_single(1)):
|
||||
@register_tracer_impl(F.conv_transpose1d, name="_bias_addition_impl")
|
||||
def conv_transpose1d_impl(
|
||||
input,
|
||||
weight,
|
||||
bias=None,
|
||||
stride=_single(1),
|
||||
padding=_single(0),
|
||||
output_padding=_single(0),
|
||||
groups=1,
|
||||
dilation=_single(1),
|
||||
):
|
||||
if bias is None:
|
||||
return F.conv_transpose1d(input,
|
||||
weight,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
output_padding=output_padding,
|
||||
groups=groups,
|
||||
dilation=dilation)
|
||||
return F.conv_transpose1d(
|
||||
input,
|
||||
weight,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
output_padding=output_padding,
|
||||
groups=groups,
|
||||
dilation=dilation,
|
||||
)
|
||||
else:
|
||||
return F.conv_transpose1d(input,
|
||||
weight,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
output_padding=output_padding,
|
||||
groups=groups,
|
||||
dilation=dilation) + bias.reshape((-1, 1))
|
||||
return F.conv_transpose1d(
|
||||
input,
|
||||
weight,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
output_padding=output_padding,
|
||||
groups=groups,
|
||||
dilation=dilation,
|
||||
) + bias.reshape((-1, 1))
|
||||
|
||||
|
||||
@register_tracer_impl(F.conv_transpose2d, name='_bias_addition_impl')
|
||||
def conv_transpose2d_impl(input,
|
||||
weight,
|
||||
bias=None,
|
||||
stride=_pair(1),
|
||||
padding=_pair(0),
|
||||
output_padding=_pair(0),
|
||||
groups=1,
|
||||
dilation=_pair(1)):
|
||||
@register_tracer_impl(F.conv_transpose2d, name="_bias_addition_impl")
|
||||
def conv_transpose2d_impl(
|
||||
input, weight, bias=None, stride=_pair(1), padding=_pair(0), output_padding=_pair(0), groups=1, dilation=_pair(1)
|
||||
):
|
||||
if bias is None:
|
||||
return F.conv_transpose2d(input,
|
||||
weight,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
output_padding=output_padding,
|
||||
groups=groups,
|
||||
dilation=dilation)
|
||||
return F.conv_transpose2d(
|
||||
input,
|
||||
weight,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
output_padding=output_padding,
|
||||
groups=groups,
|
||||
dilation=dilation,
|
||||
)
|
||||
else:
|
||||
return F.conv_transpose2d(input,
|
||||
weight,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
output_padding=output_padding,
|
||||
groups=groups,
|
||||
dilation=dilation) + bias.reshape((-1, 1, 1))
|
||||
return F.conv_transpose2d(
|
||||
input,
|
||||
weight,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
output_padding=output_padding,
|
||||
groups=groups,
|
||||
dilation=dilation,
|
||||
) + bias.reshape((-1, 1, 1))
|
||||
|
||||
|
||||
@register_tracer_impl(F.conv_transpose3d, name='_bias_addition_impl')
|
||||
def conv_transpose3d_impl(input,
|
||||
weight,
|
||||
bias=None,
|
||||
stride=_triple(1),
|
||||
padding=_triple(0),
|
||||
output_padding=_triple(0),
|
||||
groups=1,
|
||||
dilation=_triple(1)):
|
||||
@register_tracer_impl(F.conv_transpose3d, name="_bias_addition_impl")
|
||||
def conv_transpose3d_impl(
|
||||
input,
|
||||
weight,
|
||||
bias=None,
|
||||
stride=_triple(1),
|
||||
padding=_triple(0),
|
||||
output_padding=_triple(0),
|
||||
groups=1,
|
||||
dilation=_triple(1),
|
||||
):
|
||||
if bias is None:
|
||||
return F.conv_transpose3d(input,
|
||||
weight,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
output_padding=output_padding,
|
||||
groups=groups,
|
||||
dilation=dilation)
|
||||
return F.conv_transpose3d(
|
||||
input,
|
||||
weight,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
output_padding=output_padding,
|
||||
groups=groups,
|
||||
dilation=dilation,
|
||||
)
|
||||
else:
|
||||
return F.conv_transpose3d(input,
|
||||
weight,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
output_padding=output_padding,
|
||||
groups=groups,
|
||||
dilation=dilation) + bias.reshape((-1, 1, 1, 1))
|
||||
return F.conv_transpose3d(
|
||||
input,
|
||||
weight,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
output_padding=output_padding,
|
||||
groups=groups,
|
||||
dilation=dilation,
|
||||
) + bias.reshape((-1, 1, 1, 1))
|
||||
|
||||
|
||||
@register_tracer_impl(torch.addmm, name='_bias_addition_impl')
|
||||
@register_tracer_impl(torch.Tensor.addmm, name='_bias_addition_impl')
|
||||
@register_tracer_impl(torch.addmm, name="_bias_addition_impl")
|
||||
@register_tracer_impl(torch.Tensor.addmm, name="_bias_addition_impl")
|
||||
def addmm_impl(input, mat1, mat2, beta=1, alpha=1):
|
||||
if alpha != 1 and beta != 1:
|
||||
return F.linear(mat1, mat2.transpose(0, 1)) * alpha + input * beta
|
||||
@@ -141,8 +155,8 @@ def addmm_impl(input, mat1, mat2, beta=1, alpha=1):
|
||||
return F.linear(mat1, mat2.transpose(0, 1)) + input
|
||||
|
||||
|
||||
@register_tracer_impl(torch.addbmm, name='_bias_addition_impl')
|
||||
@register_tracer_impl(torch.Tensor.addbmm, name='_bias_addition_impl')
|
||||
@register_tracer_impl(torch.addbmm, name="_bias_addition_impl")
|
||||
@register_tracer_impl(torch.Tensor.addbmm, name="_bias_addition_impl")
|
||||
def addbmm_impl(input, batch1, batch2, beta=1, alpha=1):
|
||||
if alpha != 1 and beta != 1:
|
||||
return torch.bmm(batch1, batch2.transpose(1, 2)) * alpha + input * beta
|
||||
|
@@ -4,6 +4,7 @@ from .tracer import register_leaf_module, register_leaf_module_impl
|
||||
|
||||
try:
|
||||
import apex
|
||||
|
||||
register_leaf_module(apex.normalization.FusedLayerNorm)
|
||||
register_leaf_module(apex.normalization.FusedRMSNorm)
|
||||
register_leaf_module(apex.normalization.MixedFusedLayerNorm)
|
||||
|
@@ -1,10 +1,8 @@
|
||||
import operator
|
||||
from typing import Any, Callable, Dict, Optional, Set, Union
|
||||
from typing import Any, Callable, Dict, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.fx import Graph, Node, Proxy, Tracer
|
||||
from torch.fx.graph import _Namespace
|
||||
from torch.fx import Node, Proxy
|
||||
from torch.utils._pytree import tree_map
|
||||
|
||||
from colossalai._analyzer._subclasses import MetaTensor
|
||||
@@ -32,7 +30,7 @@ class ColoProxy(Proxy):
|
||||
def __torch_function__(cls, orig_method, types, args=(), kwargs=None):
|
||||
kwargs = {} if kwargs is None else kwargs
|
||||
if orig_method in cls._func_dispatch:
|
||||
impl = cls._func_dispatch.pop(orig_method) # avoid recursion
|
||||
impl = cls._func_dispatch.pop(orig_method) # avoid recursion
|
||||
proxy = impl(*args, **kwargs)
|
||||
cls._func_dispatch[orig_method] = impl
|
||||
return proxy
|
||||
@@ -72,7 +70,7 @@ class ColoProxy(Proxy):
|
||||
return ColoAttribute(self, k, getattr(self._meta_data, k, None))
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
proxy = self.tracer.create_proxy('call_function', operator.setitem, (self, key, value), {})
|
||||
proxy = self.tracer.create_proxy("call_function", operator.setitem, (self, key, value), {})
|
||||
proxy.meta_data = self._meta_data
|
||||
return proxy
|
||||
|
||||
@@ -89,7 +87,6 @@ class ColoProxy(Proxy):
|
||||
|
||||
|
||||
class ColoAttribute(ColoProxy):
|
||||
|
||||
def __init__(self, root, attr: str, data=None):
|
||||
self.root = root
|
||||
self.attr = attr
|
||||
@@ -102,11 +99,11 @@ class ColoAttribute(ColoProxy):
|
||||
# the node for attributes is added lazily, since most will just be method calls
|
||||
# which do not rely on the getitem call
|
||||
if self._node is None:
|
||||
self._node = self.tracer.create_proxy('call_function', getattr, (self.root, self.attr), {}).node
|
||||
self._node = self.tracer.create_proxy("call_function", getattr, (self.root, self.attr), {}).node
|
||||
return self._node
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return self.tracer.create_proxy('call_method', self.attr, (self.root,) + args, kwargs)
|
||||
return self.tracer.create_proxy("call_method", self.attr, (self.root,) + args, kwargs)
|
||||
|
||||
def __repr__(self):
|
||||
return f"ColoAttribute({self.node.name}, attr={self.attr})"
|
||||
|
@@ -1,4 +1,4 @@
|
||||
from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Type, Union
|
||||
from typing import Any, Callable, Dict, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch.fx import Tracer
|
||||
@@ -8,6 +8,7 @@ from colossalai._analyzer._subclasses import MetaTensor
|
||||
|
||||
try:
|
||||
from ..codegen import ActivationCheckpointCodeGen
|
||||
|
||||
SUPPORT_ACTIVATION = True
|
||||
except:
|
||||
SUPPORT_ACTIVATION = False
|
||||
@@ -16,7 +17,7 @@ from .tracer import ColoTracer
|
||||
|
||||
|
||||
def _default_device():
|
||||
return torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
|
||||
return torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
|
||||
|
||||
|
||||
def _current_device(module: torch.nn.Module):
|
||||
@@ -144,10 +145,9 @@ def symbolic_trace(
|
||||
if meta_args:
|
||||
device, orig_device = _default_device(), _current_device(root)
|
||||
wrap_fn = lambda elem: MetaTensor(elem, device=device) if isinstance(elem, torch.Tensor) else elem
|
||||
graph = ColoTracer(trace_act_ckpt=trace_act_ckpt,
|
||||
bias_addition_split=bias_addition_split).trace(root.to(device),
|
||||
concrete_args=concrete_args,
|
||||
meta_args=tree_map(wrap_fn, meta_args))
|
||||
graph = ColoTracer(trace_act_ckpt=trace_act_ckpt, bias_addition_split=bias_addition_split).trace(
|
||||
root.to(device), concrete_args=concrete_args, meta_args=tree_map(wrap_fn, meta_args)
|
||||
)
|
||||
if trace_act_ckpt and SUPPORT_ACTIVATION:
|
||||
graph.set_codegen(ActivationCheckpointCodeGen())
|
||||
root.to(orig_device)
|
||||
|
@@ -20,11 +20,10 @@ def _truncate_suffix(s: str):
|
||||
import re
|
||||
|
||||
# FIXME: don't know why but torch.fx always gets a suffix like '_1' in the name
|
||||
return re.sub(r'_\d+$', '', s)
|
||||
return re.sub(r"_\d+$", "", s)
|
||||
|
||||
|
||||
def register_tracer_impl(func: Callable[..., Any], name: Optional[str] = '_custom_impl'):
|
||||
|
||||
def register_tracer_impl(func: Callable[..., Any], name: Optional[str] = "_custom_impl"):
|
||||
def wrapper(impl):
|
||||
assert hasattr(ColoTracer, name), f"Cannot register {func.__name__} in ColoTracer.{name}"
|
||||
getattr(ColoTracer, name)[func] = impl
|
||||
@@ -34,7 +33,6 @@ def register_tracer_impl(func: Callable[..., Any], name: Optional[str] = '_custo
|
||||
|
||||
|
||||
def register_leaf_module_impl(module: nn.Module):
|
||||
|
||||
def wrapper(impl):
|
||||
ColoTracer._custom_leaf_module_impl[module] = impl
|
||||
return impl
|
||||
@@ -76,7 +74,7 @@ class ColoTracer(Tracer):
|
||||
self.ckpt_regions = []
|
||||
self.ckpt_idx = 0
|
||||
|
||||
self.mod_dir = ''
|
||||
self.mod_dir = ""
|
||||
|
||||
# whether the tracer should split the bias_add ops into two ops
|
||||
self.bias_addition_split = bias_addition_split
|
||||
@@ -87,35 +85,41 @@ class ColoTracer(Tracer):
|
||||
if self.bias_addition_split and type(m) in self._bias_addition_module and m.bias is not None:
|
||||
return False
|
||||
# user can specify which modules are leaf modules and which are not
|
||||
return (type(m) not in self._custom_non_leaf_module
|
||||
and (type(m) in self._custom_leaf_module or super().is_leaf_module(m, module_qualified_name)))
|
||||
return type(m) not in self._custom_non_leaf_module and (
|
||||
type(m) in self._custom_leaf_module or super().is_leaf_module(m, module_qualified_name)
|
||||
)
|
||||
|
||||
def call_module(self, m: torch.nn.Module, forward: Callable[..., Any], args: Tuple[Any, ...],
|
||||
kwargs: Dict[str, Any]) -> Any:
|
||||
def call_module(
|
||||
self, m: torch.nn.Module, forward: Callable[..., Any], args: Tuple[Any, ...], kwargs: Dict[str, Any]
|
||||
) -> Any:
|
||||
curr_dir = self.mod_dir
|
||||
self.mod_dir = 'self.' + self.path_of_module(m)
|
||||
self.mod_dir = "self." + self.path_of_module(m)
|
||||
rst = super().call_module(m, forward, args, kwargs)
|
||||
self.mod_dir = curr_dir
|
||||
return rst
|
||||
|
||||
def proxy(self, node: Node) -> 'ColoProxy':
|
||||
def proxy(self, node: Node) -> "ColoProxy":
|
||||
return ColoProxy(node, self)
|
||||
|
||||
def create_proxy(self,
|
||||
kind: str,
|
||||
target: Target,
|
||||
args: Tuple[Any, ...],
|
||||
kwargs: Dict[str, Any],
|
||||
name: Optional[str] = None,
|
||||
type_expr: Optional[Any] = None,
|
||||
proxy_factory_fn: Callable[[Node], 'Proxy'] = None):
|
||||
|
||||
def create_proxy(
|
||||
self,
|
||||
kind: str,
|
||||
target: Target,
|
||||
args: Tuple[Any, ...],
|
||||
kwargs: Dict[str, Any],
|
||||
name: Optional[str] = None,
|
||||
type_expr: Optional[Any] = None,
|
||||
proxy_factory_fn: Callable[[Node], "Proxy"] = None,
|
||||
):
|
||||
proxy: ColoProxy = super().create_proxy(kind, target, args, kwargs, name, type_expr, proxy_factory_fn)
|
||||
unwrap_fn = lambda p: p.meta_data if isinstance(p, ColoProxy) else p
|
||||
if kind == 'placeholder':
|
||||
proxy.meta_data = self.meta_args[target] if target in self.meta_args else self.concrete_args.get(
|
||||
_truncate_suffix(target), None)
|
||||
elif kind == 'get_attr':
|
||||
if kind == "placeholder":
|
||||
proxy.meta_data = (
|
||||
self.meta_args[target]
|
||||
if target in self.meta_args
|
||||
else self.concrete_args.get(_truncate_suffix(target), None)
|
||||
)
|
||||
elif kind == "get_attr":
|
||||
self.disable_module_getattr = True
|
||||
try:
|
||||
attr_itr = self.root
|
||||
@@ -125,20 +129,21 @@ class ColoTracer(Tracer):
|
||||
proxy.meta_data = attr_itr
|
||||
finally:
|
||||
self.disable_module_getattr = False
|
||||
elif kind == 'call_function':
|
||||
elif kind == "call_function":
|
||||
proxy.meta_data = target(*tree_map(unwrap_fn, args), **tree_map(unwrap_fn, kwargs))
|
||||
elif kind == 'call_method':
|
||||
elif kind == "call_method":
|
||||
self.disable_module_getattr = True
|
||||
try:
|
||||
if target == '__call__':
|
||||
if target == "__call__":
|
||||
proxy.meta_data = unwrap_fn(args[0])(*tree_map(unwrap_fn, args[1:]), **tree_map(unwrap_fn, kwargs))
|
||||
else:
|
||||
if target not in _TensorPropertyMethod:
|
||||
proxy._meta_data = getattr(unwrap_fn(args[0]), target)(*tree_map(unwrap_fn, args[1:]),
|
||||
**tree_map(unwrap_fn, kwargs))
|
||||
proxy._meta_data = getattr(unwrap_fn(args[0]), target)(
|
||||
*tree_map(unwrap_fn, args[1:]), **tree_map(unwrap_fn, kwargs)
|
||||
)
|
||||
finally:
|
||||
self.disable_module_getattr = False
|
||||
elif kind == 'call_module':
|
||||
elif kind == "call_module":
|
||||
mod = self.root.get_submodule(target)
|
||||
self.disable_module_getattr = True
|
||||
try:
|
||||
@@ -158,11 +163,12 @@ class ColoTracer(Tracer):
|
||||
n_info = MetaInfo(node, mod_dir=self.mod_dir, activation_checkpoint=tuple(self.ckpt_regions))
|
||||
return node
|
||||
|
||||
def trace(self,
|
||||
root: torch.nn.Module,
|
||||
concrete_args: Optional[Dict[str, torch.Tensor]] = None,
|
||||
meta_args: Optional[Dict[str, torch.Tensor]] = None) -> Graph:
|
||||
|
||||
def trace(
|
||||
self,
|
||||
root: torch.nn.Module,
|
||||
concrete_args: Optional[Dict[str, torch.Tensor]] = None,
|
||||
meta_args: Optional[Dict[str, torch.Tensor]] = None,
|
||||
) -> Graph:
|
||||
if meta_args is None:
|
||||
meta_args = {}
|
||||
|
||||
@@ -177,9 +183,7 @@ class ColoTracer(Tracer):
|
||||
non_concrete_arg_names = sig_names - concrete_arg_names
|
||||
# update concrete args with default values
|
||||
for k, v in sig.parameters.items():
|
||||
if k in sig_names - meta_arg_names and \
|
||||
k not in concrete_args and \
|
||||
v.default is not inspect.Parameter.empty:
|
||||
if k in sig_names - meta_arg_names and k not in concrete_args and v.default is not inspect.Parameter.empty:
|
||||
concrete_args[k] = v.default
|
||||
|
||||
def _check_arg_name_valid(names: Iterable[str]):
|
||||
@@ -194,9 +198,9 @@ class ColoTracer(Tracer):
|
||||
self.meta_args = meta_args
|
||||
|
||||
with self._torch_factory_override(), self._tracer_override(), torch.no_grad():
|
||||
self.mod_dir = 'self'
|
||||
self.mod_dir = "self"
|
||||
self.graph = super().trace(root, concrete_args=concrete_args)
|
||||
self.mod_dir = ''
|
||||
self.mod_dir = ""
|
||||
self.graph.lint()
|
||||
|
||||
for node in self.graph.nodes:
|
||||
@@ -266,17 +270,17 @@ class ColoTracer(Tracer):
|
||||
# override the torch factory functions to create a proxy when the method
|
||||
# is called during ``symbolic_trace()``.
|
||||
def wrap_factory_method(target):
|
||||
|
||||
@functools.wraps(target)
|
||||
def wrapper(*args, **kwargs):
|
||||
is_proxy = any(isinstance(p, ColoProxy) for p in args) | any(
|
||||
isinstance(p, ColoProxy) for p in kwargs.values())
|
||||
isinstance(p, ColoProxy) for p in kwargs.values()
|
||||
)
|
||||
if is_proxy:
|
||||
# if the arg is a proxy, then need to record this function called on this proxy
|
||||
# e.g. torch.ones(size) where size is an input proxy
|
||||
self.disable_module_getattr = True
|
||||
try:
|
||||
proxy = self.create_proxy('call_function', target, args, kwargs)
|
||||
proxy = self.create_proxy("call_function", target, args, kwargs)
|
||||
finally:
|
||||
self.disable_module_getattr = False
|
||||
return proxy
|
||||
@@ -341,10 +345,13 @@ class ColoTracer(Tracer):
|
||||
if attr_val is p:
|
||||
if n not in parameter_proxy_cache:
|
||||
kwargs = {}
|
||||
if 'proxy_factory_fn' in inspect.signature(self.create_proxy).parameters:
|
||||
kwargs['proxy_factory_fn'] = (None if not self.param_shapes_constant else
|
||||
lambda node: ColoProxy(self, node, n, attr_val))
|
||||
val_proxy = self.create_proxy('get_attr', n, (), {}, **kwargs) # type: ignore[arg-type]
|
||||
if "proxy_factory_fn" in inspect.signature(self.create_proxy).parameters:
|
||||
kwargs["proxy_factory_fn"] = (
|
||||
None
|
||||
if not self.param_shapes_constant
|
||||
else lambda node: ColoProxy(self, node, n, attr_val)
|
||||
)
|
||||
val_proxy = self.create_proxy("get_attr", n, (), {}, **kwargs) # type: ignore[arg-type]
|
||||
parameter_proxy_cache[n] = val_proxy
|
||||
return parameter_proxy_cache[n]
|
||||
return None
|
||||
@@ -355,8 +362,9 @@ class ColoTracer(Tracer):
|
||||
return maybe_buffer_proxy
|
||||
|
||||
if isinstance(attr_val, torch.nn.Parameter):
|
||||
maybe_parameter_proxy = maybe_get_proxy_for_attr(attr_val, self.root.named_parameters(),
|
||||
parameter_proxy_cache)
|
||||
maybe_parameter_proxy = maybe_get_proxy_for_attr(
|
||||
attr_val, self.root.named_parameters(), parameter_proxy_cache
|
||||
)
|
||||
if maybe_parameter_proxy is not None:
|
||||
return maybe_parameter_proxy
|
||||
|
||||
|
Reference in New Issue
Block a user