[misc] update pre-commit and run all files (#4752)

* [misc] update pre-commit

* [misc] run pre-commit

* [misc] remove useless configuration files

* [misc] ignore cuda for clang-format
This commit is contained in:
Hongxin Liu
2023-09-19 14:20:26 +08:00
committed by GitHub
parent 3c6b831c26
commit 079bf3cb26
1268 changed files with 50037 additions and 38444 deletions

View File

@@ -1,5 +1,5 @@
from dataclasses import asdict
from typing import Any, Dict, List, NamedTuple, Optional, Tuple
from typing import Any, Dict, List, Optional, Tuple
import torch
import torch.fx
@@ -85,10 +85,10 @@ class ConcreteInfoProp(torch.fx.Interpreter):
self._is_proped = True
result, meta_info = super().run_node(n)
n.meta = {**n.meta, **asdict(meta_info)} # extend MetaInfo to `n.meta`
n.meta = {**n.meta, **asdict(meta_info)} # extend MetaInfo to `n.meta`
# TODO: the attribute node_size should be removed in the future
setattr(n, 'node_size', n.meta.get('fwd_mem_tmp', 0) + n.meta.get('fwd_mem_out', 0))
n.meta['type'] = type(result)
setattr(n, "node_size", n.meta.get("fwd_mem_tmp", 0) + n.meta.get("fwd_mem_out", 0))
n.meta["type"] = type(result)
# retain the autograd graph
for param in self.module.parameters():
@@ -98,7 +98,7 @@ class ConcreteInfoProp(torch.fx.Interpreter):
# Main Node running APIs
@compatibility(is_backward_compatible=True)
def placeholder(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
def placeholder(self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
"""
Execute a ``placeholder`` node. Note that this is stateful:
``Interpreter`` maintains an internal iterator over
@@ -119,7 +119,7 @@ class ConcreteInfoProp(torch.fx.Interpreter):
return super().placeholder(target, args, kwargs), GraphInfo()
@compatibility(is_backward_compatible=True)
def get_attr(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
def get_attr(self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
"""
Execute a ``get_attr`` node. Will retrieve an attribute
value from the ``Module`` hierarchy of ``self.module``.
@@ -138,7 +138,7 @@ class ConcreteInfoProp(torch.fx.Interpreter):
return super().get_attr(target, args, kwargs), GraphInfo()
@compatibility(is_backward_compatible=True)
def call_function(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
def call_function(self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
"""
Execute a ``call_function`` node with meta tensor and return the result and its meta profile.
@@ -157,7 +157,7 @@ class ConcreteInfoProp(torch.fx.Interpreter):
return profile_function(target, self.device)(*args, **kwargs)
@compatibility(is_backward_compatible=True)
def call_method(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
def call_method(self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
"""
Execute a ``call_method`` node with meta tensor and return the result and its meta profile.
@@ -175,7 +175,7 @@ class ConcreteInfoProp(torch.fx.Interpreter):
return profile_method(target, self.device)(*args, **kwargs)
@compatibility(is_backward_compatible=True)
def call_module(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
def call_module(self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
"""
Execute a ``call_module`` node with meta tensor and return the result and its meta profile.
@@ -197,7 +197,7 @@ class ConcreteInfoProp(torch.fx.Interpreter):
return profile_module(submod, self.device)(*args, **kwargs)
@compatibility(is_backward_compatible=True)
def output(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
def output(self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
"""
Execute an ``output`` node. This really just retrieves
the value referenced by the ``output`` node and returns it.
@@ -228,7 +228,7 @@ class ConcreteInfoProp(torch.fx.Interpreter):
"""
return self.run(*args)
def summary(self, unit: str = 'MB') -> str:
def summary(self, unit: str = "MB") -> str:
"""
Summarizes the memory and FLOPs statistics of the `GraphModule` in
tabular format. Note that this API requires the ``tabulate`` module
@@ -238,9 +238,11 @@ class ConcreteInfoProp(torch.fx.Interpreter):
try:
from tabulate import tabulate
except ImportError:
print("`summary` relies on the library `tabulate`, "
"which could not be found on this machine. Run `pip "
"install tabulate` to install the library.")
print(
"`summary` relies on the library `tabulate`, "
"which could not be found on this machine. Run `pip "
"install tabulate` to install the library."
)
assert self._is_proped, "Please call `interp.run(input)` before calling `interp.summary()`."
@@ -249,10 +251,10 @@ class ConcreteInfoProp(torch.fx.Interpreter):
def mem_repr(mem: int) -> str:
unit_divisor_map = {
'kb': 1024,
'mb': 1024**2,
'gb': 1024**3,
'tb': 1024**4,
"kb": 1024,
"mb": 1024**2,
"gb": 1024**3,
"tb": 1024**4,
}
return f"{mem / unit_divisor_map[unit.lower()]:.2f} {unit.upper()}"
@@ -261,30 +263,32 @@ class ConcreteInfoProp(torch.fx.Interpreter):
for node in self.module.graph.nodes:
node: Node
node_summaries.append([
node.op,
str(node),
time_repr(node.meta['fwd_time']),
time_repr(node.meta['bwd_time']),
node.meta['save_fwd_in'],
mem_repr(node.meta['fwd_mem_out']),
mem_repr(node.meta['fwd_mem_tmp']),
mem_repr(node.meta['bwd_mem_out']),
mem_repr(node.meta['bwd_mem_tmp']),
])
node_summaries.append(
[
node.op,
str(node),
time_repr(node.meta["fwd_time"]),
time_repr(node.meta["bwd_time"]),
node.meta["save_fwd_in"],
mem_repr(node.meta["fwd_mem_out"]),
mem_repr(node.meta["fwd_mem_tmp"]),
mem_repr(node.meta["bwd_mem_out"]),
mem_repr(node.meta["bwd_mem_tmp"]),
]
)
# Use the ``tabulate`` library to create a well-formatted table
# presenting our summary information
headers: List[str] = [
'Op type',
'Op',
'Forward time',
'Backward time',
'SAVE_FWD_IN',
'FWD_OUT',
'FWD_TMP',
'BWD_OUT',
'BWD_TMP',
"Op type",
"Op",
"Forward time",
"Backward time",
"SAVE_FWD_IN",
"FWD_OUT",
"FWD_TMP",
"BWD_OUT",
"BWD_TMP",
]
return tabulate(node_summaries, headers=headers, stralign='right')
return tabulate(node_summaries, headers=headers, stralign="right")