[misc] update pre-commit and run all files (#4752)

* [misc] update pre-commit

* [misc] run pre-commit

* [misc] remove useless configuration files

* [misc] ignore cuda for clang-format
This commit is contained in:
Hongxin Liu
2023-09-19 14:20:26 +08:00
committed by GitHub
parent 3c6b831c26
commit 079bf3cb26
1268 changed files with 50037 additions and 38444 deletions

View File

@@ -39,7 +39,7 @@ def meta_trace(module: torch.nn.Module, fake_device=None, *args, **kwargs) -> Gr
_tensor: torch.Tensor
_node: Node
__slots__ = ['_tensor', '_node']
__slots__ = ["_tensor", "_node"]
@staticmethod
def __new__(cls, tensor, fake_device=None, placeholder=False, name=None):
@@ -51,22 +51,22 @@ def meta_trace(module: torch.nn.Module, fake_device=None, *args, **kwargs) -> Gr
dtype=tensor.dtype,
layout=tensor.layout,
device=fake_device if fake_device is not None else tensor.device,
requires_grad=tensor.requires_grad) # deceive the frontend for aten selections
requires_grad=tensor.requires_grad,
) # deceive the frontend for aten selections
r._tensor = tensor
if placeholder:
if name is None:
name = 'input'
r._node = graph.create_node('placeholder',
'placeholder', (graph._root,),
name=namespace.create_name(name, tensor))
name = "input"
r._node = graph.create_node(
"placeholder", "placeholder", (graph._root,), name=namespace.create_name(name, tensor)
)
# ...the real tensor is held as an element on the tensor.
if not r._tensor.is_meta:
r._tensor = r._tensor.to(torch.device('meta'))
r._tensor = r._tensor.to(torch.device("meta"))
return r
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
def unwrap(x):
nonlocal fake_device
if isinstance(x, MetaProxy):
@@ -75,21 +75,21 @@ def meta_trace(module: torch.nn.Module, fake_device=None, *args, **kwargs) -> Gr
# assert not isinstance(x, MetaProxy)
elif isinstance(x, torch.Tensor):
fake_device = x.device
x = x.to(torch.device('meta'))
x = x.to(torch.device("meta"))
return x
def get_node(x):
if isinstance(x, torch.Tensor) and not hasattr(x, '_node'):
x = MetaProxy(x, placeholder=True, name='weight')
return x if not hasattr(x, '_node') else x._node
if isinstance(x, torch.Tensor) and not hasattr(x, "_node"):
x = MetaProxy(x, placeholder=True, name="weight")
return x if not hasattr(x, "_node") else x._node
args_node = tree_map(get_node, args)
kwargs_node = tree_map(get_node, kwargs)
node = graph.create_node('call_function', func, args_node, kwargs_node)
node = graph.create_node("call_function", func, args_node, kwargs_node)
if 'device' in kwargs:
fake_device = kwargs['device']
kwargs['device'] = torch.device('meta')
if "device" in kwargs:
fake_device = kwargs["device"]
kwargs["device"] = torch.device("meta")
args = tree_map(unwrap, args)
kwargs = tree_map(unwrap, kwargs)
@@ -103,9 +103,12 @@ def meta_trace(module: torch.nn.Module, fake_device=None, *args, **kwargs) -> Gr
if isinstance(x, torch.Tensor):
nonlocal fake_device
if not x.is_meta:
x = x.to(torch.device('meta'))
return MetaProxy(
x, fake_device=fake_device) if isinstance(x, torch.Tensor) and not hasattr(x, '_tensor') else x
x = x.to(torch.device("meta"))
return (
MetaProxy(x, fake_device=fake_device)
if isinstance(x, torch.Tensor) and not hasattr(x, "_tensor")
else x
)
def set_node(x):
x._node = node
@@ -125,9 +128,12 @@ def meta_trace(module: torch.nn.Module, fake_device=None, *args, **kwargs) -> Gr
for tensor in normalize_tuple(out):
if is_autogradable(tensor) and tensor.requires_grad:
grad = torch.empty_like(tensor._tensor, device=torch.device('meta')) if isinstance(
tensor, MetaProxy) else torch.empty_like(tensor, device=torch.device('meta'))
torch.autograd.backward(tensor,
MetaProxy(grad, fake_device=tensor.device, placeholder=True),
retain_graph=True)
grad = (
torch.empty_like(tensor._tensor, device=torch.device("meta"))
if isinstance(tensor, MetaProxy)
else torch.empty_like(tensor, device=torch.device("meta"))
)
torch.autograd.backward(
tensor, MetaProxy(grad, fake_device=tensor.device, placeholder=True), retain_graph=True
)
return graph