mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-01 01:06:00 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -109,13 +109,13 @@ class MetaInfoProp(torch.fx.Interpreter):
|
||||
return TensorMetadata(None, None, False, None, 0, False)
|
||||
|
||||
tensor_meta = tree_map(extract_tensor_meta, result)
|
||||
n.meta['tensor_meta'] = tensor_meta
|
||||
n.meta = {**n.meta, **asdict(meta_info)} # extend MetaInfo to `n.meta`
|
||||
n.meta["tensor_meta"] = tensor_meta
|
||||
n.meta = {**n.meta, **asdict(meta_info)} # extend MetaInfo to `n.meta`
|
||||
# TODO: the attribute node_size should be removed in the future
|
||||
setattr(n, 'node_size', activation_size(n.meta.get('fwd_out', 0)) + activation_size(n.meta.get('fwd_tmp', 0)))
|
||||
setattr(n, 'fwd_flop', n.meta.get('fwd_flop', 0))
|
||||
setattr(n, 'bwd_flop', n.meta.get('bwd_flop', 0))
|
||||
n.meta['type'] = type(result)
|
||||
setattr(n, "node_size", activation_size(n.meta.get("fwd_out", 0)) + activation_size(n.meta.get("fwd_tmp", 0)))
|
||||
setattr(n, "fwd_flop", n.meta.get("fwd_flop", 0))
|
||||
setattr(n, "bwd_flop", n.meta.get("bwd_flop", 0))
|
||||
n.meta["type"] = type(result)
|
||||
|
||||
# retain the autograd graph
|
||||
for param in self.module.parameters():
|
||||
@@ -125,7 +125,7 @@ class MetaInfoProp(torch.fx.Interpreter):
|
||||
|
||||
# Main Node running APIs
|
||||
@compatibility(is_backward_compatible=True)
|
||||
def placeholder(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
|
||||
def placeholder(self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
|
||||
"""
|
||||
Execute a ``placeholder`` node. Note that this is stateful:
|
||||
``Interpreter`` maintains an internal iterator over
|
||||
@@ -146,7 +146,7 @@ class MetaInfoProp(torch.fx.Interpreter):
|
||||
return super().placeholder(target, args, kwargs), GraphInfo()
|
||||
|
||||
@compatibility(is_backward_compatible=True)
|
||||
def get_attr(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
|
||||
def get_attr(self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
|
||||
"""
|
||||
Execute a ``get_attr`` node. Will retrieve an attribute
|
||||
value from the ``Module`` hierarchy of ``self.module``.
|
||||
@@ -165,7 +165,7 @@ class MetaInfoProp(torch.fx.Interpreter):
|
||||
return super().get_attr(target, args, kwargs), GraphInfo()
|
||||
|
||||
@compatibility(is_backward_compatible=True)
|
||||
def call_function(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
|
||||
def call_function(self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
|
||||
"""
|
||||
Execute a ``call_function`` node with meta tensor and return the result and its meta profile.
|
||||
|
||||
@@ -184,7 +184,7 @@ class MetaInfoProp(torch.fx.Interpreter):
|
||||
return profile_function(target)(*args, **kwargs)
|
||||
|
||||
@compatibility(is_backward_compatible=True)
|
||||
def call_method(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
|
||||
def call_method(self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
|
||||
"""
|
||||
Execute a ``call_method`` node with meta tensor and return the result and its meta profile.
|
||||
|
||||
@@ -202,7 +202,7 @@ class MetaInfoProp(torch.fx.Interpreter):
|
||||
return profile_method(target)(*args, **kwargs)
|
||||
|
||||
@compatibility(is_backward_compatible=True)
|
||||
def call_module(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
|
||||
def call_module(self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
|
||||
"""
|
||||
Execute a ``call_module`` node with meta tensor and return the result and its meta profile.
|
||||
|
||||
@@ -224,7 +224,7 @@ class MetaInfoProp(torch.fx.Interpreter):
|
||||
return profile_module(submod)(*args, **kwargs)
|
||||
|
||||
@compatibility(is_backward_compatible=True)
|
||||
def output(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
|
||||
def output(self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
|
||||
"""
|
||||
Execute an ``output`` node. This really just retrieves
|
||||
the value referenced by the ``output`` node and returns it.
|
||||
@@ -240,7 +240,7 @@ class MetaInfoProp(torch.fx.Interpreter):
|
||||
result (Any): The argument value that was retrieved
|
||||
meta_info (MetaInfo): The memory cost and FLOPs estimated with `MetaTensor`.
|
||||
"""
|
||||
if hasattr(args[0], '_tensor'):
|
||||
if hasattr(args[0], "_tensor"):
|
||||
return args[0], GraphInfo(fwd_in=[args[0]._tensor])
|
||||
return args[0], GraphInfo(save_fwd_in=True)
|
||||
|
||||
@@ -257,7 +257,7 @@ class MetaInfoProp(torch.fx.Interpreter):
|
||||
"""
|
||||
return super().run(*args)
|
||||
|
||||
def summary(self, unit: str = 'MB') -> str:
|
||||
def summary(self, unit: str = "MB") -> str:
|
||||
"""
|
||||
Summarizes the memory and FLOPs statistics of the `GraphModule` in
|
||||
tabular format. Note that this API requires the ``tabulate`` module
|
||||
@@ -267,9 +267,11 @@ class MetaInfoProp(torch.fx.Interpreter):
|
||||
try:
|
||||
from tabulate import tabulate
|
||||
except ImportError:
|
||||
print("`summary` relies on the library `tabulate`, "
|
||||
"which could not be found on this machine. Run `pip "
|
||||
"install tabulate` to install the library.")
|
||||
print(
|
||||
"`summary` relies on the library `tabulate`, "
|
||||
"which could not be found on this machine. Run `pip "
|
||||
"install tabulate` to install the library."
|
||||
)
|
||||
|
||||
assert self._is_proped, "Please call `interp.run(input)` before calling `interp.summary()`."
|
||||
|
||||
@@ -278,10 +280,10 @@ class MetaInfoProp(torch.fx.Interpreter):
|
||||
|
||||
def mem_repr(mem: int) -> str:
|
||||
unit_divisor_map = {
|
||||
'kb': 1024,
|
||||
'mb': 1024**2,
|
||||
'gb': 1024**3,
|
||||
'tb': 1024**4,
|
||||
"kb": 1024,
|
||||
"mb": 1024**2,
|
||||
"gb": 1024**3,
|
||||
"tb": 1024**4,
|
||||
}
|
||||
return f"{mem / unit_divisor_map[unit.lower()]:.2f} {unit.upper()}"
|
||||
|
||||
@@ -292,35 +294,37 @@ class MetaInfoProp(torch.fx.Interpreter):
|
||||
for node in self.module.graph.nodes:
|
||||
node: Node
|
||||
accumulate_size += calculate_fwd_out(node) + calculate_fwd_tmp(node)
|
||||
node_summaries.append([
|
||||
node.op,
|
||||
str(node),
|
||||
flops_repr(node.meta['fwd_flop']),
|
||||
flops_repr(node.meta['bwd_flop']),
|
||||
mem_repr(accumulate_size),
|
||||
mem_repr(calculate_fwd_in(node)),
|
||||
mem_repr(calculate_fwd_out(node)),
|
||||
mem_repr(calculate_fwd_tmp(node)),
|
||||
mem_repr(node.meta['bwd_mem_out']),
|
||||
mem_repr(node.meta['bwd_mem_tmp']),
|
||||
])
|
||||
node_summaries.append(
|
||||
[
|
||||
node.op,
|
||||
str(node),
|
||||
flops_repr(node.meta["fwd_flop"]),
|
||||
flops_repr(node.meta["bwd_flop"]),
|
||||
mem_repr(accumulate_size),
|
||||
mem_repr(calculate_fwd_in(node)),
|
||||
mem_repr(calculate_fwd_out(node)),
|
||||
mem_repr(calculate_fwd_tmp(node)),
|
||||
mem_repr(node.meta["bwd_mem_out"]),
|
||||
mem_repr(node.meta["bwd_mem_tmp"]),
|
||||
]
|
||||
)
|
||||
|
||||
# Use the ``tabulate`` library to create a well-formatted table
|
||||
# presenting our summary information
|
||||
headers: List[str] = [
|
||||
'Op type',
|
||||
'Op',
|
||||
'Forward FLOPs',
|
||||
'Backward FLOPs',
|
||||
'Accumulated Memory',
|
||||
'FWD_IN',
|
||||
'FWD_OUT',
|
||||
'FWD_TMP',
|
||||
'BWD_OUT',
|
||||
'BWD_TMP',
|
||||
"Op type",
|
||||
"Op",
|
||||
"Forward FLOPs",
|
||||
"Backward FLOPs",
|
||||
"Accumulated Memory",
|
||||
"FWD_IN",
|
||||
"FWD_OUT",
|
||||
"FWD_TMP",
|
||||
"BWD_OUT",
|
||||
"BWD_TMP",
|
||||
]
|
||||
|
||||
return tabulate(node_summaries, headers=headers, stralign='right')
|
||||
return tabulate(node_summaries, headers=headers, stralign="right")
|
||||
|
||||
|
||||
def metainfo_trace(gm: torch.fx.GraphModule, *args, verbose: bool = False, unit: str = "MB", **kwargs) -> None:
|
||||
@@ -344,15 +348,16 @@ def metainfo_trace(gm: torch.fx.GraphModule, *args, verbose: bool = False, unit:
|
||||
Returns:
|
||||
torch.fx.GraphModule: The ``GraphModule`` annotated with MetaInfo.
|
||||
"""
|
||||
device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
|
||||
device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
|
||||
interp = MetaInfoProp(gm.to(device))
|
||||
if is_compatible_with_meta():
|
||||
from colossalai.fx.profiler import MetaTensor
|
||||
|
||||
args = tree_map(lambda x: MetaTensor(x, fake_device=device), args)
|
||||
kwargs = tree_map(lambda x: MetaTensor(x, fake_device=device), kwargs)
|
||||
interp.propagate(*args, **kwargs)
|
||||
if verbose:
|
||||
interp.summary(unit)
|
||||
gm.to('cpu')
|
||||
gm.to("cpu")
|
||||
del interp
|
||||
return gm
|
||||
|
Reference in New Issue
Block a user