mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-07 20:10:17 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
from dataclasses import asdict
|
||||
from typing import Any, Dict, List, NamedTuple, Optional, Tuple
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.fx
|
||||
@@ -85,10 +85,10 @@ class ConcreteInfoProp(torch.fx.Interpreter):
|
||||
self._is_proped = True
|
||||
result, meta_info = super().run_node(n)
|
||||
|
||||
n.meta = {**n.meta, **asdict(meta_info)} # extend MetaInfo to `n.meta`
|
||||
n.meta = {**n.meta, **asdict(meta_info)} # extend MetaInfo to `n.meta`
|
||||
# TODO: the attribute node_size should be removed in the future
|
||||
setattr(n, 'node_size', n.meta.get('fwd_mem_tmp', 0) + n.meta.get('fwd_mem_out', 0))
|
||||
n.meta['type'] = type(result)
|
||||
setattr(n, "node_size", n.meta.get("fwd_mem_tmp", 0) + n.meta.get("fwd_mem_out", 0))
|
||||
n.meta["type"] = type(result)
|
||||
|
||||
# retain the autograd graph
|
||||
for param in self.module.parameters():
|
||||
@@ -98,7 +98,7 @@ class ConcreteInfoProp(torch.fx.Interpreter):
|
||||
|
||||
# Main Node running APIs
|
||||
@compatibility(is_backward_compatible=True)
|
||||
def placeholder(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
|
||||
def placeholder(self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
|
||||
"""
|
||||
Execute a ``placeholder`` node. Note that this is stateful:
|
||||
``Interpreter`` maintains an internal iterator over
|
||||
@@ -119,7 +119,7 @@ class ConcreteInfoProp(torch.fx.Interpreter):
|
||||
return super().placeholder(target, args, kwargs), GraphInfo()
|
||||
|
||||
@compatibility(is_backward_compatible=True)
|
||||
def get_attr(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
|
||||
def get_attr(self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
|
||||
"""
|
||||
Execute a ``get_attr`` node. Will retrieve an attribute
|
||||
value from the ``Module`` hierarchy of ``self.module``.
|
||||
@@ -138,7 +138,7 @@ class ConcreteInfoProp(torch.fx.Interpreter):
|
||||
return super().get_attr(target, args, kwargs), GraphInfo()
|
||||
|
||||
@compatibility(is_backward_compatible=True)
|
||||
def call_function(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
|
||||
def call_function(self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
|
||||
"""
|
||||
Execute a ``call_function`` node with meta tensor and return the result and its meta profile.
|
||||
|
||||
@@ -157,7 +157,7 @@ class ConcreteInfoProp(torch.fx.Interpreter):
|
||||
return profile_function(target, self.device)(*args, **kwargs)
|
||||
|
||||
@compatibility(is_backward_compatible=True)
|
||||
def call_method(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
|
||||
def call_method(self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
|
||||
"""
|
||||
Execute a ``call_method`` node with meta tensor and return the result and its meta profile.
|
||||
|
||||
@@ -175,7 +175,7 @@ class ConcreteInfoProp(torch.fx.Interpreter):
|
||||
return profile_method(target, self.device)(*args, **kwargs)
|
||||
|
||||
@compatibility(is_backward_compatible=True)
|
||||
def call_module(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
|
||||
def call_module(self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
|
||||
"""
|
||||
Execute a ``call_module`` node with meta tensor and return the result and its meta profile.
|
||||
|
||||
@@ -197,7 +197,7 @@ class ConcreteInfoProp(torch.fx.Interpreter):
|
||||
return profile_module(submod, self.device)(*args, **kwargs)
|
||||
|
||||
@compatibility(is_backward_compatible=True)
|
||||
def output(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
|
||||
def output(self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
|
||||
"""
|
||||
Execute an ``output`` node. This really just retrieves
|
||||
the value referenced by the ``output`` node and returns it.
|
||||
@@ -228,7 +228,7 @@ class ConcreteInfoProp(torch.fx.Interpreter):
|
||||
"""
|
||||
return self.run(*args)
|
||||
|
||||
def summary(self, unit: str = 'MB') -> str:
|
||||
def summary(self, unit: str = "MB") -> str:
|
||||
"""
|
||||
Summarizes the memory and FLOPs statistics of the `GraphModule` in
|
||||
tabular format. Note that this API requires the ``tabulate`` module
|
||||
@@ -238,9 +238,11 @@ class ConcreteInfoProp(torch.fx.Interpreter):
|
||||
try:
|
||||
from tabulate import tabulate
|
||||
except ImportError:
|
||||
print("`summary` relies on the library `tabulate`, "
|
||||
"which could not be found on this machine. Run `pip "
|
||||
"install tabulate` to install the library.")
|
||||
print(
|
||||
"`summary` relies on the library `tabulate`, "
|
||||
"which could not be found on this machine. Run `pip "
|
||||
"install tabulate` to install the library."
|
||||
)
|
||||
|
||||
assert self._is_proped, "Please call `interp.run(input)` before calling `interp.summary()`."
|
||||
|
||||
@@ -249,10 +251,10 @@ class ConcreteInfoProp(torch.fx.Interpreter):
|
||||
|
||||
def mem_repr(mem: int) -> str:
|
||||
unit_divisor_map = {
|
||||
'kb': 1024,
|
||||
'mb': 1024**2,
|
||||
'gb': 1024**3,
|
||||
'tb': 1024**4,
|
||||
"kb": 1024,
|
||||
"mb": 1024**2,
|
||||
"gb": 1024**3,
|
||||
"tb": 1024**4,
|
||||
}
|
||||
return f"{mem / unit_divisor_map[unit.lower()]:.2f} {unit.upper()}"
|
||||
|
||||
@@ -261,30 +263,32 @@ class ConcreteInfoProp(torch.fx.Interpreter):
|
||||
|
||||
for node in self.module.graph.nodes:
|
||||
node: Node
|
||||
node_summaries.append([
|
||||
node.op,
|
||||
str(node),
|
||||
time_repr(node.meta['fwd_time']),
|
||||
time_repr(node.meta['bwd_time']),
|
||||
node.meta['save_fwd_in'],
|
||||
mem_repr(node.meta['fwd_mem_out']),
|
||||
mem_repr(node.meta['fwd_mem_tmp']),
|
||||
mem_repr(node.meta['bwd_mem_out']),
|
||||
mem_repr(node.meta['bwd_mem_tmp']),
|
||||
])
|
||||
node_summaries.append(
|
||||
[
|
||||
node.op,
|
||||
str(node),
|
||||
time_repr(node.meta["fwd_time"]),
|
||||
time_repr(node.meta["bwd_time"]),
|
||||
node.meta["save_fwd_in"],
|
||||
mem_repr(node.meta["fwd_mem_out"]),
|
||||
mem_repr(node.meta["fwd_mem_tmp"]),
|
||||
mem_repr(node.meta["bwd_mem_out"]),
|
||||
mem_repr(node.meta["bwd_mem_tmp"]),
|
||||
]
|
||||
)
|
||||
|
||||
# Use the ``tabulate`` library to create a well-formatted table
|
||||
# presenting our summary information
|
||||
headers: List[str] = [
|
||||
'Op type',
|
||||
'Op',
|
||||
'Forward time',
|
||||
'Backward time',
|
||||
'SAVE_FWD_IN',
|
||||
'FWD_OUT',
|
||||
'FWD_TMP',
|
||||
'BWD_OUT',
|
||||
'BWD_TMP',
|
||||
"Op type",
|
||||
"Op",
|
||||
"Forward time",
|
||||
"Backward time",
|
||||
"SAVE_FWD_IN",
|
||||
"FWD_OUT",
|
||||
"FWD_TMP",
|
||||
"BWD_OUT",
|
||||
"BWD_TMP",
|
||||
]
|
||||
|
||||
return tabulate(node_summaries, headers=headers, stralign='right')
|
||||
return tabulate(node_summaries, headers=headers, stralign="right")
|
||||
|
Reference in New Issue
Block a user