mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-18 16:00:49 +00:00
[fx/profiler] debug the fx.profiler / add an example test script for fx.profiler (#1730)
* [fx/profiler] add test. * [fx] fix file names. * [fx] add docstring and comment. * [fx] polish profiler.py. * [fx] fix import errors. * [fx] fix profiler. * [fx] fix names.
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
import torch
|
||||
|
||||
__all__ = ['ALIAS_ATEN', 'INPLACE_NEW', 'INPLACE_MATH_ATEN', 'CLONE_ATEN']
|
||||
__all__ = ['ALIAS_ATEN', 'INPLACE_NEW', 'INPLACE_MATH_ATEN', 'CLONE_ATEN', 'RELU_LIKE_OPS', 'RELU_LIKE_MOD']
|
||||
|
||||
aten = torch.ops.aten
|
||||
|
||||
@@ -30,3 +30,15 @@ INPLACE_MATH_ATEN = [
|
||||
CLONE_ATEN = [
|
||||
aten.clone.default,
|
||||
]
|
||||
|
||||
# See illustrations in
|
||||
# https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/fx/profiler/constants.py
|
||||
OUTPUT_SAVED_OPS = [
|
||||
torch.nn.functional.relu,
|
||||
torch.nn.functional.softmax,
|
||||
]
|
||||
|
||||
OUTPUT_SAVED_MOD = [
|
||||
torch.nn.ReLU,
|
||||
torch.nn.Softmax,
|
||||
]
|
||||
|
@@ -5,6 +5,9 @@ from torch.fx import GraphModule, Node
|
||||
|
||||
from .._compatibility import compatibility, is_compatible_with_meta
|
||||
|
||||
if is_compatible_with_meta():
|
||||
from .constants import OUTPUT_SAVED_MOD, OUTPUT_SAVED_OPS
|
||||
|
||||
__all__ = [
|
||||
'activation_size', 'parameter_size', 'is_inplace', "calculate_fwd_in", "calculate_fwd_tmp", "calculate_fwd_out"
|
||||
]
|
||||
@@ -71,14 +74,35 @@ def calculate_fwd_tmp(n: Node) -> int:
|
||||
fwd_tmp (int): the result of `fwd_tmp`
|
||||
"""
|
||||
|
||||
def is_relu_node(n: Node) -> bool:
|
||||
def is_relu_like_node(n: Node) -> bool:
|
||||
"""Check if a node is a ReLU-like node.
|
||||
ReLU-like nodes have the following properties:
|
||||
- They are either `call_function` or `call_module`
|
||||
- Their output tensors are directly saved for backward
|
||||
- Their input tensors are not saved for backward
|
||||
|
||||
An example is `torch.nn.functional.softmax` which has (forward + backward):
|
||||
def forward(self, input_2):
|
||||
_softmax_default = torch.ops.aten._softmax.default(input_2, None, None); input_2 = None
|
||||
zeros_like_default = torch.ops.aten.zeros_like.default(_softmax_default, dtype = None, layout = None, device = None, pin_memory = None)
|
||||
detach_default = torch.ops.aten.detach.default(_softmax_default); _softmax_default = None
|
||||
_softmax_backward_data_default = torch.ops.aten._softmax_backward_data.default(zeros_like_default, detach_default, None, None); zeros_like_default = detach_default = None
|
||||
detach_default_1 = torch.ops.aten.detach.default(_softmax_backward_data_default); _softmax_backward_data_default = None
|
||||
detach_default_2 = torch.ops.aten.detach.default(detach_default_1); detach_default_1 = None
|
||||
|
||||
Args:
|
||||
n (Node): A node from the graph
|
||||
|
||||
Returns:
|
||||
bool: Whether the node is a ReLU-like node
|
||||
"""
|
||||
if n.op == 'call_function':
|
||||
return n.target in [torch.nn.functional.relu]
|
||||
return n.target in OUTPUT_SAVED_OPS
|
||||
elif n.op == 'call_module':
|
||||
return type(n.graph.owning_module.get_submodule(n.target)) in [torch.nn.ReLU]
|
||||
return type(n.graph.owning_module.get_submodule(n.target)) in OUTPUT_SAVED_MOD
|
||||
return False
|
||||
|
||||
if not is_relu_node(n):
|
||||
if not is_relu_like_node(n):
|
||||
return activation_size(n.meta["fwd_tmp"])
|
||||
return 0
|
||||
|
||||
|
@@ -9,7 +9,7 @@ from torch.nn.parameter import Parameter
|
||||
from torch.utils._pytree import tree_map
|
||||
|
||||
from .._compatibility import compatibility
|
||||
from .constants import ALIAS_ATEN
|
||||
from .constants import ALIAS_ATEN, OUTPUT_SAVED_MOD, OUTPUT_SAVED_OPS
|
||||
from .dataflow import GraphInfo, Phase, autograd_graph_analysis, is_phase
|
||||
from .memory import activation_size, parameter_size
|
||||
from .opcount import flop_mapping
|
||||
@@ -272,7 +272,8 @@ def _profile_meta(target: Callable, *args, **kwargs) -> Tuple[Tuple[Any, ...], G
|
||||
tensor = x._tensor.detach()
|
||||
tensor.uuid = x._tensor.uuid
|
||||
return tensor
|
||||
return x
|
||||
if not isinstance(x, torch.finfo):
|
||||
return x
|
||||
|
||||
graph_info.fwd_out = list(map(extract_tensor, normalize_tuple(out)))
|
||||
|
||||
@@ -314,21 +315,17 @@ def profile_function(target: 'Target', device: str = 'meta') -> Callable:
|
||||
# If there is an argument that this `call_function` is inplace, we should
|
||||
# still run the profiling but discard some results regarding `target`
|
||||
global do_not_cache
|
||||
|
||||
inplace = kwargs.get('inplace', False)
|
||||
if inplace or target in [torch.nn.functional.relu]:
|
||||
if target in OUTPUT_SAVED_OPS:
|
||||
do_not_cache = True
|
||||
if inplace:
|
||||
do_not_cache = True
|
||||
kwargs['inplace'] = False
|
||||
if device == 'meta':
|
||||
out, meta = _profile_meta(func, *args, **kwargs)
|
||||
# currently we set the fwd_mem_tmp of ReLU to zero
|
||||
if target in [torch.nn.functional.relu]:
|
||||
meta.fwd_in = []
|
||||
meta.fwd_tmp = []
|
||||
meta.bwd_mem_out = 0
|
||||
meta.fwd_mem_tmp = 0
|
||||
else:
|
||||
out, meta = _profile_concrete(func, *args, **kwargs)
|
||||
|
||||
if inplace:
|
||||
kwargs['inplace'] = True
|
||||
do_not_cache = False
|
||||
@@ -386,20 +383,16 @@ def profile_module(module: torch.nn.Module, device: str = 'meta') -> Callable:
|
||||
global do_not_cache
|
||||
|
||||
inplace = getattr(module, 'inplace', False)
|
||||
if inplace or type(module) in [torch.nn.ReLU]:
|
||||
if type(module) in OUTPUT_SAVED_MOD:
|
||||
do_not_cache = True
|
||||
if inplace:
|
||||
do_not_cache = True
|
||||
module.inplace = False
|
||||
if device == 'meta':
|
||||
out, meta = _profile_meta(func, *args, **kwargs)
|
||||
# currently we set the fwd_tmp of ReLU to []
|
||||
if type(module) in [torch.nn.ReLU]:
|
||||
meta.fwd_in = []
|
||||
meta.fwd_tmp = []
|
||||
meta.bwd_mem_out = 0
|
||||
else:
|
||||
out, meta = _profile_concrete(func, *args, **kwargs)
|
||||
if inplace:
|
||||
|
||||
module.inplace = True
|
||||
do_not_cache = False
|
||||
|
||||
|
@@ -125,5 +125,5 @@ class MetaTensor(torch.Tensor):
|
||||
device = kwargs['device']
|
||||
result = super().to(*args, **kwargs)
|
||||
if device is not None:
|
||||
result = MetaTensor(deepcopy(result), fake_device=device)
|
||||
result = MetaTensor(result, fake_device=device)
|
||||
return result
|
||||
|
Reference in New Issue
Block a user