[misc] update pre-commit and run all files (#4752)

* [misc] update pre-commit

* [misc] run pre-commit

* [misc] remove useless configuration files

* [misc] ignore cuda for clang-format
This commit is contained in:
Hongxin Liu
2023-09-19 14:20:26 +08:00
committed by GitHub
parent 3c6b831c26
commit 079bf3cb26
1268 changed files with 50037 additions and 38444 deletions

View File

@@ -3,12 +3,12 @@ from functools import partial
import torch
import torch.distributed as dist
from torch.types import _bool, _device, _dtype
from torch.utils._pytree import tree_flatten, tree_map
from torch.types import _device
from torch.utils._pytree import tree_map
from ._monkey_patch import _AliasATen, _DistCommMethod, _InplaceATen, _MaybeInplaceATen, _TorchOverrideableFactoryMethod
__all__ = ['MetaTensor', 'MetaTensorMode']
__all__ = ["MetaTensor", "MetaTensorMode"]
def register_storage(r, data_ptr_fn=None):
@@ -28,8 +28,7 @@ def _normalize_tuple(x):
# a hack of inplace execution in PyTorch
def _assert_alias(func):
return func in (_AliasATen + _InplaceATen + _MaybeInplaceATen # TODO: check if should be this aggressive
)
return func in (_AliasATen + _InplaceATen + _MaybeInplaceATen) # TODO: check if should be this aggressive
class MetaTensor(torch.Tensor):
@@ -65,14 +64,15 @@ class MetaTensor(torch.Tensor):
storage_offset=elem.storage_offset(),
dtype=elem.dtype,
layout=elem.layout,
device=device or (elem.device if elem.device.type != 'meta' else torch.device('cpu')),
requires_grad=requires_grad) # deceive the frontend for aten selections
device=device or (elem.device if elem.device.type != "meta" else torch.device("cpu")),
requires_grad=requires_grad,
) # deceive the frontend for aten selections
r._tensor = elem
# ...the real tensor is held as an element on the tensor.
if not r._tensor.is_meta:
val = elem.data_ptr()
data_ptr_fn = lambda: val
r._tensor = r._tensor.to(torch.device('meta'))
r._tensor = r._tensor.to(torch.device("meta"))
# only tensor not on `meta` should be copied to `meta`
register_storage(r._tensor, data_ptr_fn)
@@ -81,7 +81,7 @@ class MetaTensor(torch.Tensor):
return r
def __repr__(self):
name = 'MetaParameter' if getattr(self, '_is_param', False) else 'MetaTensor'
name = "MetaParameter" if getattr(self, "_is_param", False) else "MetaTensor"
if self.grad_fn:
return f"{name}(..., size={tuple(self.shape)}, device='{self.device}', dtype={self.dtype}, grad_fn={self.grad_fn})"
return f"{name}(..., size={tuple(self.shape)}, device='{self.device}', dtype={self.dtype})"
@@ -97,15 +97,15 @@ class MetaTensor(torch.Tensor):
x = x._tensor
elif isinstance(x, torch.Tensor):
device = x.device
x = x.to(torch.device('meta'))
x = x.to(torch.device("meta"))
return x
args = tree_map(unwrap, args)
kwargs = tree_map(unwrap, kwargs)
if 'device' in kwargs:
device = kwargs['device']
kwargs['device'] = torch.device('meta')
if "device" in kwargs:
device = kwargs["device"]
kwargs["device"] = torch.device("meta")
# run aten for backend=CPU but actually on backend=Meta
# here we detect whether or not the execution generates a physical copy
@@ -143,21 +143,21 @@ class MetaTensor(torch.Tensor):
nonlocal device
if isinstance(x, str) or isinstance(x, _device):
device = x
return torch.device('meta')
return torch.device("meta")
return x
elem = self._tensor.to(*tree_map(replace, args), **tree_map(replace, kwargs))
return MetaTensor(elem, device=device)
def cpu(self, *args, **kwargs):
if self.device.type == 'cpu':
if self.device.type == "cpu":
return self.to(*args, **kwargs)
return self.to(*args, device='cpu', **kwargs)
return self.to(*args, device="cpu", **kwargs)
def cuda(self, device=None, non_blocking=False):
if device is not None:
return self.to(device=device, non_blocking=non_blocking)
return self.to(device='cuda:0', non_blocking=non_blocking)
return self.to(device="cuda:0", non_blocking=non_blocking)
def data_ptr(self):
return self._tensor.data_ptr()
@@ -177,19 +177,17 @@ class MetaTensorMode(object):
"""
def __init__(self):
self.torch_overrides = {} # override torch.xxx
self.dist_overrides = {} # override torch.distributed.xxx
self.torch_overrides = {} # override torch.xxx
self.dist_overrides = {} # override torch.distributed.xxx
def __enter__(self):
def _dummy(*args, **kwargs):
pass
def _new(*args, orig_new=torch.empty, **kwargs):
return MetaTensor(orig_new(*args, **{
**kwargs, 'device': 'meta'
}),
device=kwargs.get('device', torch.device('cpu')))
return MetaTensor(
orig_new(*args, **{**kwargs, "device": "meta"}), device=kwargs.get("device", torch.device("cpu"))
)
for func in _TorchOverrideableFactoryMethod:
self.torch_overrides[func] = getattr(torch, func)