mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-04 10:34:41 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -1,8 +1,8 @@
|
||||
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
|
||||
from typing import Any, Dict, Iterator, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.fx
|
||||
from torch.autograd.profiler_util import _format_memory, _format_time
|
||||
from torch.autograd.profiler_util import _format_memory
|
||||
from torch.fx import GraphModule
|
||||
from torch.fx.node import Argument, Node, Target
|
||||
|
||||
@@ -13,14 +13,14 @@ from colossalai._analyzer.fx.node_util import MetaInfo
|
||||
def _format_flops(flops: float) -> str:
|
||||
"""Returns a formatted FLOP size string"""
|
||||
if flops > 1e12:
|
||||
return f'{flops / 1e12:.2f} TFLOPs'
|
||||
return f"{flops / 1e12:.2f} TFLOPs"
|
||||
elif flops > 1e9:
|
||||
return f'{flops / 1e9:.2f} GFLOPs'
|
||||
return f"{flops / 1e9:.2f} GFLOPs"
|
||||
elif flops > 1e6:
|
||||
return f'{flops / 1e6:.2f} MFLOPs'
|
||||
return f"{flops / 1e6:.2f} MFLOPs"
|
||||
elif flops > 1e3:
|
||||
return f'{flops / 1e3:.2f} kFLOPs'
|
||||
return f'{flops} FLOPs'
|
||||
return f"{flops / 1e3:.2f} kFLOPs"
|
||||
return f"{flops} FLOPs"
|
||||
|
||||
|
||||
def _denormalize_tuple(t: Tuple[int, ...]) -> Tuple[int, ...]:
|
||||
@@ -42,10 +42,11 @@ class GraphProfiler(torch.fx.Interpreter):
|
||||
Fetch shape argument from ``ShapeProp`` without re-executing
|
||||
the ``GraphModule`` from scratch.
|
||||
"""
|
||||
|
||||
_profileable = [
|
||||
'call_function',
|
||||
'call_module',
|
||||
'call_method',
|
||||
"call_function",
|
||||
"call_module",
|
||||
"call_method",
|
||||
]
|
||||
|
||||
def __init__(self, module: GraphModule, garbage_collect_values: bool = True):
|
||||
@@ -77,14 +78,13 @@ class GraphProfiler(torch.fx.Interpreter):
|
||||
self.args_iter: Iterator[Any] = iter(args)
|
||||
|
||||
for node in self.module.graph.nodes:
|
||||
|
||||
self.run_node(node) # No need to store.
|
||||
self.run_node(node) # No need to store.
|
||||
|
||||
if self.garbage_collect_values:
|
||||
for to_delete in self.user_to_last_uses.get(node, []):
|
||||
del self.env[to_delete]
|
||||
|
||||
if node.op == 'output':
|
||||
if node.op == "output":
|
||||
output_val = self.env[node]
|
||||
return self.module.graph.process_outputs(output_val) if enable_io_processing else output_val
|
||||
|
||||
@@ -133,9 +133,11 @@ class GraphProfiler(torch.fx.Interpreter):
|
||||
try:
|
||||
from tabulate import tabulate
|
||||
except ImportError:
|
||||
print("`summary` relies on the library `tabulate`, "
|
||||
"which could not be found on this machine. Run `pip "
|
||||
"install tabulate` to install the library.")
|
||||
print(
|
||||
"`summary` relies on the library `tabulate`, "
|
||||
"which could not be found on this machine. Run `pip "
|
||||
"install tabulate` to install the library."
|
||||
)
|
||||
|
||||
# Build up a list of summary information for each node
|
||||
node_summaries: List[List[Any]] = []
|
||||
@@ -145,36 +147,38 @@ class GraphProfiler(torch.fx.Interpreter):
|
||||
node: Node
|
||||
n_info = MetaInfo(node)
|
||||
last_n_info = last_n_info or n_info
|
||||
node_summaries.append([
|
||||
node.op,
|
||||
str(node),
|
||||
_format_memory(n_info.accumulate_size),
|
||||
_format_memory(n_info.accumulate_size - last_n_info.accumulate_size),
|
||||
_format_memory(n_info.output_size),
|
||||
_format_memory(n_info.temp_size),
|
||||
_format_memory(n_info.param_size),
|
||||
_format_memory(n_info.backward_size),
|
||||
_format_flops(n_info.fwd_flop),
|
||||
_format_flops(n_info.bwd_flop),
|
||||
])
|
||||
node_summaries.append(
|
||||
[
|
||||
node.op,
|
||||
str(node),
|
||||
_format_memory(n_info.accumulate_size),
|
||||
_format_memory(n_info.accumulate_size - last_n_info.accumulate_size),
|
||||
_format_memory(n_info.output_size),
|
||||
_format_memory(n_info.temp_size),
|
||||
_format_memory(n_info.param_size),
|
||||
_format_memory(n_info.backward_size),
|
||||
_format_flops(n_info.fwd_flop),
|
||||
_format_flops(n_info.bwd_flop),
|
||||
]
|
||||
)
|
||||
last_n_info = n_info
|
||||
|
||||
# Use the ``tabulate`` library to create a well-formatted table
|
||||
# presenting our summary information
|
||||
headers: List[str] = [
|
||||
'Op type',
|
||||
'Op',
|
||||
'Accumulate size',
|
||||
'Incremental size',
|
||||
'Output size',
|
||||
'Temp size',
|
||||
'Param size',
|
||||
'Backward size',
|
||||
'Fwd FLOPs',
|
||||
'Bwd FLOPs',
|
||||
"Op type",
|
||||
"Op",
|
||||
"Accumulate size",
|
||||
"Incremental size",
|
||||
"Output size",
|
||||
"Temp size",
|
||||
"Param size",
|
||||
"Backward size",
|
||||
"Fwd FLOPs",
|
||||
"Bwd FLOPs",
|
||||
]
|
||||
|
||||
return tabulate(node_summaries, headers=headers, stralign='right')
|
||||
return tabulate(node_summaries, headers=headers, stralign="right")
|
||||
|
||||
|
||||
class CommunicationProfiler(GraphProfiler):
|
||||
@@ -222,6 +226,7 @@ class FlopProfiler(GraphProfiler):
|
||||
>>> def my_fn_flop_count_impl(*args, **kwargs):
|
||||
>>> return 0, 0
|
||||
"""
|
||||
|
||||
_custom_flop_count_impl = {}
|
||||
|
||||
def run_node(self, n: torch.fx.Node) -> Any:
|
||||
@@ -246,11 +251,13 @@ class FlopProfiler(GraphProfiler):
|
||||
(
|
||||
n_info.fwd_flop,
|
||||
n_info.bwd_flop,
|
||||
) = getattr(self, n.op)(n.target, args, kwargs)
|
||||
) = getattr(
|
||||
self, n.op
|
||||
)(n.target, args, kwargs)
|
||||
except Exception as e:
|
||||
raise RuntimeError(
|
||||
f'Error {str(e)} occurred when profiling node {n}, node.target = {n.target}. '
|
||||
f'Please refer to function\'s docstring to register the relevant profile_impl for this node!'
|
||||
f"Error {str(e)} occurred when profiling node {n}, node.target = {n.target}. "
|
||||
f"Please refer to function's docstring to register the relevant profile_impl for this node!"
|
||||
) from e
|
||||
|
||||
# retain the autograd graph
|
||||
@@ -259,7 +266,7 @@ class FlopProfiler(GraphProfiler):
|
||||
|
||||
return _denormalize_tuple(n_info.outputs)
|
||||
|
||||
def call_function(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
|
||||
def call_function(self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
|
||||
"""
|
||||
Execute a ``call_function`` node and return the profiling result.
|
||||
Dispatch to ``_custom_flop_count_impl`` if ``call_function`` should be
|
||||
@@ -283,7 +290,7 @@ class FlopProfiler(GraphProfiler):
|
||||
else:
|
||||
return flop_count(target, *args, **kwargs)
|
||||
|
||||
def call_method(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
|
||||
def call_method(self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
|
||||
"""
|
||||
Execute a ``call_method`` node and return the profiling result.
|
||||
|
||||
@@ -301,7 +308,7 @@ class FlopProfiler(GraphProfiler):
|
||||
assert isinstance(target, str)
|
||||
return flop_count(getattr(torch.Tensor, target), *args, **kwargs)
|
||||
|
||||
def call_module(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
|
||||
def call_module(self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
|
||||
"""
|
||||
Execute a ``call_module`` node and return the profiling result.
|
||||
|
||||
@@ -336,9 +343,10 @@ def graph_profile_pass(module: GraphModule, *args, verbose=False) -> GraphModule
|
||||
Returns:
|
||||
GraphModule: The same GraphModule with profiling information
|
||||
"""
|
||||
for profiler_cls in (FlopProfiler,
|
||||
# CommunicationProfiler, # TODO: add communication profiling
|
||||
):
|
||||
for profiler_cls in (
|
||||
FlopProfiler,
|
||||
# CommunicationProfiler, # TODO: add communication profiling
|
||||
):
|
||||
profiler = profiler_cls(module)
|
||||
profiler.propagate(*args, device=_current_device(module))
|
||||
|
||||
|
Reference in New Issue
Block a user