[misc] update pre-commit and run all files (#4752)

* [misc] update pre-commit

* [misc] run pre-commit

* [misc] remove useless configuration files

* [misc] ignore cuda for clang-format
This commit is contained in:
Hongxin Liu
2023-09-19 14:20:26 +08:00
committed by GitHub
parent 3c6b831c26
commit 079bf3cb26
1268 changed files with 50037 additions and 38444 deletions

View File

@@ -1,8 +1,8 @@
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
from typing import Any, Dict, Iterator, List, Optional, Tuple
import torch
import torch.fx
from torch.autograd.profiler_util import _format_memory, _format_time
from torch.autograd.profiler_util import _format_memory
from torch.fx import GraphModule
from torch.fx.node import Argument, Node, Target
@@ -13,14 +13,14 @@ from colossalai._analyzer.fx.node_util import MetaInfo
def _format_flops(flops: float) -> str:
"""Returns a formatted FLOP size string"""
if flops > 1e12:
return f'{flops / 1e12:.2f} TFLOPs'
return f"{flops / 1e12:.2f} TFLOPs"
elif flops > 1e9:
return f'{flops / 1e9:.2f} GFLOPs'
return f"{flops / 1e9:.2f} GFLOPs"
elif flops > 1e6:
return f'{flops / 1e6:.2f} MFLOPs'
return f"{flops / 1e6:.2f} MFLOPs"
elif flops > 1e3:
return f'{flops / 1e3:.2f} kFLOPs'
return f'{flops} FLOPs'
return f"{flops / 1e3:.2f} kFLOPs"
return f"{flops} FLOPs"
def _denormalize_tuple(t: Tuple[int, ...]) -> Tuple[int, ...]:
@@ -42,10 +42,11 @@ class GraphProfiler(torch.fx.Interpreter):
Fetch shape argument from ``ShapeProp`` without re-executing
the ``GraphModule`` from scratch.
"""
_profileable = [
'call_function',
'call_module',
'call_method',
"call_function",
"call_module",
"call_method",
]
def __init__(self, module: GraphModule, garbage_collect_values: bool = True):
@@ -77,14 +78,13 @@ class GraphProfiler(torch.fx.Interpreter):
self.args_iter: Iterator[Any] = iter(args)
for node in self.module.graph.nodes:
self.run_node(node) # No need to store.
self.run_node(node) # No need to store.
if self.garbage_collect_values:
for to_delete in self.user_to_last_uses.get(node, []):
del self.env[to_delete]
if node.op == 'output':
if node.op == "output":
output_val = self.env[node]
return self.module.graph.process_outputs(output_val) if enable_io_processing else output_val
@@ -133,9 +133,11 @@ class GraphProfiler(torch.fx.Interpreter):
try:
from tabulate import tabulate
except ImportError:
print("`summary` relies on the library `tabulate`, "
"which could not be found on this machine. Run `pip "
"install tabulate` to install the library.")
print(
"`summary` relies on the library `tabulate`, "
"which could not be found on this machine. Run `pip "
"install tabulate` to install the library."
)
# Build up a list of summary information for each node
node_summaries: List[List[Any]] = []
@@ -145,36 +147,38 @@ class GraphProfiler(torch.fx.Interpreter):
node: Node
n_info = MetaInfo(node)
last_n_info = last_n_info or n_info
node_summaries.append([
node.op,
str(node),
_format_memory(n_info.accumulate_size),
_format_memory(n_info.accumulate_size - last_n_info.accumulate_size),
_format_memory(n_info.output_size),
_format_memory(n_info.temp_size),
_format_memory(n_info.param_size),
_format_memory(n_info.backward_size),
_format_flops(n_info.fwd_flop),
_format_flops(n_info.bwd_flop),
])
node_summaries.append(
[
node.op,
str(node),
_format_memory(n_info.accumulate_size),
_format_memory(n_info.accumulate_size - last_n_info.accumulate_size),
_format_memory(n_info.output_size),
_format_memory(n_info.temp_size),
_format_memory(n_info.param_size),
_format_memory(n_info.backward_size),
_format_flops(n_info.fwd_flop),
_format_flops(n_info.bwd_flop),
]
)
last_n_info = n_info
# Use the ``tabulate`` library to create a well-formatted table
# presenting our summary information
headers: List[str] = [
'Op type',
'Op',
'Accumulate size',
'Incremental size',
'Output size',
'Temp size',
'Param size',
'Backward size',
'Fwd FLOPs',
'Bwd FLOPs',
"Op type",
"Op",
"Accumulate size",
"Incremental size",
"Output size",
"Temp size",
"Param size",
"Backward size",
"Fwd FLOPs",
"Bwd FLOPs",
]
return tabulate(node_summaries, headers=headers, stralign='right')
return tabulate(node_summaries, headers=headers, stralign="right")
class CommunicationProfiler(GraphProfiler):
@@ -222,6 +226,7 @@ class FlopProfiler(GraphProfiler):
>>> def my_fn_flop_count_impl(*args, **kwargs):
>>> return 0, 0
"""
_custom_flop_count_impl = {}
def run_node(self, n: torch.fx.Node) -> Any:
@@ -246,11 +251,13 @@ class FlopProfiler(GraphProfiler):
(
n_info.fwd_flop,
n_info.bwd_flop,
) = getattr(self, n.op)(n.target, args, kwargs)
) = getattr(
self, n.op
)(n.target, args, kwargs)
except Exception as e:
raise RuntimeError(
f'Error {str(e)} occurred when profiling node {n}, node.target = {n.target}. '
f'Please refer to function\'s docstring to register the relevant profile_impl for this node!'
f"Error {str(e)} occurred when profiling node {n}, node.target = {n.target}. "
f"Please refer to function's docstring to register the relevant profile_impl for this node!"
) from e
# retain the autograd graph
@@ -259,7 +266,7 @@ class FlopProfiler(GraphProfiler):
return _denormalize_tuple(n_info.outputs)
def call_function(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
def call_function(self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
"""
Execute a ``call_function`` node and return the profiling result.
Dispatch to ``_custom_flop_count_impl`` if ``call_function`` should be
@@ -283,7 +290,7 @@ class FlopProfiler(GraphProfiler):
else:
return flop_count(target, *args, **kwargs)
def call_method(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
def call_method(self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
"""
Execute a ``call_method`` node and return the profiling result.
@@ -301,7 +308,7 @@ class FlopProfiler(GraphProfiler):
assert isinstance(target, str)
return flop_count(getattr(torch.Tensor, target), *args, **kwargs)
def call_module(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
def call_module(self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
"""
Execute a ``call_module`` node and return the profiling result.
@@ -336,9 +343,10 @@ def graph_profile_pass(module: GraphModule, *args, verbose=False) -> GraphModule
Returns:
GraphModule: The same GraphModule with profiling information
"""
for profiler_cls in (FlopProfiler,
# CommunicationProfiler, # TODO: add communication profiling
):
for profiler_cls in (
FlopProfiler,
# CommunicationProfiler, # TODO: add communication profiling
):
profiler = profiler_cls(module)
profiler.propagate(*args, device=_current_device(module))