mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-15 14:12:02 +00:00
[fx] provide a stable but not accurate enough version of profiler. (#1547)
* [fx] compute memory stat and flop count for MetaInfoProp. * [fx] modify node attribute. * [fx] modify ckpt_chen. * [fx] fix compatibility. * [fx] fix import error. * [fx] skip test for MetaInfoProp. * [fx] skip test for MetaInfoProp. * [fx] skip test for MetaInfoProp. * [fx] skip test for MetaInfoProp. * [fx] skip if torch 1.11.0. * [fx] recover MetaInfoProp support for PyTorch 1.11. * [fx] provide a stable but not accurate enough version of profiler. * [fx] provide a stable but not accurate enough version of profiler. * [fx] fix compatibility in tests. * [fx] fix compatibility in tests. * [fx] fix compatibility in tests. * [fx] fix compatibility in tests. * [fx] fix compatibility in tests. * [fx] fix compatibility in tests. * [fx] fix compatibility in tests. * [fx] fix compatibility in tests. * [fx] fix compatibility in tests. * [fx] fix compatibility in tests. * [fx] fix import error.
This commit is contained in:
@@ -1,120 +1,121 @@
|
||||
from operator import add, floordiv, getitem, mul, neg, setitem, sub, pos
|
||||
from typing import Callable, List, NamedTuple, Any, Dict, Tuple, Union
|
||||
from typing import Callable, Any, Dict, Tuple
|
||||
import torch
|
||||
from torch.fx import Graph
|
||||
from torch.fx.node import Argument, Target
|
||||
from torch.fx._compatibility import compatibility
|
||||
from . import meta_profiler_function, meta_profiler_module
|
||||
from torch.utils._pytree import tree_map
|
||||
from .memory import activation_size, INPLACE_ATEN, WEIRD_OPS
|
||||
from .tensor import MetaTensor
|
||||
from .opcount import flop_mapping
|
||||
|
||||
__all__ = [
|
||||
'MetaProfile', 'profile_function', 'profile_module', 'profile_method', 'calculate_activation_size',
|
||||
'calculate_param_size'
|
||||
]
|
||||
|
||||
CALL_FUNCTION_MSG = \
|
||||
"""
|
||||
Colossal-AI hasn't supported profiling for {}, you might manually patch it with the following code.\n
|
||||
from colossalai.fx.profiler import meta_profiler_function
|
||||
|
||||
@meta_profiler_function.register(YOUR_FUNCTION)
|
||||
def profile_YOUR_FUNCTION(input: torch.Tensor, *args) -> Tuple[int, int]:
|
||||
flops = ...
|
||||
macs = ...
|
||||
return flops, macs
|
||||
"""
|
||||
CALL_METHOD_MSG = 'Please check if {} is an inplace method. If so, add target to INPLACE_METHOD={}. Otherwise, add target to NON_INPLACE_METHOD={}'
|
||||
CALL_MODULE_MSG = \
|
||||
"""
|
||||
Colossal-AI hasn't supported profiling for {}, you might manually patch it with the following code.\n
|
||||
from colossalai.fx.profiler import meta_profiler_module
|
||||
|
||||
@meta_profiler_module.register(YOUR_MODULE)
|
||||
def profile_YOUR_MODULE(self: torch.nn.Module, input: torch.Tensor) -> Tuple[int, int]:
|
||||
flops = ...
|
||||
macs = ...
|
||||
return flops, macs
|
||||
"""
|
||||
|
||||
# TODO fill out the inplace ops
|
||||
INPLACE_OPS = [
|
||||
add,
|
||||
sub,
|
||||
mul,
|
||||
floordiv,
|
||||
neg,
|
||||
pos,
|
||||
getitem,
|
||||
setitem,
|
||||
getattr,
|
||||
torch.Tensor.cpu,
|
||||
]
|
||||
|
||||
# TODO: list all call_methods that are inplace here
|
||||
INPLACE_METHOD = [
|
||||
'transpose',
|
||||
'permute',
|
||||
# TODO: reshape may return a copy of the data if the data is not contiguous
|
||||
'reshape',
|
||||
'dim',
|
||||
'flatten',
|
||||
'size',
|
||||
'view',
|
||||
'unsqueeze',
|
||||
'to',
|
||||
]
|
||||
|
||||
# TODO: list all call_methods that are not inplace here
|
||||
NON_INPLACE_METHOD = [
|
||||
'expand',
|
||||
'mean',
|
||||
]
|
||||
__all__ = ['profile_function', 'profile_module', 'profile_method', '_profile']
|
||||
|
||||
|
||||
@compatibility(is_backward_compatible=True)
|
||||
class MetaProfile(NamedTuple):
|
||||
|
||||
# MetaProfile is a structure containing pertinent information
|
||||
# about a node within a torch.fx GraphModule.
|
||||
|
||||
param: int
|
||||
activation: int
|
||||
flops: int
|
||||
macs: int
|
||||
def normalize_tuple(x):
|
||||
if not isinstance(x, tuple):
|
||||
return (x,)
|
||||
return x
|
||||
|
||||
|
||||
def calculate_activation_size(activation: Union[torch.Tensor, Dict, List, Tuple, int]) -> int:
|
||||
"""Calculate activation size of a node.
|
||||
def is_autogradable(x):
|
||||
return isinstance(x, torch.Tensor) and x.is_floating_point()
|
||||
|
||||
|
||||
def _profile(target: Callable, *args, **kwargs) -> Tuple[Any, ...]:
|
||||
"""Profile a Callable function with args and kwargs.
|
||||
|
||||
Args:
|
||||
activation (Union[torch.Tensor, Dict, List, Tuple, int]): The activation of a `torch.nn.Module` or `torch.nn.functional`
|
||||
target (Callable): A Callable function
|
||||
args (Any): Argument
|
||||
kwargs (Any): Argument
|
||||
|
||||
Returns:
|
||||
int: The activation size
|
||||
out (Tuple[Any, ...]): The argument value that was retrieved
|
||||
flop_count (Tuple[int, ...]): The flop count for (fwd_flop, bwd_flop).
|
||||
mem_stat (Tuple[int, ...]): The memory statistics for (fwd_tmp, fwd_out, bwd_tmp, bwd_out)
|
||||
"""
|
||||
activation_size = 0
|
||||
if isinstance(activation, torch.Tensor):
|
||||
activation_size += activation.numel() * torch.tensor([], dtype=activation.dtype).element_size()
|
||||
elif isinstance(activation, dict):
|
||||
value_list = [v for _, v in activation.items()]
|
||||
activation_size += calculate_activation_size(value_list)
|
||||
elif isinstance(activation, tuple) or isinstance(activation, list):
|
||||
for element in activation:
|
||||
activation_size += calculate_activation_size(element)
|
||||
return activation_size
|
||||
|
||||
flop_count = {
|
||||
'f': 0,
|
||||
'l': 0,
|
||||
'b': 0,
|
||||
}
|
||||
temp = {
|
||||
'f': [],
|
||||
'l': [],
|
||||
'b': [],
|
||||
}
|
||||
stage = 'f'
|
||||
|
||||
def calculate_param_size(mod: torch.nn.Module) -> int:
|
||||
"""Calculate param size of a node.
|
||||
class FlopTensor(MetaTensor):
|
||||
|
||||
Args:
|
||||
mod (torch.nn.Module): The target `torch.nn.Module`
|
||||
def __repr__(self):
|
||||
if self.grad_fn:
|
||||
return f"FlopTensor(..., device={self._tensor.device}, size={tuple(self.shape)}, grad_fn={self.grad_fn})"
|
||||
return f"FlopTensor(..., device={self._tensor.device}, size={tuple(self.shape)})"
|
||||
|
||||
Returns:
|
||||
int: The param size
|
||||
"""
|
||||
param_size = 0
|
||||
for param in mod.parameters():
|
||||
param_size += param.numel() * torch.tensor([], dtype=param.dtype).element_size()
|
||||
return param_size
|
||||
@classmethod
|
||||
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
|
||||
|
||||
def unwrap(x):
|
||||
if isinstance(x, torch.Tensor) and not hasattr(x, '_tensor'):
|
||||
x = FlopTensor(x.to('meta'))
|
||||
return x._tensor.to('meta') if isinstance(x, FlopTensor) else x
|
||||
|
||||
def to_meta(x):
|
||||
return x.to('meta') if isinstance(x, torch.Tensor) else x
|
||||
|
||||
args = tree_map(unwrap, args)
|
||||
kwargs = tree_map(unwrap, kwargs)
|
||||
|
||||
# run aten for backend=CPU but actually on backend=Meta
|
||||
out = func(*args, **kwargs)
|
||||
flop_count[stage] += flop_mapping[func](args, normalize_tuple(out))
|
||||
if func not in INPLACE_ATEN:
|
||||
temp[stage].append(tree_map(to_meta, normalize_tuple(out)))
|
||||
|
||||
def wrap(x):
|
||||
return FlopTensor(x.to('meta')) if isinstance(x, torch.Tensor) else x
|
||||
|
||||
return tree_map(wrap, out)
|
||||
|
||||
if target not in WEIRD_OPS:
|
||||
|
||||
def wrap(x):
|
||||
return FlopTensor(
|
||||
x.detach().requires_grad_(True)) if is_autogradable(x) and not hasattr(x, '_tensor') else x
|
||||
else:
|
||||
|
||||
def wrap(x):
|
||||
return FlopTensor(
|
||||
x.detach().requires_grad_(False)) if is_autogradable(x) and not hasattr(x, '_tensor') else x
|
||||
|
||||
args = tree_map(wrap, args)
|
||||
kwargs = tree_map(wrap, kwargs)
|
||||
|
||||
if isinstance(target, str):
|
||||
# args[0] is the `self` object for this method call
|
||||
self_obj, *args_tail = args
|
||||
out = getattr(self_obj, target)(*args_tail, **kwargs)
|
||||
else:
|
||||
out = target(*args, **kwargs)
|
||||
|
||||
if is_autogradable(out) and out.requires_grad:
|
||||
stage = 'l'
|
||||
loss = out.sum()
|
||||
stage = 'b'
|
||||
loss.backward()
|
||||
|
||||
fwd_flop = flop_count['f']
|
||||
bwd_flop = flop_count['b']
|
||||
|
||||
fwd_tmp = max(map(activation_size, temp['f'][:-1])) if len(temp['f'][:-1]) else 0
|
||||
fwd_out = activation_size(temp['f'][-1]) if len(temp['f']) else 0
|
||||
bwd_tmp = max(map(activation_size, temp['b'])) if len(temp['b']) else 0
|
||||
|
||||
def unwrap(x):
|
||||
return x._tensor.to('meta') if isinstance(x, FlopTensor) else x
|
||||
|
||||
return tree_map(unwrap, out), (fwd_flop, bwd_flop), (fwd_tmp, fwd_out, bwd_tmp, 0)
|
||||
|
||||
|
||||
def profile_function(target: 'Target') -> Callable:
|
||||
@@ -127,31 +128,19 @@ def profile_function(target: 'Target') -> Callable:
|
||||
Only original `torch.nn.functional` are available.
|
||||
|
||||
Examples:
|
||||
>> input = torch.rand(100, 100, 100, 100, device='meta')
|
||||
>> func = torch.nn.functional.relu
|
||||
>> output, profile = profile_function(func)(input, inplace=False)
|
||||
>> print(f"Profiling function {func},")
|
||||
>> print(f"Param size: {profile.param / 1024**2:.3f} MB, Activation size: {profile.activation / 1024**2:.3f} MB, {profile.flops} FLOPs, {profile.macs} MACs")
|
||||
Profiling function <function relu at 0x7fcdd0258d30>,
|
||||
Param size: 0.000 MB, Activation size: 381.470 MB, 100000000 FLOPs, 0 MACs
|
||||
>>> input = torch.rand(100, 100, 100, 100, device='meta')
|
||||
>>> func = torch.nn.functional.relu
|
||||
>>> output, (fwd_flop, bwd_flop), (fwd_tmp, fwd_out, bwd_tmp, bwd_out) = profile_function(func)(input, inplace=False)
|
||||
"""
|
||||
|
||||
def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any:
|
||||
assert meta_profiler_function.has(target) or meta_profiler_function.has(
|
||||
target.__name__), CALL_FUNCTION_MSG.format(target)
|
||||
|
||||
# call_function has no parameters
|
||||
param_size = 0
|
||||
activation_size = 0
|
||||
result = func(*args, **kwargs)
|
||||
if target not in INPLACE_OPS and not kwargs.get('inplace', False):
|
||||
activation_size += calculate_activation_size(result)
|
||||
if meta_profiler_function.has(target):
|
||||
profiler = meta_profiler_function.get(target)
|
||||
else:
|
||||
profiler = meta_profiler_function.get(target.__name__)
|
||||
flops, macs = profiler(*args, **kwargs)
|
||||
return result, MetaProfile(param_size, activation_size, flops, macs)
|
||||
if kwargs.get('inplace', False):
|
||||
args = tree_map(lambda x: x.to('meta') if isinstance(x, torch.Tensor) else x, args)
|
||||
kwargs = tree_map(lambda x: x.to('meta') if isinstance(x, torch.Tensor) else x, kwargs)
|
||||
out = func(*args, **kwargs)
|
||||
return out, (0, 0), (0, 0, 0, 0)
|
||||
out, flop_count, mem_stat = _profile(func, *args, **kwargs)
|
||||
return out, flop_count, mem_stat
|
||||
|
||||
f.__name__ = target.__name__
|
||||
func = target
|
||||
@@ -162,27 +151,13 @@ def profile_method(target: 'Target') -> Callable:
|
||||
"""
|
||||
Wrap a `call_method` node
|
||||
record the memory cost and FLOPs of the execution.
|
||||
|
||||
Warnings:
|
||||
This is not fully implemented and you may follow the error message to debug.
|
||||
"""
|
||||
|
||||
def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any:
|
||||
# args[0] is the `self` object for this method call
|
||||
self_obj, *args_tail = args
|
||||
|
||||
# execute the method and return the result
|
||||
assert isinstance(target, str), f'{target} instance is not str.'
|
||||
|
||||
result = getattr(self_obj, target)(*args_tail, **kwargs)
|
||||
assert target in INPLACE_METHOD + NON_INPLACE_METHOD, CALL_METHOD_MSG.format(
|
||||
target, INPLACE_METHOD, NON_INPLACE_METHOD)
|
||||
# call_method has no parameters and are MOSTLY(?) inplace, and has no FLOPs or MACs.
|
||||
param_size = 0
|
||||
activation_size = 0 if target in INPLACE_METHOD else calculate_activation_size(result)
|
||||
flops = 0
|
||||
macs = 0
|
||||
return result, MetaProfile(param_size, activation_size, flops, macs)
|
||||
out, flop_count, mem_stat = _profile(target, *args, **kwargs)
|
||||
return out, flop_count, mem_stat
|
||||
|
||||
return f
|
||||
|
||||
@@ -197,27 +172,19 @@ def profile_module(module: torch.nn.Module) -> Callable:
|
||||
Only original `torch.nn` are available.
|
||||
|
||||
Example:
|
||||
>> input = torch.rand(4, 3, 224, 224, device='meta')
|
||||
>> mod = torch.nn.Conv2d(3, 128, 3)
|
||||
>> output, profile = profile_module(mod)(input)
|
||||
>> print(f"Profiling function {mod},")
|
||||
>> print(f"Param size: {profile.param / 1024**2:.3f} MB, Activation size: {profile.activation / 1024**2:.3f} MB, {profile.flops} FLOPs, {profile.macs} MACs")
|
||||
Profiling function Conv2d(3, 128, kernel_size=(3, 3), stride=(1, 1)),
|
||||
Param size: 0.014 MB, Activation size: 96.258 MB, 1387837440 FLOPs, 681302016 MACs
|
||||
>>> input = torch.rand(4, 3, 224, 224, device='meta')
|
||||
>>> mod = torch.nn.Conv2d(3, 128, 3)
|
||||
>>> output, (fwd_flop, bwd_flop), (fwd_tmp, fwd_out, bwd_tmp, bwd_out) = profile_module(mod)(input)
|
||||
"""
|
||||
|
||||
def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any:
|
||||
assert meta_profiler_module.has(type(module)), CALL_MODULE_MSG.format(type(module))
|
||||
|
||||
# only `nn.Module` has parameters
|
||||
param_size = calculate_param_size(module)
|
||||
activation_size = 0
|
||||
result = func(*args, **kwargs)
|
||||
if not getattr(module, 'inplace', False):
|
||||
activation_size += calculate_activation_size(result)
|
||||
profiler = meta_profiler_module.get(type(module))
|
||||
flops, macs = profiler(module, *args, **kwargs)
|
||||
return result, MetaProfile(param_size, activation_size, flops, macs)
|
||||
if getattr(module, 'inplace', False):
|
||||
args = tree_map(lambda x: x.to('meta'), args)
|
||||
kwargs = tree_map(lambda x: x.to('meta'), kwargs)
|
||||
out = func(*args, **kwargs)
|
||||
return out, (out.numel(), out.numel()), (0, 0, 0, 0)
|
||||
out, flop_count, mem_stat = _profile(func, *args, **kwargs)
|
||||
return out, flop_count, mem_stat
|
||||
|
||||
f.__name__ = module.__class__.__name__
|
||||
func = module.forward
|
||||
|
Reference in New Issue
Block a user