[fx/profiler] debug the fx.profiler / add an example test script for fx.profiler (#1730)

* [fx/profiler] add test.

* [fx] fix file names.

* [fx] add docstring and comment.

* [fx] polish profiler.py.

* [fx] fix import errors.

* [fx] fix profiler.

* [fx] fix names.
This commit is contained in:
Super Daniel
2022-10-19 14:24:51 +08:00
committed by GitHub
parent eee84908d4
commit 30874f1692
6 changed files with 283 additions and 23 deletions

View File

@@ -1,6 +1,6 @@
import torch
__all__ = ['ALIAS_ATEN', 'INPLACE_NEW', 'INPLACE_MATH_ATEN', 'CLONE_ATEN']
__all__ = ['ALIAS_ATEN', 'INPLACE_NEW', 'INPLACE_MATH_ATEN', 'CLONE_ATEN', 'RELU_LIKE_OPS', 'RELU_LIKE_MOD']
aten = torch.ops.aten
@@ -30,3 +30,15 @@ INPLACE_MATH_ATEN = [
CLONE_ATEN = [
aten.clone.default,
]
# See illustrations in
# https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/fx/profiler/constants.py
OUTPUT_SAVED_OPS = [
torch.nn.functional.relu,
torch.nn.functional.softmax,
]
OUTPUT_SAVED_MOD = [
torch.nn.ReLU,
torch.nn.Softmax,
]

View File

@@ -5,6 +5,9 @@ from torch.fx import GraphModule, Node
from .._compatibility import compatibility, is_compatible_with_meta
if is_compatible_with_meta():
from .constants import OUTPUT_SAVED_MOD, OUTPUT_SAVED_OPS
__all__ = [
'activation_size', 'parameter_size', 'is_inplace', "calculate_fwd_in", "calculate_fwd_tmp", "calculate_fwd_out"
]
@@ -71,14 +74,35 @@ def calculate_fwd_tmp(n: Node) -> int:
fwd_tmp (int): the result of `fwd_tmp`
"""
def is_relu_node(n: Node) -> bool:
def is_relu_like_node(n: Node) -> bool:
"""Check if a node is a ReLU-like node.
ReLU-like nodes have the following properties:
- They are either `call_function` or `call_module`
- Their output tensors are directly saved for backward
- Their input tensors are not saved for backward
An example is `torch.nn.functional.softmax` which has (forward + backward):
def forward(self, input_2):
_softmax_default = torch.ops.aten._softmax.default(input_2, None, None); input_2 = None
zeros_like_default = torch.ops.aten.zeros_like.default(_softmax_default, dtype = None, layout = None, device = None, pin_memory = None)
detach_default = torch.ops.aten.detach.default(_softmax_default); _softmax_default = None
_softmax_backward_data_default = torch.ops.aten._softmax_backward_data.default(zeros_like_default, detach_default, None, None); zeros_like_default = detach_default = None
detach_default_1 = torch.ops.aten.detach.default(_softmax_backward_data_default); _softmax_backward_data_default = None
detach_default_2 = torch.ops.aten.detach.default(detach_default_1); detach_default_1 = None
Args:
n (Node): A node from the graph
Returns:
bool: Whether the node is a ReLU-like node
"""
if n.op == 'call_function':
return n.target in [torch.nn.functional.relu]
return n.target in OUTPUT_SAVED_OPS
elif n.op == 'call_module':
return type(n.graph.owning_module.get_submodule(n.target)) in [torch.nn.ReLU]
return type(n.graph.owning_module.get_submodule(n.target)) in OUTPUT_SAVED_MOD
return False
if not is_relu_node(n):
if not is_relu_like_node(n):
return activation_size(n.meta["fwd_tmp"])
return 0

View File

@@ -9,7 +9,7 @@ from torch.nn.parameter import Parameter
from torch.utils._pytree import tree_map
from .._compatibility import compatibility
from .constants import ALIAS_ATEN
from .constants import ALIAS_ATEN, OUTPUT_SAVED_MOD, OUTPUT_SAVED_OPS
from .dataflow import GraphInfo, Phase, autograd_graph_analysis, is_phase
from .memory import activation_size, parameter_size
from .opcount import flop_mapping
@@ -272,7 +272,8 @@ def _profile_meta(target: Callable, *args, **kwargs) -> Tuple[Tuple[Any, ...], G
tensor = x._tensor.detach()
tensor.uuid = x._tensor.uuid
return tensor
return x
if not isinstance(x, torch.finfo):
return x
graph_info.fwd_out = list(map(extract_tensor, normalize_tuple(out)))
@@ -314,21 +315,17 @@ def profile_function(target: 'Target', device: str = 'meta') -> Callable:
# If there is an argument that this `call_function` is inplace, we should
# still run the profiling but discard some results regarding `target`
global do_not_cache
inplace = kwargs.get('inplace', False)
if inplace or target in [torch.nn.functional.relu]:
if target in OUTPUT_SAVED_OPS:
do_not_cache = True
if inplace:
do_not_cache = True
kwargs['inplace'] = False
if device == 'meta':
out, meta = _profile_meta(func, *args, **kwargs)
# currently we set the fwd_mem_tmp of ReLU to zero
if target in [torch.nn.functional.relu]:
meta.fwd_in = []
meta.fwd_tmp = []
meta.bwd_mem_out = 0
meta.fwd_mem_tmp = 0
else:
out, meta = _profile_concrete(func, *args, **kwargs)
if inplace:
kwargs['inplace'] = True
do_not_cache = False
@@ -386,20 +383,16 @@ def profile_module(module: torch.nn.Module, device: str = 'meta') -> Callable:
global do_not_cache
inplace = getattr(module, 'inplace', False)
if inplace or type(module) in [torch.nn.ReLU]:
if type(module) in OUTPUT_SAVED_MOD:
do_not_cache = True
if inplace:
do_not_cache = True
module.inplace = False
if device == 'meta':
out, meta = _profile_meta(func, *args, **kwargs)
# currently we set the fwd_tmp of ReLU to []
if type(module) in [torch.nn.ReLU]:
meta.fwd_in = []
meta.fwd_tmp = []
meta.bwd_mem_out = 0
else:
out, meta = _profile_concrete(func, *args, **kwargs)
if inplace:
module.inplace = True
do_not_cache = False

View File

@@ -125,5 +125,5 @@ class MetaTensor(torch.Tensor):
device = kwargs['device']
result = super().to(*args, **kwargs)
if device is not None:
result = MetaTensor(deepcopy(result), fake_device=device)
result = MetaTensor(result, fake_device=device)
return result