mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-03 18:19:58 +00:00
[analyzer] a minimal implementation of static graph analyzer (#2852)
* [hotfix] meta tensor default device. * [siu] add experimental submodules to main branch. * [siu] * [siu] * [analyzer] init. * [analyzer] readme. * [analyzer] readme. * [analyzer] readme. * [analyzer] readme. * [test] add test. * Update symbolic_trace.py * mark skip tests. * try except. * try except. * try except. * s * init * init * fix * skip * skip --------- Co-authored-by: Daniel Shao <superdainiu@MININT-PVARVID.fareast.corp.microsoft.com> Co-authored-by: Daniel Shao <superdainiu@Daniels-Mac.local>
This commit is contained in:
347
colossalai/_analyzer/fx/passes/graph_profile.py
Normal file
347
colossalai/_analyzer/fx/passes/graph_profile.py
Normal file
@@ -0,0 +1,347 @@
|
||||
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.fx
|
||||
from torch.autograd.profiler_util import _format_memory, _format_time
|
||||
from torch.fx import GraphModule
|
||||
from torch.fx.node import Argument, Node, Target
|
||||
|
||||
from colossalai._analyzer._subclasses import flop_count
|
||||
from colossalai._analyzer.fx.node_util import MetaInfo
|
||||
|
||||
|
||||
def _format_flops(flops: float) -> str:
|
||||
"""Returns a formatted FLOP size string"""
|
||||
if flops > 1e12:
|
||||
return f'{flops / 1e12:.2f} TFLOPs'
|
||||
elif flops > 1e9:
|
||||
return f'{flops / 1e9:.2f} GFLOPs'
|
||||
elif flops > 1e6:
|
||||
return f'{flops / 1e6:.2f} MFLOPs'
|
||||
elif flops > 1e3:
|
||||
return f'{flops / 1e3:.2f} kFLOPs'
|
||||
return f'{flops} FLOPs'
|
||||
|
||||
|
||||
def _denormalize_tuple(t: Tuple[int, ...]) -> Tuple[int, ...]:
|
||||
return t[0] if len(t) == 1 else t
|
||||
|
||||
|
||||
def _normalize_tuple(x):
|
||||
if not isinstance(x, tuple):
|
||||
return (x,)
|
||||
return x
|
||||
|
||||
|
||||
def _current_device(module):
|
||||
return next(module.parameters()).device
|
||||
|
||||
|
||||
class GraphProfiler(torch.fx.Interpreter):
|
||||
"""
|
||||
Fetch shape argument from ``ShapeProp`` without re-executing
|
||||
the ``GraphModule`` from scratch.
|
||||
"""
|
||||
_profileable = [
|
||||
'call_function',
|
||||
'call_module',
|
||||
'call_method',
|
||||
]
|
||||
|
||||
def __init__(self, module: GraphModule, garbage_collect_values: bool = True):
|
||||
super().__init__(module, garbage_collect_values)
|
||||
|
||||
def run(self, *args, initial_env: Optional[Dict[Node, Any]] = None, enable_io_processing: bool = True) -> Any:
|
||||
"""
|
||||
Run `module` via interpretation and return the result.
|
||||
|
||||
Args:
|
||||
*args: The arguments to the Module to run, in positional order
|
||||
initial_env (Optional[Dict[Node, Any]]): An optional starting environment for execution.
|
||||
This is a dict mapping `Node` to any value. This can be used, for example, to
|
||||
pre-populate results for certain `Nodes` so as to do only partial evaluation within
|
||||
the interpreter.
|
||||
enable_io_processing (bool): If true, we process the inputs and outputs with graph's process_inputs and
|
||||
process_outputs function first before using them.
|
||||
|
||||
Returns:
|
||||
Any: The value returned from executing the Module
|
||||
"""
|
||||
self.env = initial_env if initial_env else {}
|
||||
|
||||
# Positional function args are consumed left-to-right by
|
||||
# `placeholder` nodes. Use an iterator to keep track of
|
||||
# position and extract those values.
|
||||
if enable_io_processing:
|
||||
args = self.module.graph.process_inputs(*args)
|
||||
self.args_iter: Iterator[Any] = iter(args)
|
||||
|
||||
for node in self.module.graph.nodes:
|
||||
|
||||
self.run_node(node) # No need to store.
|
||||
|
||||
if self.garbage_collect_values:
|
||||
for to_delete in self.user_to_last_uses.get(node, []):
|
||||
del self.env[to_delete]
|
||||
|
||||
if node.op == 'output':
|
||||
output_val = self.env[node]
|
||||
return self.module.graph.process_outputs(output_val) if enable_io_processing else output_val
|
||||
|
||||
def fetch_initial_env(self, device=None) -> Dict[Node, Any]:
|
||||
"""
|
||||
Fetch ``initial_env`` for execution. This is because ``ShapeProp``
|
||||
has already attached outputs of each ``Node`` to its ``MetaInfo``.
|
||||
|
||||
Args:
|
||||
device (torch.device): The device to place the execution, default to ``None``
|
||||
|
||||
Returns:
|
||||
Dict[Node, Any]: The initial environment for execution
|
||||
"""
|
||||
initial_env = {}
|
||||
for n in self.module.graph.nodes:
|
||||
initial_env[n] = _denormalize_tuple(MetaInfo(n).outputs)
|
||||
return initial_env
|
||||
|
||||
def propagate(self, *args, device=None):
|
||||
"""
|
||||
Run `module` via interpretation and profile the execution
|
||||
of each ``Node``.
|
||||
|
||||
Args:
|
||||
*args (Tensor): The sample input, not used
|
||||
device (torch.device): The device to place the execution, default to ``None``
|
||||
|
||||
Returns:
|
||||
Any: The value returned from executing the Module
|
||||
"""
|
||||
initial_env = self.fetch_initial_env(device)
|
||||
|
||||
return self.run(initial_env=initial_env)
|
||||
|
||||
def summary(self) -> str:
|
||||
"""
|
||||
Summarizes the profiled statistics of the `GraphModule` in
|
||||
tabular format. Note that this API requires the ``tabulate`` module
|
||||
to be installed.
|
||||
|
||||
Returns:
|
||||
str: The summary of the profiled statistics
|
||||
"""
|
||||
# https://github.com/pytorch/pytorch/blob/master/torch/fx/graph.py
|
||||
try:
|
||||
from tabulate import tabulate
|
||||
except ImportError:
|
||||
print("`summary` relies on the library `tabulate`, "
|
||||
"which could not be found on this machine. Run `pip "
|
||||
"install tabulate` to install the library.")
|
||||
|
||||
# Build up a list of summary information for each node
|
||||
node_summaries: List[List[Any]] = []
|
||||
last_n_info = None
|
||||
|
||||
for node in self.module.graph.nodes:
|
||||
node: Node
|
||||
n_info = MetaInfo(node)
|
||||
last_n_info = last_n_info or n_info
|
||||
node_summaries.append([
|
||||
node.op,
|
||||
str(node),
|
||||
_format_memory(n_info.accumulate_size),
|
||||
_format_memory(n_info.accumulate_size - last_n_info.accumulate_size),
|
||||
_format_memory(n_info.output_size),
|
||||
_format_memory(n_info.temp_size),
|
||||
_format_memory(n_info.param_size),
|
||||
_format_memory(n_info.backward_size),
|
||||
_format_flops(n_info.fwd_flop),
|
||||
_format_flops(n_info.bwd_flop),
|
||||
])
|
||||
last_n_info = n_info
|
||||
|
||||
# Use the ``tabulate`` library to create a well-formatted table
|
||||
# presenting our summary information
|
||||
headers: List[str] = [
|
||||
'Op type',
|
||||
'Op',
|
||||
'Accumulate size',
|
||||
'Incremental size',
|
||||
'Output size',
|
||||
'Temp size',
|
||||
'Param size',
|
||||
'Backward size',
|
||||
'Fwd FLOPs',
|
||||
'Bwd FLOPs',
|
||||
]
|
||||
|
||||
return tabulate(node_summaries, headers=headers, stralign='right')
|
||||
|
||||
|
||||
class CommunicationProfiler(GraphProfiler):
|
||||
"""
|
||||
TODO(lyl): Add this for all comm nodes
|
||||
"""
|
||||
|
||||
def __init__(self, module: GraphModule, garbage_collect_values: bool = True):
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class FlopProfiler(GraphProfiler):
|
||||
"""
|
||||
Execute an FX graph Node-by-Node and record the meta data of the result
|
||||
into the corresponding node.
|
||||
|
||||
Usage:
|
||||
>>> model = MyModule()
|
||||
>>> x = torch.rand(10, 10)
|
||||
>>> gm = colossalai.fx.symbolic_trace(model, meta_args = {'x': x}})
|
||||
>>> shape_interp = ShapeProp(gm) # must do this first
|
||||
>>> shape_interp.propagate(x)
|
||||
>>> profiler = FlopProfiler(gm)
|
||||
>>> profiler.propagate(x)
|
||||
|
||||
Args:
|
||||
module (GraphModule): The module to be executed
|
||||
|
||||
Hints:
|
||||
If you want to add a new flop count rule, you can first
|
||||
check the existing files in ``../_subclasses/flop_tensor.py``.
|
||||
If your flop count rules are incompatible with the existing
|
||||
ones, you can do so by adding a new method to this class
|
||||
with the ``@register_flop_count_impl`` decorator. The method
|
||||
should take (*args, **kwargs) instance as its input and
|
||||
generate flop count for both forward and backward as its
|
||||
output.
|
||||
|
||||
For example, if you want to add a flop count rule for
|
||||
``my_fn``, which is a hand-written operand not detected by
|
||||
PyTorch, you can do so by adding a new method to this
|
||||
class with the ``@register_flop_count_impl`` decorator:
|
||||
|
||||
>>> @register_flop_count_impl(my_fn)
|
||||
>>> def my_fn_flop_count_impl(*args, **kwargs):
|
||||
>>> return 0, 0
|
||||
"""
|
||||
_custom_flop_count_impl = {}
|
||||
|
||||
def run_node(self, n: torch.fx.Node) -> Any:
|
||||
"""
|
||||
Run a specific node ``n`` and profile its execution time and memory usage.
|
||||
Calls into call_function, call_method, and call_module only.
|
||||
|
||||
Args:
|
||||
n (Node): The Node to profile
|
||||
|
||||
Returns:
|
||||
Any: The output of the node
|
||||
|
||||
Raises:
|
||||
RuntimeError: If the node is not profileable.
|
||||
"""
|
||||
args, kwargs = self.fetch_args_kwargs_from_env(n)
|
||||
n_info = MetaInfo(n)
|
||||
|
||||
if n.op in self._profileable:
|
||||
try:
|
||||
(
|
||||
n_info.fwd_flop,
|
||||
n_info.bwd_flop,
|
||||
) = getattr(self, n.op)(n.target, args, kwargs)
|
||||
except Exception as e:
|
||||
raise RuntimeError(
|
||||
f'Error {str(e)} occurred when profiling node {n}, node.target = {n.target}. '
|
||||
f'Please refer to function\'s docstring to register the relevant profile_impl for this node!'
|
||||
) from e
|
||||
|
||||
# retain the autograd graph
|
||||
for param in self.module.parameters():
|
||||
param.grad = None
|
||||
|
||||
return _denormalize_tuple(n_info.outputs)
|
||||
|
||||
def call_function(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
|
||||
"""
|
||||
Execute a ``call_function`` node and return the profiling result.
|
||||
Dispatch to ``_custom_flop_count_impl`` if ``call_function`` should be
|
||||
profiled in a user-defined behavior.
|
||||
|
||||
Args:
|
||||
target (Target): The call target for this node. See
|
||||
`Node <https://pytorch.org/docs/master/fx.html#torch.fx.Node>`__ for
|
||||
details on semantics
|
||||
args (Tuple): Tuple of positional args for this invocation
|
||||
kwargs (Dict): Dict of keyword arguments for this invocation
|
||||
|
||||
Return
|
||||
flop_count (Tuple[int]): (fwd_flop, bwd_flop)
|
||||
"""
|
||||
assert not isinstance(target, str)
|
||||
|
||||
# Dispatch the impl for profiling, default will be ``flop_count``
|
||||
if target in self._custom_flop_count_impl:
|
||||
return self._custom_flop_count_impl[target](*args, **kwargs)
|
||||
else:
|
||||
return flop_count(target, *args, **kwargs)
|
||||
|
||||
def call_method(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
|
||||
"""
|
||||
Execute a ``call_method`` node and return the profiling result.
|
||||
|
||||
Args:
|
||||
target (Target): The call target for this node. See
|
||||
`Node <https://pytorch.org/docs/master/fx.html#torch.fx.Node>`__ for
|
||||
details on semantics
|
||||
args (Tuple): Tuple of positional args for this invocation
|
||||
kwargs (Dict): Dict of keyword arguments for this invocation
|
||||
|
||||
Return
|
||||
flop_count (Tuple[int]): (fwd_flop, bwd_flop)
|
||||
"""
|
||||
# Execute the method and return the result
|
||||
assert isinstance(target, str)
|
||||
return flop_count(getattr(torch.Tensor, target), *args, **kwargs)
|
||||
|
||||
def call_module(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
|
||||
"""
|
||||
Execute a ``call_module`` node and return the profiling result.
|
||||
|
||||
Args:
|
||||
target (Target): The call target for this node. See
|
||||
`Node <https://pytorch.org/docs/master/fx.html#torch.fx.Node>`__ for
|
||||
details on semantics
|
||||
args (Tuple): Tuple of positional args for this invocation
|
||||
kwargs (Dict): Dict of keyword arguments for this invocation
|
||||
|
||||
Return
|
||||
flop_count (Tuple[int]): (fwd_flop, bwd_flop)
|
||||
"""
|
||||
# Retrieve executed args and kwargs values from the environment
|
||||
|
||||
# Execute the method and return the result
|
||||
assert isinstance(target, str)
|
||||
submod = self.fetch_attr(target)
|
||||
return flop_count(submod, *args, **kwargs)
|
||||
|
||||
|
||||
def graph_profile_pass(module: GraphModule, *args, verbose=False) -> GraphModule:
|
||||
"""
|
||||
Run ``module`` via interpretation and profile the execution
|
||||
of each ``Node``.
|
||||
|
||||
Args:
|
||||
module (GraphModule): The GraphModule to profile
|
||||
*args (Any): The sample input, not used
|
||||
verbose (bool): Whether to print the profiling summary
|
||||
|
||||
Returns:
|
||||
GraphModule: The same GraphModule with profiling information
|
||||
"""
|
||||
for profiler_cls in (FlopProfiler,
|
||||
# CommunicationProfiler, # TODO: add communication profiling
|
||||
):
|
||||
profiler = profiler_cls(module)
|
||||
profiler.propagate(*args, device=_current_device(module))
|
||||
|
||||
if verbose:
|
||||
print(profiler.summary())
|
||||
return module
|
Reference in New Issue
Block a user