mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-08 12:30:42 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -3,7 +3,7 @@
|
||||
# refer to https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml
|
||||
# for more meta_registrations
|
||||
|
||||
from typing import Callable, List, Optional, Tuple, Union
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import torch
|
||||
from packaging import version
|
||||
@@ -24,25 +24,23 @@ orig_empty_like = torch.empty_like
|
||||
|
||||
|
||||
def new(*args, **kwargs):
|
||||
return orig_empty(*args, **kwargs, device=torch.device('meta'))
|
||||
return orig_empty(*args, **kwargs, device=torch.device("meta"))
|
||||
|
||||
|
||||
def new_strided(*args, **kwargs):
|
||||
return orig_empty_strided(*args, **kwargs, device=torch.device('meta'))
|
||||
return orig_empty_strided(*args, **kwargs, device=torch.device("meta"))
|
||||
|
||||
|
||||
def new_like(*args, **kwargs):
|
||||
return orig_empty_like(*args, **kwargs, device=torch.device('meta'))
|
||||
return orig_empty_like(*args, **kwargs, device=torch.device("meta"))
|
||||
|
||||
|
||||
def register_meta(op, register_dispatcher=True):
|
||||
|
||||
def wrapper(f):
|
||||
|
||||
def add_func(op):
|
||||
meta_table[op] = f
|
||||
if register_dispatcher:
|
||||
name = (op.__name__ if op._overloadname != "default" else op.overloadpacket.__name__)
|
||||
name = op.__name__ if op._overloadname != "default" else op.overloadpacket.__name__
|
||||
try:
|
||||
meta_lib.impl(name, f)
|
||||
except:
|
||||
@@ -54,7 +52,7 @@ def register_meta(op, register_dispatcher=True):
|
||||
return wrapper
|
||||
|
||||
|
||||
if version.parse(torch.__version__) >= version.parse('1.12.0'):
|
||||
if version.parse(torch.__version__) >= version.parse("1.12.0"):
|
||||
# ============================== Convolutions ======================================
|
||||
# https://github.com/pytorch/pytorch/pull/79834
|
||||
@register_meta(aten.convolution.default)
|
||||
@@ -69,7 +67,6 @@ if version.parse(torch.__version__) >= version.parse('1.12.0'):
|
||||
output_padding: List[int],
|
||||
groups: int,
|
||||
):
|
||||
|
||||
def _formula(ln: int, p: int, d: int, k: int, s: int) -> int:
|
||||
"""
|
||||
Formula to apply to calculate the length of some dimension of the output
|
||||
@@ -146,7 +143,8 @@ if version.parse(torch.__version__) >= version.parse('1.12.0'):
|
||||
kernel_size[i],
|
||||
stride[i],
|
||||
output_padding_list[i],
|
||||
))
|
||||
)
|
||||
)
|
||||
else:
|
||||
ret_shape.append(_formula(dims[i], padding[i], dilation[i], kernel_size[i], stride[i]))
|
||||
return ret_shape
|
||||
@@ -180,19 +178,39 @@ if version.parse(torch.__version__) >= version.parse('1.12.0'):
|
||||
shape_out = calc_conv_nd_return_shape(dims, kernel_size, stride, padding, dilation)
|
||||
out = input_tensor.new_empty((input_tensor.shape[0], out_channels, *shape_out))
|
||||
mem_fmt = pick_memory_format()
|
||||
out = out.to(memory_format=mem_fmt) # type: ignore[call-overload]
|
||||
out = out.to(memory_format=mem_fmt) # type: ignore[call-overload]
|
||||
return out
|
||||
|
||||
@register_meta(aten._convolution.default)
|
||||
def meta__conv(input_tensor: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, stride: List[int],
|
||||
padding: List[int], dilation: List[int], is_transposed: bool, output_padding: List[int], groups: int,
|
||||
*extra_args):
|
||||
def meta__conv(
|
||||
input_tensor: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
bias: torch.Tensor,
|
||||
stride: List[int],
|
||||
padding: List[int],
|
||||
dilation: List[int],
|
||||
is_transposed: bool,
|
||||
output_padding: List[int],
|
||||
groups: int,
|
||||
*extra_args,
|
||||
):
|
||||
out = meta_conv(input_tensor, weight, bias, stride, padding, dilation, is_transposed, output_padding, groups)
|
||||
return out
|
||||
|
||||
@register_meta(aten.convolution_backward.default)
|
||||
def meta_conv_backward(grad_output: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, bias_sizes, stride,
|
||||
padding, dilation, transposed, output_padding, groups, output_mask):
|
||||
def meta_conv_backward(
|
||||
grad_output: torch.Tensor,
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
bias_sizes,
|
||||
stride,
|
||||
padding,
|
||||
dilation,
|
||||
transposed,
|
||||
output_padding,
|
||||
groups,
|
||||
output_mask,
|
||||
):
|
||||
return new_like(input), new_like(weight), new((bias_sizes))
|
||||
|
||||
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/AdaptiveAveragePooling.cpp
|
||||
@@ -224,7 +242,6 @@ if version.parse(torch.__version__) >= version.parse('1.12.0'):
|
||||
batch_sizes,
|
||||
dropout_state,
|
||||
):
|
||||
|
||||
is_input_packed = len(batch_sizes) != 0
|
||||
if is_input_packed:
|
||||
seq_length = len(batch_sizes)
|
||||
@@ -240,8 +257,11 @@ if version.parse(torch.__version__) >= version.parse('1.12.0'):
|
||||
if is_input_packed:
|
||||
out_shape = [batch_sizes_sum, out_size * num_directions]
|
||||
else:
|
||||
out_shape = ([mini_batch, seq_length, out_size *
|
||||
num_directions] if batch_first else [seq_length, mini_batch, out_size * num_directions])
|
||||
out_shape = (
|
||||
[mini_batch, seq_length, out_size * num_directions]
|
||||
if batch_first
|
||||
else [seq_length, mini_batch, out_size * num_directions]
|
||||
)
|
||||
output = input.new_empty(out_shape)
|
||||
|
||||
cell_shape = [num_layers * num_directions, mini_batch, hidden_size]
|
||||
@@ -257,15 +277,21 @@ if version.parse(torch.__version__) >= version.parse('1.12.0'):
|
||||
|
||||
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/RNN.cpp
|
||||
@register_meta(aten._cudnn_rnn_backward.default)
|
||||
def meta_cudnn_rnn_backward(input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
weight_stride0: int,
|
||||
hx: torch.Tensor,
|
||||
cx: Optional[torch.Tensor] = None,
|
||||
*args,
|
||||
**kwargs):
|
||||
return new_like(input), new_like(weight), new_like(hx), new_like(cx) if cx is not None else new(
|
||||
()) # (grad_input, grad_weight, grad_hx, grad_cx)
|
||||
def meta_cudnn_rnn_backward(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
weight_stride0: int,
|
||||
hx: torch.Tensor,
|
||||
cx: Optional[torch.Tensor] = None,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
return (
|
||||
new_like(input),
|
||||
new_like(weight),
|
||||
new_like(hx),
|
||||
new_like(cx) if cx is not None else new(()),
|
||||
) # (grad_input, grad_weight, grad_hx, grad_cx)
|
||||
|
||||
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Activation.cpp
|
||||
# ============================== Activations =======================================
|
||||
@@ -278,7 +304,7 @@ if version.parse(torch.__version__) >= version.parse('1.12.0'):
|
||||
aten.hardtanh_backward.default,
|
||||
]
|
||||
|
||||
if version.parse(torch.__version__) < version.parse('2.0.0'):
|
||||
if version.parse(torch.__version__) < version.parse("2.0.0"):
|
||||
_unregistered_ewise += [
|
||||
aten.prelu_backward.default,
|
||||
]
|
||||
@@ -296,37 +322,61 @@ if version.parse(torch.__version__) >= version.parse('1.12.0'):
|
||||
|
||||
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/BatchNorm.cpp
|
||||
@register_meta(aten.native_batch_norm_backward.default)
|
||||
def meta_bn_backward(dY: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, running_mean, running_var,
|
||||
save_mean, save_invstd, train, eps, output_mask):
|
||||
return new_like(input), new_like(weight), new_like(weight) # (dX, dgamma, dbeta)
|
||||
def meta_bn_backward(
|
||||
dY: torch.Tensor,
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
running_mean,
|
||||
running_var,
|
||||
save_mean,
|
||||
save_invstd,
|
||||
train,
|
||||
eps,
|
||||
output_mask,
|
||||
):
|
||||
return new_like(input), new_like(weight), new_like(weight) # (dX, dgamma, dbeta)
|
||||
|
||||
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/BatchNorm.cpp
|
||||
@register_meta(aten.cudnn_batch_norm.default)
|
||||
def meta_cudnn_bn(input: torch.Tensor, weight, bias, running_mean, running_var, training, momentum, eps):
|
||||
n_input = input.size(1)
|
||||
return new_like(input), new((n_input)), new((n_input)), new(
|
||||
(0), dtype=torch.uint8) # (output, running_mean, running_var, reserve)
|
||||
return (
|
||||
new_like(input),
|
||||
new((n_input)),
|
||||
new((n_input)),
|
||||
new((0), dtype=torch.uint8),
|
||||
) # (output, running_mean, running_var, reserve)
|
||||
|
||||
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/BatchNorm.cpp
|
||||
# NB: CuDNN only implements the backward algorithm for batchnorm
|
||||
# in training mode (evaluation mode batchnorm has a different algorithm),
|
||||
# which is why this doesn't accept a 'training' parameter.
|
||||
@register_meta(aten.cudnn_batch_norm_backward.default)
|
||||
def meta_cudnn_bn_backward(dY: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, running_mean, running_var,
|
||||
save_mean, save_invstd, eps, reserve):
|
||||
return new_like(input), new_like(weight), new_like(weight) # (dX, dgamma, dbeta)
|
||||
def meta_cudnn_bn_backward(
|
||||
dY: torch.Tensor,
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
running_mean,
|
||||
running_var,
|
||||
save_mean,
|
||||
save_invstd,
|
||||
eps,
|
||||
reserve,
|
||||
):
|
||||
return new_like(input), new_like(weight), new_like(weight) # (dX, dgamma, dbeta)
|
||||
|
||||
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/layer_norm.cpp
|
||||
@register_meta(aten.native_layer_norm.default)
|
||||
def meta_ln(input: torch.Tensor, normalized_shape, weight, bias, eps):
|
||||
bs, n_input = input.size(0), input.size(1)
|
||||
return new_like(input), new((bs, n_input, 1)), new((bs, n_input, 1)) # (output, running_mean, running_var)
|
||||
return new_like(input), new((bs, n_input, 1)), new((bs, n_input, 1)) # (output, running_mean, running_var)
|
||||
|
||||
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/layer_norm.cpp
|
||||
@register_meta(aten.native_layer_norm_backward.default)
|
||||
def meta_ln_backward(dY: torch.Tensor, input: torch.Tensor, normalized_shape, mean, rstd, weight, bias,
|
||||
grad_input_mask):
|
||||
return new_like(input), new_like(weight), new_like(bias) # (dX, dgamma, dbeta)
|
||||
def meta_ln_backward(
|
||||
dY: torch.Tensor, input: torch.Tensor, normalized_shape, mean, rstd, weight, bias, grad_input_mask
|
||||
):
|
||||
return new_like(input), new_like(weight), new_like(bias) # (dX, dgamma, dbeta)
|
||||
|
||||
# ================================== Misc ==========================================
|
||||
# Maybe incorrect
|
||||
@@ -355,8 +405,9 @@ if version.parse(torch.__version__) >= version.parse('1.12.0'):
|
||||
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Embedding.cpp
|
||||
|
||||
@register_meta(aten.embedding_dense_backward.default)
|
||||
def meta_embedding_dense_backward(grad_output: torch.Tensor, indices: torch.Tensor, num_weights, padding_idx,
|
||||
scale_grad_by_freq):
|
||||
def meta_embedding_dense_backward(
|
||||
grad_output: torch.Tensor, indices: torch.Tensor, num_weights, padding_idx, scale_grad_by_freq
|
||||
):
|
||||
return new((num_weights, grad_output.size(-1)), dtype=grad_output.dtype, layout=grad_output.layout)
|
||||
|
||||
# ============================== Dropout ===========================================
|
||||
@@ -364,14 +415,14 @@ if version.parse(torch.__version__) >= version.parse('1.12.0'):
|
||||
@register_meta(aten.native_dropout.default)
|
||||
def meta_native_dropout_default(input: torch.Tensor, p: float, train: bool = False):
|
||||
# notice that mask is bool
|
||||
return new_like(input), new_like(input, dtype=torch.bool) # (output, mask)
|
||||
return new_like(input), new_like(input, dtype=torch.bool) # (output, mask)
|
||||
|
||||
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Dropout.cpp
|
||||
@register_meta(aten.native_dropout_backward.default)
|
||||
def meta_native_dropout_backward_default(grad: torch.Tensor, mask: torch.Tensor, scale: float):
|
||||
return new_like(grad) # (grad_in)
|
||||
return new_like(grad) # (grad_in)
|
||||
|
||||
if version.parse(torch.__version__) < version.parse('1.13.0'):
|
||||
if version.parse(torch.__version__) < version.parse("1.13.0"):
|
||||
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml
|
||||
@register_meta(aten.eye.m_out)
|
||||
def meta_eye(n: int, m: int, out: torch.Tensor):
|
||||
@@ -385,24 +436,28 @@ if version.parse(torch.__version__) >= version.parse('1.12.0'):
|
||||
result: List[Optional[torch.Tensor]] = []
|
||||
for i, index in enumerate(indices):
|
||||
if index is not None:
|
||||
assert index.dtype in [torch.long, torch.int8, torch.bool],\
|
||||
"tensors used as indices must be long, byte or bool tensors"
|
||||
assert index.dtype in [
|
||||
torch.long,
|
||||
torch.int8,
|
||||
torch.bool,
|
||||
], "tensors used as indices must be long, byte or bool tensors"
|
||||
if index.dtype in [torch.int8, torch.bool]:
|
||||
nonzero = index.nonzero()
|
||||
k = len(result)
|
||||
assert k + index.ndim <= self.ndim, f"too many indices for tensor of dimension {self.ndim}"
|
||||
for j in range(index.ndim):
|
||||
assert index.shape[j] == self.shape[
|
||||
k +
|
||||
j], f"The shape of the mask {index.shape} at index {i} does not match the shape of the indexed tensor {self.shape} at index {k + j}"
|
||||
assert (
|
||||
index.shape[j] == self.shape[k + j]
|
||||
), f"The shape of the mask {index.shape} at index {i} does not match the shape of the indexed tensor {self.shape} at index {k + j}"
|
||||
result.append(nonzero.select(1, j))
|
||||
else:
|
||||
result.append(index)
|
||||
else:
|
||||
result.append(index)
|
||||
indices = result
|
||||
assert len(
|
||||
indices) <= self.ndim, f"too many indices for tensor of dimension {self.ndim} (got {len(indices)})"
|
||||
assert (
|
||||
len(indices) <= self.ndim
|
||||
), f"too many indices for tensor of dimension {self.ndim} (got {len(indices)})"
|
||||
# expand_outplace
|
||||
import torch._refs as refs
|
||||
|
||||
|
@@ -1,5 +1,4 @@
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from packaging import version
|
||||
|
||||
__all__ = [
|
||||
@@ -48,7 +47,7 @@ _DistCommMethod = [
|
||||
"scatter",
|
||||
]
|
||||
|
||||
if version.parse(torch.__version__) >= version.parse('1.12.0'):
|
||||
if version.parse(torch.__version__) >= version.parse("1.12.0"):
|
||||
aten = torch.ops.aten
|
||||
# TODO: dive deep here
|
||||
# refer to https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/TensorShape.cpp
|
||||
|
@@ -8,7 +8,7 @@ from contextlib import contextmanager
|
||||
from enum import Enum, auto
|
||||
from functools import partial, reduce
|
||||
from numbers import Number
|
||||
from typing import Any, Callable, List, Optional, Union
|
||||
from typing import Any, Callable, List, Union
|
||||
|
||||
import torch
|
||||
from packaging import version
|
||||
@@ -36,15 +36,15 @@ def _format_flops(flop):
|
||||
B = 1e9
|
||||
T = 1e12
|
||||
if flop < K:
|
||||
return f'{flop:.2f}'
|
||||
return f"{flop:.2f}"
|
||||
elif flop < M:
|
||||
return f'{flop / K:.2f}K'
|
||||
return f"{flop / K:.2f}K"
|
||||
elif flop < B:
|
||||
return f'{flop / M:.2f}M'
|
||||
return f"{flop / M:.2f}M"
|
||||
elif flop < T:
|
||||
return f'{flop / B:.2f}B'
|
||||
return f"{flop / B:.2f}B"
|
||||
else:
|
||||
return f'{flop / T:.2f}T'
|
||||
return f"{flop / T:.2f}T"
|
||||
|
||||
|
||||
def flop_count(module: Union[torch.nn.Module, Callable] = None, *args, verbose: bool = False, **kwargs) -> Number:
|
||||
@@ -59,11 +59,13 @@ def flop_count(module: Union[torch.nn.Module, Callable] = None, *args, verbose:
|
||||
Returns:
|
||||
Number: The total number of floating point operations (FWD + BWD).
|
||||
"""
|
||||
maybe_inplace = (getattr(module, 'inplace', False) or kwargs.get('inplace', False)
|
||||
or getattr(module, '__name__', None) in ('add_', 'mul_', 'div_', 'sub_'))
|
||||
maybe_inplace = (
|
||||
getattr(module, "inplace", False)
|
||||
or kwargs.get("inplace", False)
|
||||
or getattr(module, "__name__", None) in ("add_", "mul_", "div_", "sub_")
|
||||
)
|
||||
|
||||
class DummyModule(torch.nn.Module):
|
||||
|
||||
def __init__(self, func):
|
||||
super().__init__()
|
||||
self.func = func
|
||||
@@ -74,21 +76,20 @@ def flop_count(module: Union[torch.nn.Module, Callable] = None, *args, verbose:
|
||||
|
||||
total_flop_count = {Phase.FWD: 0, Phase.BWD: 0}
|
||||
flop_counts = defaultdict(lambda: defaultdict(int))
|
||||
parents = ['Global']
|
||||
parents = ["Global"]
|
||||
module = module if isinstance(module, torch.nn.Module) else DummyModule(module)
|
||||
|
||||
class FlopTensor(MetaTensor):
|
||||
_tensor: torch.Tensor
|
||||
|
||||
def __repr__(self):
|
||||
name = 'FlopParameter' if getattr(self, '_is_param', False) else 'FlopTensor'
|
||||
name = "FlopParameter" if getattr(self, "_is_param", False) else "FlopTensor"
|
||||
if self.grad_fn:
|
||||
return f"{name}(..., size={tuple(self.shape)}, device='{self.device}', dtype={self.dtype}, grad_fn={self.grad_fn})"
|
||||
return f"{name}(..., size={tuple(self.shape)}, device='{self.device}', dtype={self.dtype})"
|
||||
|
||||
@classmethod
|
||||
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
|
||||
|
||||
# no_dispatch is only needed if you use enable_python_mode.
|
||||
# It prevents infinite recursion.
|
||||
rs = super().__torch_dispatch__(func, types, args, kwargs)
|
||||
@@ -115,9 +116,7 @@ def flop_count(module: Union[torch.nn.Module, Callable] = None, *args, verbose:
|
||||
return isinstance(x, torch.Tensor) and x.is_floating_point()
|
||||
|
||||
def create_backwards_push(name):
|
||||
|
||||
class PushState(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, *args):
|
||||
args = tree_map(lambda x: x.clone() if isinstance(x, torch.Tensor) else x, args)
|
||||
@@ -134,9 +133,7 @@ def flop_count(module: Union[torch.nn.Module, Callable] = None, *args, verbose:
|
||||
return PushState.apply
|
||||
|
||||
def create_backwards_pop(name):
|
||||
|
||||
class PopState(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, *args):
|
||||
args = tree_map(lambda x: x.clone() if isinstance(x, torch.Tensor) else x, args)
|
||||
@@ -147,14 +144,13 @@ def flop_count(module: Union[torch.nn.Module, Callable] = None, *args, verbose:
|
||||
@staticmethod
|
||||
def backward(ctx, *grad_outs):
|
||||
nonlocal parents
|
||||
assert (parents[-1] == name)
|
||||
assert parents[-1] == name
|
||||
parents.pop()
|
||||
return grad_outs
|
||||
|
||||
return PopState.apply
|
||||
|
||||
def enter_module(name):
|
||||
|
||||
def f(module, inputs):
|
||||
nonlocal parents
|
||||
parents.append(name)
|
||||
@@ -165,10 +161,9 @@ def flop_count(module: Union[torch.nn.Module, Callable] = None, *args, verbose:
|
||||
return f
|
||||
|
||||
def exit_module(name):
|
||||
|
||||
def f(module, inputs, outputs):
|
||||
nonlocal parents
|
||||
assert (parents[-1] == name)
|
||||
assert parents[-1] == name
|
||||
parents.pop()
|
||||
outputs = normalize_tuple(outputs)
|
||||
return create_backwards_push(name)(*outputs)
|
||||
@@ -189,7 +184,7 @@ def flop_count(module: Union[torch.nn.Module, Callable] = None, *args, verbose:
|
||||
for mod in flop_counts.keys():
|
||||
print(f"Module: ", mod)
|
||||
for k, v in flop_counts[mod].items():
|
||||
print('\t', k, _format_flops(v))
|
||||
print("\t", k, _format_flops(v))
|
||||
print()
|
||||
|
||||
def detach_variables(r):
|
||||
@@ -201,7 +196,7 @@ def flop_count(module: Union[torch.nn.Module, Callable] = None, *args, verbose:
|
||||
|
||||
def wrap(r):
|
||||
if isinstance(r, torch.Tensor):
|
||||
data_ptr_fn = getattr(r, '_tensor', r).data_ptr
|
||||
data_ptr_fn = getattr(r, "_tensor", r).data_ptr
|
||||
r = FlopTensor(detach_variables(r))
|
||||
if maybe_inplace:
|
||||
r = r + 0
|
||||
@@ -375,8 +370,11 @@ def norm_flop_counter(affine_arg_index: int, input_arg_index: int) -> Callable:
|
||||
# Inputs[0] contains the shape of the input.
|
||||
input_shape = inputs[input_arg_index].shape
|
||||
|
||||
has_affine = inputs[affine_arg_index].shape is not None if hasattr(inputs[affine_arg_index],
|
||||
'shape') else inputs[affine_arg_index]
|
||||
has_affine = (
|
||||
inputs[affine_arg_index].shape is not None
|
||||
if hasattr(inputs[affine_arg_index], "shape")
|
||||
else inputs[affine_arg_index]
|
||||
)
|
||||
assert 2 <= len(input_shape) <= 5, input_shape
|
||||
# 5 is just a rough estimate
|
||||
flop = reduce(operator.mul, input_shape) * (5 if has_affine else 4)
|
||||
@@ -390,7 +388,7 @@ def batchnorm_flop_jit(inputs: List[Any], outputs: List[Any], training: bool = N
|
||||
training = inputs[-3]
|
||||
assert isinstance(training, bool), "Signature of aten::batch_norm has changed!"
|
||||
if training:
|
||||
return norm_flop_counter(1, 0)(inputs, outputs) # pyre-ignore
|
||||
return norm_flop_counter(1, 0)(inputs, outputs) # pyre-ignore
|
||||
has_affine = inputs[1].shape is not None
|
||||
input_shape = reduce(operator.mul, inputs[0].shape)
|
||||
return input_shape * (2 if has_affine else 1)
|
||||
@@ -420,33 +418,30 @@ def ewise_flop_counter(input_scale: float = 1, output_scale: float = 0) -> Calla
|
||||
|
||||
def zero_flop_jit(*args):
|
||||
"""
|
||||
Count flops for zero flop layers.
|
||||
Count flops for zero flop layers.
|
||||
"""
|
||||
return 0
|
||||
|
||||
|
||||
if version.parse(torch.__version__) >= version.parse('1.12.0'):
|
||||
if version.parse(torch.__version__) >= version.parse("1.12.0"):
|
||||
flop_mapping = {
|
||||
# gemm
|
||||
# gemm
|
||||
aten.mm.default: matmul_flop_jit,
|
||||
aten.matmul.default: matmul_flop_jit,
|
||||
aten.addmm.default: addmm_flop_jit,
|
||||
aten.bmm.default: bmm_flop_jit,
|
||||
|
||||
# convolution
|
||||
# convolution
|
||||
aten.convolution.default: conv_flop_jit,
|
||||
aten._convolution.default: conv_flop_jit,
|
||||
aten.convolution_backward.default: conv_backward_flop_jit,
|
||||
|
||||
# normalization
|
||||
# normalization
|
||||
aten.native_batch_norm.default: batchnorm_flop_jit,
|
||||
aten.native_batch_norm_backward.default: batchnorm_flop_jit,
|
||||
aten.cudnn_batch_norm.default: batchnorm_flop_jit,
|
||||
aten.cudnn_batch_norm_backward.default: partial(batchnorm_flop_jit, training=True),
|
||||
aten.native_layer_norm.default: norm_flop_counter(2, 0),
|
||||
aten.native_layer_norm_backward.default: norm_flop_counter(2, 0),
|
||||
|
||||
# pooling
|
||||
# pooling
|
||||
aten.avg_pool1d.default: ewise_flop_counter(1, 0),
|
||||
aten.avg_pool2d.default: ewise_flop_counter(1, 0),
|
||||
aten.avg_pool2d_backward.default: ewise_flop_counter(0, 1),
|
||||
@@ -469,7 +464,7 @@ if version.parse(torch.__version__) >= version.parse('1.12.0'):
|
||||
}
|
||||
|
||||
ewise_flop_aten = [
|
||||
# basic op
|
||||
# basic op
|
||||
aten.add.Tensor,
|
||||
aten.add_.Tensor,
|
||||
aten.div.Tensor,
|
||||
@@ -485,8 +480,7 @@ if version.parse(torch.__version__) >= version.parse('1.12.0'):
|
||||
aten.sum.default,
|
||||
aten.sum.dim_IntList,
|
||||
aten.mean.dim,
|
||||
|
||||
# activation op
|
||||
# activation op
|
||||
aten.hardswish.default,
|
||||
aten.hardswish_.default,
|
||||
aten.hardswish_backward.default,
|
||||
@@ -509,15 +503,12 @@ if version.parse(torch.__version__) >= version.parse('1.12.0'):
|
||||
aten.tanh.default,
|
||||
aten.tanh_backward.default,
|
||||
aten.threshold_backward.default,
|
||||
|
||||
# dropout
|
||||
# dropout
|
||||
aten.native_dropout.default,
|
||||
aten.native_dropout_backward.default,
|
||||
|
||||
# distribution
|
||||
# distribution
|
||||
aten.bernoulli_.float,
|
||||
|
||||
# where
|
||||
# where
|
||||
aten.where.self,
|
||||
]
|
||||
for op in ewise_flop_aten:
|
||||
|
@@ -3,12 +3,12 @@ from functools import partial
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.types import _bool, _device, _dtype
|
||||
from torch.utils._pytree import tree_flatten, tree_map
|
||||
from torch.types import _device
|
||||
from torch.utils._pytree import tree_map
|
||||
|
||||
from ._monkey_patch import _AliasATen, _DistCommMethod, _InplaceATen, _MaybeInplaceATen, _TorchOverrideableFactoryMethod
|
||||
|
||||
__all__ = ['MetaTensor', 'MetaTensorMode']
|
||||
__all__ = ["MetaTensor", "MetaTensorMode"]
|
||||
|
||||
|
||||
def register_storage(r, data_ptr_fn=None):
|
||||
@@ -28,8 +28,7 @@ def _normalize_tuple(x):
|
||||
|
||||
# a hack of inplace execution in PyTorch
|
||||
def _assert_alias(func):
|
||||
return func in (_AliasATen + _InplaceATen + _MaybeInplaceATen # TODO: check if should be this aggressive
|
||||
)
|
||||
return func in (_AliasATen + _InplaceATen + _MaybeInplaceATen) # TODO: check if should be this aggressive
|
||||
|
||||
|
||||
class MetaTensor(torch.Tensor):
|
||||
@@ -65,14 +64,15 @@ class MetaTensor(torch.Tensor):
|
||||
storage_offset=elem.storage_offset(),
|
||||
dtype=elem.dtype,
|
||||
layout=elem.layout,
|
||||
device=device or (elem.device if elem.device.type != 'meta' else torch.device('cpu')),
|
||||
requires_grad=requires_grad) # deceive the frontend for aten selections
|
||||
device=device or (elem.device if elem.device.type != "meta" else torch.device("cpu")),
|
||||
requires_grad=requires_grad,
|
||||
) # deceive the frontend for aten selections
|
||||
r._tensor = elem
|
||||
# ...the real tensor is held as an element on the tensor.
|
||||
if not r._tensor.is_meta:
|
||||
val = elem.data_ptr()
|
||||
data_ptr_fn = lambda: val
|
||||
r._tensor = r._tensor.to(torch.device('meta'))
|
||||
r._tensor = r._tensor.to(torch.device("meta"))
|
||||
|
||||
# only tensor not on `meta` should be copied to `meta`
|
||||
register_storage(r._tensor, data_ptr_fn)
|
||||
@@ -81,7 +81,7 @@ class MetaTensor(torch.Tensor):
|
||||
return r
|
||||
|
||||
def __repr__(self):
|
||||
name = 'MetaParameter' if getattr(self, '_is_param', False) else 'MetaTensor'
|
||||
name = "MetaParameter" if getattr(self, "_is_param", False) else "MetaTensor"
|
||||
if self.grad_fn:
|
||||
return f"{name}(..., size={tuple(self.shape)}, device='{self.device}', dtype={self.dtype}, grad_fn={self.grad_fn})"
|
||||
return f"{name}(..., size={tuple(self.shape)}, device='{self.device}', dtype={self.dtype})"
|
||||
@@ -97,15 +97,15 @@ class MetaTensor(torch.Tensor):
|
||||
x = x._tensor
|
||||
elif isinstance(x, torch.Tensor):
|
||||
device = x.device
|
||||
x = x.to(torch.device('meta'))
|
||||
x = x.to(torch.device("meta"))
|
||||
return x
|
||||
|
||||
args = tree_map(unwrap, args)
|
||||
kwargs = tree_map(unwrap, kwargs)
|
||||
|
||||
if 'device' in kwargs:
|
||||
device = kwargs['device']
|
||||
kwargs['device'] = torch.device('meta')
|
||||
if "device" in kwargs:
|
||||
device = kwargs["device"]
|
||||
kwargs["device"] = torch.device("meta")
|
||||
|
||||
# run aten for backend=CPU but actually on backend=Meta
|
||||
# here we detect whether or not the execution generates a physical copy
|
||||
@@ -143,21 +143,21 @@ class MetaTensor(torch.Tensor):
|
||||
nonlocal device
|
||||
if isinstance(x, str) or isinstance(x, _device):
|
||||
device = x
|
||||
return torch.device('meta')
|
||||
return torch.device("meta")
|
||||
return x
|
||||
|
||||
elem = self._tensor.to(*tree_map(replace, args), **tree_map(replace, kwargs))
|
||||
return MetaTensor(elem, device=device)
|
||||
|
||||
def cpu(self, *args, **kwargs):
|
||||
if self.device.type == 'cpu':
|
||||
if self.device.type == "cpu":
|
||||
return self.to(*args, **kwargs)
|
||||
return self.to(*args, device='cpu', **kwargs)
|
||||
return self.to(*args, device="cpu", **kwargs)
|
||||
|
||||
def cuda(self, device=None, non_blocking=False):
|
||||
if device is not None:
|
||||
return self.to(device=device, non_blocking=non_blocking)
|
||||
return self.to(device='cuda:0', non_blocking=non_blocking)
|
||||
return self.to(device="cuda:0", non_blocking=non_blocking)
|
||||
|
||||
def data_ptr(self):
|
||||
return self._tensor.data_ptr()
|
||||
@@ -177,19 +177,17 @@ class MetaTensorMode(object):
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.torch_overrides = {} # override torch.xxx
|
||||
self.dist_overrides = {} # override torch.distributed.xxx
|
||||
self.torch_overrides = {} # override torch.xxx
|
||||
self.dist_overrides = {} # override torch.distributed.xxx
|
||||
|
||||
def __enter__(self):
|
||||
|
||||
def _dummy(*args, **kwargs):
|
||||
pass
|
||||
|
||||
def _new(*args, orig_new=torch.empty, **kwargs):
|
||||
return MetaTensor(orig_new(*args, **{
|
||||
**kwargs, 'device': 'meta'
|
||||
}),
|
||||
device=kwargs.get('device', torch.device('cpu')))
|
||||
return MetaTensor(
|
||||
orig_new(*args, **{**kwargs, "device": "meta"}), device=kwargs.get("device", torch.device("cpu"))
|
||||
)
|
||||
|
||||
for func in _TorchOverrideableFactoryMethod:
|
||||
self.torch_overrides[func] = getattr(torch, func)
|
||||
|
Reference in New Issue
Block a user