[misc] update pre-commit and run all files (#4752)

* [misc] update pre-commit

* [misc] run pre-commit

* [misc] remove useless configuration files

* [misc] ignore cuda for clang-format
This commit is contained in:
Hongxin Liu
2023-09-19 14:20:26 +08:00
committed by GitHub
parent 3c6b831c26
commit 079bf3cb26
1268 changed files with 50037 additions and 38444 deletions

View File

@@ -3,7 +3,7 @@
# refer to https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml
# for more meta_registrations
from typing import Callable, List, Optional, Tuple, Union
from typing import List, Optional, Union
import torch
from packaging import version
@@ -24,25 +24,23 @@ orig_empty_like = torch.empty_like
def new(*args, **kwargs):
return orig_empty(*args, **kwargs, device=torch.device('meta'))
return orig_empty(*args, **kwargs, device=torch.device("meta"))
def new_strided(*args, **kwargs):
return orig_empty_strided(*args, **kwargs, device=torch.device('meta'))
return orig_empty_strided(*args, **kwargs, device=torch.device("meta"))
def new_like(*args, **kwargs):
return orig_empty_like(*args, **kwargs, device=torch.device('meta'))
return orig_empty_like(*args, **kwargs, device=torch.device("meta"))
def register_meta(op, register_dispatcher=True):
def wrapper(f):
def add_func(op):
meta_table[op] = f
if register_dispatcher:
name = (op.__name__ if op._overloadname != "default" else op.overloadpacket.__name__)
name = op.__name__ if op._overloadname != "default" else op.overloadpacket.__name__
try:
meta_lib.impl(name, f)
except:
@@ -54,7 +52,7 @@ def register_meta(op, register_dispatcher=True):
return wrapper
if version.parse(torch.__version__) >= version.parse('1.12.0'):
if version.parse(torch.__version__) >= version.parse("1.12.0"):
# ============================== Convolutions ======================================
# https://github.com/pytorch/pytorch/pull/79834
@register_meta(aten.convolution.default)
@@ -69,7 +67,6 @@ if version.parse(torch.__version__) >= version.parse('1.12.0'):
output_padding: List[int],
groups: int,
):
def _formula(ln: int, p: int, d: int, k: int, s: int) -> int:
"""
Formula to apply to calculate the length of some dimension of the output
@@ -146,7 +143,8 @@ if version.parse(torch.__version__) >= version.parse('1.12.0'):
kernel_size[i],
stride[i],
output_padding_list[i],
))
)
)
else:
ret_shape.append(_formula(dims[i], padding[i], dilation[i], kernel_size[i], stride[i]))
return ret_shape
@@ -180,19 +178,39 @@ if version.parse(torch.__version__) >= version.parse('1.12.0'):
shape_out = calc_conv_nd_return_shape(dims, kernel_size, stride, padding, dilation)
out = input_tensor.new_empty((input_tensor.shape[0], out_channels, *shape_out))
mem_fmt = pick_memory_format()
out = out.to(memory_format=mem_fmt) # type: ignore[call-overload]
out = out.to(memory_format=mem_fmt) # type: ignore[call-overload]
return out
@register_meta(aten._convolution.default)
def meta__conv(input_tensor: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, stride: List[int],
padding: List[int], dilation: List[int], is_transposed: bool, output_padding: List[int], groups: int,
*extra_args):
def meta__conv(
input_tensor: torch.Tensor,
weight: torch.Tensor,
bias: torch.Tensor,
stride: List[int],
padding: List[int],
dilation: List[int],
is_transposed: bool,
output_padding: List[int],
groups: int,
*extra_args,
):
out = meta_conv(input_tensor, weight, bias, stride, padding, dilation, is_transposed, output_padding, groups)
return out
@register_meta(aten.convolution_backward.default)
def meta_conv_backward(grad_output: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, bias_sizes, stride,
padding, dilation, transposed, output_padding, groups, output_mask):
def meta_conv_backward(
grad_output: torch.Tensor,
input: torch.Tensor,
weight: torch.Tensor,
bias_sizes,
stride,
padding,
dilation,
transposed,
output_padding,
groups,
output_mask,
):
return new_like(input), new_like(weight), new((bias_sizes))
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/AdaptiveAveragePooling.cpp
@@ -224,7 +242,6 @@ if version.parse(torch.__version__) >= version.parse('1.12.0'):
batch_sizes,
dropout_state,
):
is_input_packed = len(batch_sizes) != 0
if is_input_packed:
seq_length = len(batch_sizes)
@@ -240,8 +257,11 @@ if version.parse(torch.__version__) >= version.parse('1.12.0'):
if is_input_packed:
out_shape = [batch_sizes_sum, out_size * num_directions]
else:
out_shape = ([mini_batch, seq_length, out_size *
num_directions] if batch_first else [seq_length, mini_batch, out_size * num_directions])
out_shape = (
[mini_batch, seq_length, out_size * num_directions]
if batch_first
else [seq_length, mini_batch, out_size * num_directions]
)
output = input.new_empty(out_shape)
cell_shape = [num_layers * num_directions, mini_batch, hidden_size]
@@ -257,15 +277,21 @@ if version.parse(torch.__version__) >= version.parse('1.12.0'):
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/RNN.cpp
@register_meta(aten._cudnn_rnn_backward.default)
def meta_cudnn_rnn_backward(input: torch.Tensor,
weight: torch.Tensor,
weight_stride0: int,
hx: torch.Tensor,
cx: Optional[torch.Tensor] = None,
*args,
**kwargs):
return new_like(input), new_like(weight), new_like(hx), new_like(cx) if cx is not None else new(
()) # (grad_input, grad_weight, grad_hx, grad_cx)
def meta_cudnn_rnn_backward(
input: torch.Tensor,
weight: torch.Tensor,
weight_stride0: int,
hx: torch.Tensor,
cx: Optional[torch.Tensor] = None,
*args,
**kwargs,
):
return (
new_like(input),
new_like(weight),
new_like(hx),
new_like(cx) if cx is not None else new(()),
) # (grad_input, grad_weight, grad_hx, grad_cx)
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Activation.cpp
# ============================== Activations =======================================
@@ -278,7 +304,7 @@ if version.parse(torch.__version__) >= version.parse('1.12.0'):
aten.hardtanh_backward.default,
]
if version.parse(torch.__version__) < version.parse('2.0.0'):
if version.parse(torch.__version__) < version.parse("2.0.0"):
_unregistered_ewise += [
aten.prelu_backward.default,
]
@@ -296,37 +322,61 @@ if version.parse(torch.__version__) >= version.parse('1.12.0'):
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/BatchNorm.cpp
@register_meta(aten.native_batch_norm_backward.default)
def meta_bn_backward(dY: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, running_mean, running_var,
save_mean, save_invstd, train, eps, output_mask):
return new_like(input), new_like(weight), new_like(weight) # (dX, dgamma, dbeta)
def meta_bn_backward(
dY: torch.Tensor,
input: torch.Tensor,
weight: torch.Tensor,
running_mean,
running_var,
save_mean,
save_invstd,
train,
eps,
output_mask,
):
return new_like(input), new_like(weight), new_like(weight) # (dX, dgamma, dbeta)
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/BatchNorm.cpp
@register_meta(aten.cudnn_batch_norm.default)
def meta_cudnn_bn(input: torch.Tensor, weight, bias, running_mean, running_var, training, momentum, eps):
n_input = input.size(1)
return new_like(input), new((n_input)), new((n_input)), new(
(0), dtype=torch.uint8) # (output, running_mean, running_var, reserve)
return (
new_like(input),
new((n_input)),
new((n_input)),
new((0), dtype=torch.uint8),
) # (output, running_mean, running_var, reserve)
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/BatchNorm.cpp
# NB: CuDNN only implements the backward algorithm for batchnorm
# in training mode (evaluation mode batchnorm has a different algorithm),
# which is why this doesn't accept a 'training' parameter.
@register_meta(aten.cudnn_batch_norm_backward.default)
def meta_cudnn_bn_backward(dY: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, running_mean, running_var,
save_mean, save_invstd, eps, reserve):
return new_like(input), new_like(weight), new_like(weight) # (dX, dgamma, dbeta)
def meta_cudnn_bn_backward(
dY: torch.Tensor,
input: torch.Tensor,
weight: torch.Tensor,
running_mean,
running_var,
save_mean,
save_invstd,
eps,
reserve,
):
return new_like(input), new_like(weight), new_like(weight) # (dX, dgamma, dbeta)
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/layer_norm.cpp
@register_meta(aten.native_layer_norm.default)
def meta_ln(input: torch.Tensor, normalized_shape, weight, bias, eps):
bs, n_input = input.size(0), input.size(1)
return new_like(input), new((bs, n_input, 1)), new((bs, n_input, 1)) # (output, running_mean, running_var)
return new_like(input), new((bs, n_input, 1)), new((bs, n_input, 1)) # (output, running_mean, running_var)
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/layer_norm.cpp
@register_meta(aten.native_layer_norm_backward.default)
def meta_ln_backward(dY: torch.Tensor, input: torch.Tensor, normalized_shape, mean, rstd, weight, bias,
grad_input_mask):
return new_like(input), new_like(weight), new_like(bias) # (dX, dgamma, dbeta)
def meta_ln_backward(
dY: torch.Tensor, input: torch.Tensor, normalized_shape, mean, rstd, weight, bias, grad_input_mask
):
return new_like(input), new_like(weight), new_like(bias) # (dX, dgamma, dbeta)
# ================================== Misc ==========================================
# Maybe incorrect
@@ -355,8 +405,9 @@ if version.parse(torch.__version__) >= version.parse('1.12.0'):
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Embedding.cpp
@register_meta(aten.embedding_dense_backward.default)
def meta_embedding_dense_backward(grad_output: torch.Tensor, indices: torch.Tensor, num_weights, padding_idx,
scale_grad_by_freq):
def meta_embedding_dense_backward(
grad_output: torch.Tensor, indices: torch.Tensor, num_weights, padding_idx, scale_grad_by_freq
):
return new((num_weights, grad_output.size(-1)), dtype=grad_output.dtype, layout=grad_output.layout)
# ============================== Dropout ===========================================
@@ -364,14 +415,14 @@ if version.parse(torch.__version__) >= version.parse('1.12.0'):
@register_meta(aten.native_dropout.default)
def meta_native_dropout_default(input: torch.Tensor, p: float, train: bool = False):
# notice that mask is bool
return new_like(input), new_like(input, dtype=torch.bool) # (output, mask)
return new_like(input), new_like(input, dtype=torch.bool) # (output, mask)
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Dropout.cpp
@register_meta(aten.native_dropout_backward.default)
def meta_native_dropout_backward_default(grad: torch.Tensor, mask: torch.Tensor, scale: float):
return new_like(grad) # (grad_in)
return new_like(grad) # (grad_in)
if version.parse(torch.__version__) < version.parse('1.13.0'):
if version.parse(torch.__version__) < version.parse("1.13.0"):
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml
@register_meta(aten.eye.m_out)
def meta_eye(n: int, m: int, out: torch.Tensor):
@@ -385,24 +436,28 @@ if version.parse(torch.__version__) >= version.parse('1.12.0'):
result: List[Optional[torch.Tensor]] = []
for i, index in enumerate(indices):
if index is not None:
assert index.dtype in [torch.long, torch.int8, torch.bool],\
"tensors used as indices must be long, byte or bool tensors"
assert index.dtype in [
torch.long,
torch.int8,
torch.bool,
], "tensors used as indices must be long, byte or bool tensors"
if index.dtype in [torch.int8, torch.bool]:
nonzero = index.nonzero()
k = len(result)
assert k + index.ndim <= self.ndim, f"too many indices for tensor of dimension {self.ndim}"
for j in range(index.ndim):
assert index.shape[j] == self.shape[
k +
j], f"The shape of the mask {index.shape} at index {i} does not match the shape of the indexed tensor {self.shape} at index {k + j}"
assert (
index.shape[j] == self.shape[k + j]
), f"The shape of the mask {index.shape} at index {i} does not match the shape of the indexed tensor {self.shape} at index {k + j}"
result.append(nonzero.select(1, j))
else:
result.append(index)
else:
result.append(index)
indices = result
assert len(
indices) <= self.ndim, f"too many indices for tensor of dimension {self.ndim} (got {len(indices)})"
assert (
len(indices) <= self.ndim
), f"too many indices for tensor of dimension {self.ndim} (got {len(indices)})"
# expand_outplace
import torch._refs as refs

View File

@@ -1,5 +1,4 @@
import torch
import torch.distributed as dist
from packaging import version
__all__ = [
@@ -48,7 +47,7 @@ _DistCommMethod = [
"scatter",
]
if version.parse(torch.__version__) >= version.parse('1.12.0'):
if version.parse(torch.__version__) >= version.parse("1.12.0"):
aten = torch.ops.aten
# TODO: dive deep here
# refer to https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/TensorShape.cpp

View File

@@ -8,7 +8,7 @@ from contextlib import contextmanager
from enum import Enum, auto
from functools import partial, reduce
from numbers import Number
from typing import Any, Callable, List, Optional, Union
from typing import Any, Callable, List, Union
import torch
from packaging import version
@@ -36,15 +36,15 @@ def _format_flops(flop):
B = 1e9
T = 1e12
if flop < K:
return f'{flop:.2f}'
return f"{flop:.2f}"
elif flop < M:
return f'{flop / K:.2f}K'
return f"{flop / K:.2f}K"
elif flop < B:
return f'{flop / M:.2f}M'
return f"{flop / M:.2f}M"
elif flop < T:
return f'{flop / B:.2f}B'
return f"{flop / B:.2f}B"
else:
return f'{flop / T:.2f}T'
return f"{flop / T:.2f}T"
def flop_count(module: Union[torch.nn.Module, Callable] = None, *args, verbose: bool = False, **kwargs) -> Number:
@@ -59,11 +59,13 @@ def flop_count(module: Union[torch.nn.Module, Callable] = None, *args, verbose:
Returns:
Number: The total number of floating point operations (FWD + BWD).
"""
maybe_inplace = (getattr(module, 'inplace', False) or kwargs.get('inplace', False)
or getattr(module, '__name__', None) in ('add_', 'mul_', 'div_', 'sub_'))
maybe_inplace = (
getattr(module, "inplace", False)
or kwargs.get("inplace", False)
or getattr(module, "__name__", None) in ("add_", "mul_", "div_", "sub_")
)
class DummyModule(torch.nn.Module):
def __init__(self, func):
super().__init__()
self.func = func
@@ -74,21 +76,20 @@ def flop_count(module: Union[torch.nn.Module, Callable] = None, *args, verbose:
total_flop_count = {Phase.FWD: 0, Phase.BWD: 0}
flop_counts = defaultdict(lambda: defaultdict(int))
parents = ['Global']
parents = ["Global"]
module = module if isinstance(module, torch.nn.Module) else DummyModule(module)
class FlopTensor(MetaTensor):
_tensor: torch.Tensor
def __repr__(self):
name = 'FlopParameter' if getattr(self, '_is_param', False) else 'FlopTensor'
name = "FlopParameter" if getattr(self, "_is_param", False) else "FlopTensor"
if self.grad_fn:
return f"{name}(..., size={tuple(self.shape)}, device='{self.device}', dtype={self.dtype}, grad_fn={self.grad_fn})"
return f"{name}(..., size={tuple(self.shape)}, device='{self.device}', dtype={self.dtype})"
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
# no_dispatch is only needed if you use enable_python_mode.
# It prevents infinite recursion.
rs = super().__torch_dispatch__(func, types, args, kwargs)
@@ -115,9 +116,7 @@ def flop_count(module: Union[torch.nn.Module, Callable] = None, *args, verbose:
return isinstance(x, torch.Tensor) and x.is_floating_point()
def create_backwards_push(name):
class PushState(torch.autograd.Function):
@staticmethod
def forward(ctx, *args):
args = tree_map(lambda x: x.clone() if isinstance(x, torch.Tensor) else x, args)
@@ -134,9 +133,7 @@ def flop_count(module: Union[torch.nn.Module, Callable] = None, *args, verbose:
return PushState.apply
def create_backwards_pop(name):
class PopState(torch.autograd.Function):
@staticmethod
def forward(ctx, *args):
args = tree_map(lambda x: x.clone() if isinstance(x, torch.Tensor) else x, args)
@@ -147,14 +144,13 @@ def flop_count(module: Union[torch.nn.Module, Callable] = None, *args, verbose:
@staticmethod
def backward(ctx, *grad_outs):
nonlocal parents
assert (parents[-1] == name)
assert parents[-1] == name
parents.pop()
return grad_outs
return PopState.apply
def enter_module(name):
def f(module, inputs):
nonlocal parents
parents.append(name)
@@ -165,10 +161,9 @@ def flop_count(module: Union[torch.nn.Module, Callable] = None, *args, verbose:
return f
def exit_module(name):
def f(module, inputs, outputs):
nonlocal parents
assert (parents[-1] == name)
assert parents[-1] == name
parents.pop()
outputs = normalize_tuple(outputs)
return create_backwards_push(name)(*outputs)
@@ -189,7 +184,7 @@ def flop_count(module: Union[torch.nn.Module, Callable] = None, *args, verbose:
for mod in flop_counts.keys():
print(f"Module: ", mod)
for k, v in flop_counts[mod].items():
print('\t', k, _format_flops(v))
print("\t", k, _format_flops(v))
print()
def detach_variables(r):
@@ -201,7 +196,7 @@ def flop_count(module: Union[torch.nn.Module, Callable] = None, *args, verbose:
def wrap(r):
if isinstance(r, torch.Tensor):
data_ptr_fn = getattr(r, '_tensor', r).data_ptr
data_ptr_fn = getattr(r, "_tensor", r).data_ptr
r = FlopTensor(detach_variables(r))
if maybe_inplace:
r = r + 0
@@ -375,8 +370,11 @@ def norm_flop_counter(affine_arg_index: int, input_arg_index: int) -> Callable:
# Inputs[0] contains the shape of the input.
input_shape = inputs[input_arg_index].shape
has_affine = inputs[affine_arg_index].shape is not None if hasattr(inputs[affine_arg_index],
'shape') else inputs[affine_arg_index]
has_affine = (
inputs[affine_arg_index].shape is not None
if hasattr(inputs[affine_arg_index], "shape")
else inputs[affine_arg_index]
)
assert 2 <= len(input_shape) <= 5, input_shape
# 5 is just a rough estimate
flop = reduce(operator.mul, input_shape) * (5 if has_affine else 4)
@@ -390,7 +388,7 @@ def batchnorm_flop_jit(inputs: List[Any], outputs: List[Any], training: bool = N
training = inputs[-3]
assert isinstance(training, bool), "Signature of aten::batch_norm has changed!"
if training:
return norm_flop_counter(1, 0)(inputs, outputs) # pyre-ignore
return norm_flop_counter(1, 0)(inputs, outputs) # pyre-ignore
has_affine = inputs[1].shape is not None
input_shape = reduce(operator.mul, inputs[0].shape)
return input_shape * (2 if has_affine else 1)
@@ -420,33 +418,30 @@ def ewise_flop_counter(input_scale: float = 1, output_scale: float = 0) -> Calla
def zero_flop_jit(*args):
"""
Count flops for zero flop layers.
Count flops for zero flop layers.
"""
return 0
if version.parse(torch.__version__) >= version.parse('1.12.0'):
if version.parse(torch.__version__) >= version.parse("1.12.0"):
flop_mapping = {
# gemm
# gemm
aten.mm.default: matmul_flop_jit,
aten.matmul.default: matmul_flop_jit,
aten.addmm.default: addmm_flop_jit,
aten.bmm.default: bmm_flop_jit,
# convolution
# convolution
aten.convolution.default: conv_flop_jit,
aten._convolution.default: conv_flop_jit,
aten.convolution_backward.default: conv_backward_flop_jit,
# normalization
# normalization
aten.native_batch_norm.default: batchnorm_flop_jit,
aten.native_batch_norm_backward.default: batchnorm_flop_jit,
aten.cudnn_batch_norm.default: batchnorm_flop_jit,
aten.cudnn_batch_norm_backward.default: partial(batchnorm_flop_jit, training=True),
aten.native_layer_norm.default: norm_flop_counter(2, 0),
aten.native_layer_norm_backward.default: norm_flop_counter(2, 0),
# pooling
# pooling
aten.avg_pool1d.default: ewise_flop_counter(1, 0),
aten.avg_pool2d.default: ewise_flop_counter(1, 0),
aten.avg_pool2d_backward.default: ewise_flop_counter(0, 1),
@@ -469,7 +464,7 @@ if version.parse(torch.__version__) >= version.parse('1.12.0'):
}
ewise_flop_aten = [
# basic op
# basic op
aten.add.Tensor,
aten.add_.Tensor,
aten.div.Tensor,
@@ -485,8 +480,7 @@ if version.parse(torch.__version__) >= version.parse('1.12.0'):
aten.sum.default,
aten.sum.dim_IntList,
aten.mean.dim,
# activation op
# activation op
aten.hardswish.default,
aten.hardswish_.default,
aten.hardswish_backward.default,
@@ -509,15 +503,12 @@ if version.parse(torch.__version__) >= version.parse('1.12.0'):
aten.tanh.default,
aten.tanh_backward.default,
aten.threshold_backward.default,
# dropout
# dropout
aten.native_dropout.default,
aten.native_dropout_backward.default,
# distribution
# distribution
aten.bernoulli_.float,
# where
# where
aten.where.self,
]
for op in ewise_flop_aten:

View File

@@ -3,12 +3,12 @@ from functools import partial
import torch
import torch.distributed as dist
from torch.types import _bool, _device, _dtype
from torch.utils._pytree import tree_flatten, tree_map
from torch.types import _device
from torch.utils._pytree import tree_map
from ._monkey_patch import _AliasATen, _DistCommMethod, _InplaceATen, _MaybeInplaceATen, _TorchOverrideableFactoryMethod
__all__ = ['MetaTensor', 'MetaTensorMode']
__all__ = ["MetaTensor", "MetaTensorMode"]
def register_storage(r, data_ptr_fn=None):
@@ -28,8 +28,7 @@ def _normalize_tuple(x):
# a hack of inplace execution in PyTorch
def _assert_alias(func):
return func in (_AliasATen + _InplaceATen + _MaybeInplaceATen # TODO: check if should be this aggressive
)
return func in (_AliasATen + _InplaceATen + _MaybeInplaceATen) # TODO: check if should be this aggressive
class MetaTensor(torch.Tensor):
@@ -65,14 +64,15 @@ class MetaTensor(torch.Tensor):
storage_offset=elem.storage_offset(),
dtype=elem.dtype,
layout=elem.layout,
device=device or (elem.device if elem.device.type != 'meta' else torch.device('cpu')),
requires_grad=requires_grad) # deceive the frontend for aten selections
device=device or (elem.device if elem.device.type != "meta" else torch.device("cpu")),
requires_grad=requires_grad,
) # deceive the frontend for aten selections
r._tensor = elem
# ...the real tensor is held as an element on the tensor.
if not r._tensor.is_meta:
val = elem.data_ptr()
data_ptr_fn = lambda: val
r._tensor = r._tensor.to(torch.device('meta'))
r._tensor = r._tensor.to(torch.device("meta"))
# only tensor not on `meta` should be copied to `meta`
register_storage(r._tensor, data_ptr_fn)
@@ -81,7 +81,7 @@ class MetaTensor(torch.Tensor):
return r
def __repr__(self):
name = 'MetaParameter' if getattr(self, '_is_param', False) else 'MetaTensor'
name = "MetaParameter" if getattr(self, "_is_param", False) else "MetaTensor"
if self.grad_fn:
return f"{name}(..., size={tuple(self.shape)}, device='{self.device}', dtype={self.dtype}, grad_fn={self.grad_fn})"
return f"{name}(..., size={tuple(self.shape)}, device='{self.device}', dtype={self.dtype})"
@@ -97,15 +97,15 @@ class MetaTensor(torch.Tensor):
x = x._tensor
elif isinstance(x, torch.Tensor):
device = x.device
x = x.to(torch.device('meta'))
x = x.to(torch.device("meta"))
return x
args = tree_map(unwrap, args)
kwargs = tree_map(unwrap, kwargs)
if 'device' in kwargs:
device = kwargs['device']
kwargs['device'] = torch.device('meta')
if "device" in kwargs:
device = kwargs["device"]
kwargs["device"] = torch.device("meta")
# run aten for backend=CPU but actually on backend=Meta
# here we detect whether or not the execution generates a physical copy
@@ -143,21 +143,21 @@ class MetaTensor(torch.Tensor):
nonlocal device
if isinstance(x, str) or isinstance(x, _device):
device = x
return torch.device('meta')
return torch.device("meta")
return x
elem = self._tensor.to(*tree_map(replace, args), **tree_map(replace, kwargs))
return MetaTensor(elem, device=device)
def cpu(self, *args, **kwargs):
if self.device.type == 'cpu':
if self.device.type == "cpu":
return self.to(*args, **kwargs)
return self.to(*args, device='cpu', **kwargs)
return self.to(*args, device="cpu", **kwargs)
def cuda(self, device=None, non_blocking=False):
if device is not None:
return self.to(device=device, non_blocking=non_blocking)
return self.to(device='cuda:0', non_blocking=non_blocking)
return self.to(device="cuda:0", non_blocking=non_blocking)
def data_ptr(self):
return self._tensor.data_ptr()
@@ -177,19 +177,17 @@ class MetaTensorMode(object):
"""
def __init__(self):
self.torch_overrides = {} # override torch.xxx
self.dist_overrides = {} # override torch.distributed.xxx
self.torch_overrides = {} # override torch.xxx
self.dist_overrides = {} # override torch.distributed.xxx
def __enter__(self):
def _dummy(*args, **kwargs):
pass
def _new(*args, orig_new=torch.empty, **kwargs):
return MetaTensor(orig_new(*args, **{
**kwargs, 'device': 'meta'
}),
device=kwargs.get('device', torch.device('cpu')))
return MetaTensor(
orig_new(*args, **{**kwargs, "device": "meta"}), device=kwargs.get("device", torch.device("cpu"))
)
for func in _TorchOverrideableFactoryMethod:
self.torch_overrides[func] = getattr(torch, func)