[fx] provide a stable but not accurate enough version of profiler. (#1547)

* [fx] compute memory stat and flop count for MetaInfoProp.

* [fx] modify node attribute.

* [fx] modify ckpt_chen.

* [fx] fix compatibility.

* [fx] fix import error.

* [fx] skip test for MetaInfoProp.

* [fx] skip test for MetaInfoProp.

* [fx] skip test for MetaInfoProp.

* [fx] skip test for MetaInfoProp.

* [fx] skip if torch 1.11.0.

* [fx] recover MetaInfoProp support for PyTorch 1.11.

* [fx] provide a stable but not accurate enough version of profiler.

* [fx] provide a stable but not accurate enough version of profiler.

* [fx] fix compatibility in tests.

* [fx] fix compatibility in tests.

* [fx] fix compatibility in tests.

* [fx] fix compatibility in tests.

* [fx] fix compatibility in tests.

* [fx] fix compatibility in tests.

* [fx] fix compatibility in tests.

* [fx] fix compatibility in tests.

* [fx] fix compatibility in tests.

* [fx] fix compatibility in tests.

* [fx] fix import error.
This commit is contained in:
Super Daniel
2022-09-07 11:21:04 +08:00
committed by GitHub
parent 7d49e7b2db
commit 4f59693207
38 changed files with 776 additions and 263 deletions

View File

@@ -1,120 +1,121 @@
from operator import add, floordiv, getitem, mul, neg, setitem, sub, pos
from typing import Callable, List, NamedTuple, Any, Dict, Tuple, Union
from typing import Callable, Any, Dict, Tuple
import torch
from torch.fx import Graph
from torch.fx.node import Argument, Target
from torch.fx._compatibility import compatibility
from . import meta_profiler_function, meta_profiler_module
from torch.utils._pytree import tree_map
from .memory import activation_size, INPLACE_ATEN, WEIRD_OPS
from .tensor import MetaTensor
from .opcount import flop_mapping
__all__ = [
'MetaProfile', 'profile_function', 'profile_module', 'profile_method', 'calculate_activation_size',
'calculate_param_size'
]
CALL_FUNCTION_MSG = \
"""
Colossal-AI hasn't supported profiling for {}, you might manually patch it with the following code.\n
from colossalai.fx.profiler import meta_profiler_function
@meta_profiler_function.register(YOUR_FUNCTION)
def profile_YOUR_FUNCTION(input: torch.Tensor, *args) -> Tuple[int, int]:
flops = ...
macs = ...
return flops, macs
"""
CALL_METHOD_MSG = 'Please check if {} is an inplace method. If so, add target to INPLACE_METHOD={}. Otherwise, add target to NON_INPLACE_METHOD={}'
CALL_MODULE_MSG = \
"""
Colossal-AI hasn't supported profiling for {}, you might manually patch it with the following code.\n
from colossalai.fx.profiler import meta_profiler_module
@meta_profiler_module.register(YOUR_MODULE)
def profile_YOUR_MODULE(self: torch.nn.Module, input: torch.Tensor) -> Tuple[int, int]:
flops = ...
macs = ...
return flops, macs
"""
# TODO fill out the inplace ops
INPLACE_OPS = [
add,
sub,
mul,
floordiv,
neg,
pos,
getitem,
setitem,
getattr,
torch.Tensor.cpu,
]
# TODO: list all call_methods that are inplace here
INPLACE_METHOD = [
'transpose',
'permute',
# TODO: reshape may return a copy of the data if the data is not contiguous
'reshape',
'dim',
'flatten',
'size',
'view',
'unsqueeze',
'to',
]
# TODO: list all call_methods that are not inplace here
NON_INPLACE_METHOD = [
'expand',
'mean',
]
__all__ = ['profile_function', 'profile_module', 'profile_method', '_profile']
@compatibility(is_backward_compatible=True)
class MetaProfile(NamedTuple):
# MetaProfile is a structure containing pertinent information
# about a node within a torch.fx GraphModule.
param: int
activation: int
flops: int
macs: int
def normalize_tuple(x):
if not isinstance(x, tuple):
return (x,)
return x
def calculate_activation_size(activation: Union[torch.Tensor, Dict, List, Tuple, int]) -> int:
"""Calculate activation size of a node.
def is_autogradable(x):
return isinstance(x, torch.Tensor) and x.is_floating_point()
def _profile(target: Callable, *args, **kwargs) -> Tuple[Any, ...]:
"""Profile a Callable function with args and kwargs.
Args:
activation (Union[torch.Tensor, Dict, List, Tuple, int]): The activation of a `torch.nn.Module` or `torch.nn.functional`
target (Callable): A Callable function
args (Any): Argument
kwargs (Any): Argument
Returns:
int: The activation size
out (Tuple[Any, ...]): The argument value that was retrieved
flop_count (Tuple[int, ...]): The flop count for (fwd_flop, bwd_flop).
mem_stat (Tuple[int, ...]): The memory statistics for (fwd_tmp, fwd_out, bwd_tmp, bwd_out)
"""
activation_size = 0
if isinstance(activation, torch.Tensor):
activation_size += activation.numel() * torch.tensor([], dtype=activation.dtype).element_size()
elif isinstance(activation, dict):
value_list = [v for _, v in activation.items()]
activation_size += calculate_activation_size(value_list)
elif isinstance(activation, tuple) or isinstance(activation, list):
for element in activation:
activation_size += calculate_activation_size(element)
return activation_size
flop_count = {
'f': 0,
'l': 0,
'b': 0,
}
temp = {
'f': [],
'l': [],
'b': [],
}
stage = 'f'
def calculate_param_size(mod: torch.nn.Module) -> int:
"""Calculate param size of a node.
class FlopTensor(MetaTensor):
Args:
mod (torch.nn.Module): The target `torch.nn.Module`
def __repr__(self):
if self.grad_fn:
return f"FlopTensor(..., device={self._tensor.device}, size={tuple(self.shape)}, grad_fn={self.grad_fn})"
return f"FlopTensor(..., device={self._tensor.device}, size={tuple(self.shape)})"
Returns:
int: The param size
"""
param_size = 0
for param in mod.parameters():
param_size += param.numel() * torch.tensor([], dtype=param.dtype).element_size()
return param_size
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
def unwrap(x):
if isinstance(x, torch.Tensor) and not hasattr(x, '_tensor'):
x = FlopTensor(x.to('meta'))
return x._tensor.to('meta') if isinstance(x, FlopTensor) else x
def to_meta(x):
return x.to('meta') if isinstance(x, torch.Tensor) else x
args = tree_map(unwrap, args)
kwargs = tree_map(unwrap, kwargs)
# run aten for backend=CPU but actually on backend=Meta
out = func(*args, **kwargs)
flop_count[stage] += flop_mapping[func](args, normalize_tuple(out))
if func not in INPLACE_ATEN:
temp[stage].append(tree_map(to_meta, normalize_tuple(out)))
def wrap(x):
return FlopTensor(x.to('meta')) if isinstance(x, torch.Tensor) else x
return tree_map(wrap, out)
if target not in WEIRD_OPS:
def wrap(x):
return FlopTensor(
x.detach().requires_grad_(True)) if is_autogradable(x) and not hasattr(x, '_tensor') else x
else:
def wrap(x):
return FlopTensor(
x.detach().requires_grad_(False)) if is_autogradable(x) and not hasattr(x, '_tensor') else x
args = tree_map(wrap, args)
kwargs = tree_map(wrap, kwargs)
if isinstance(target, str):
# args[0] is the `self` object for this method call
self_obj, *args_tail = args
out = getattr(self_obj, target)(*args_tail, **kwargs)
else:
out = target(*args, **kwargs)
if is_autogradable(out) and out.requires_grad:
stage = 'l'
loss = out.sum()
stage = 'b'
loss.backward()
fwd_flop = flop_count['f']
bwd_flop = flop_count['b']
fwd_tmp = max(map(activation_size, temp['f'][:-1])) if len(temp['f'][:-1]) else 0
fwd_out = activation_size(temp['f'][-1]) if len(temp['f']) else 0
bwd_tmp = max(map(activation_size, temp['b'])) if len(temp['b']) else 0
def unwrap(x):
return x._tensor.to('meta') if isinstance(x, FlopTensor) else x
return tree_map(unwrap, out), (fwd_flop, bwd_flop), (fwd_tmp, fwd_out, bwd_tmp, 0)
def profile_function(target: 'Target') -> Callable:
@@ -127,31 +128,19 @@ def profile_function(target: 'Target') -> Callable:
Only original `torch.nn.functional` are available.
Examples:
>> input = torch.rand(100, 100, 100, 100, device='meta')
>> func = torch.nn.functional.relu
>> output, profile = profile_function(func)(input, inplace=False)
>> print(f"Profiling function {func},")
>> print(f"Param size: {profile.param / 1024**2:.3f} MB, Activation size: {profile.activation / 1024**2:.3f} MB, {profile.flops} FLOPs, {profile.macs} MACs")
Profiling function <function relu at 0x7fcdd0258d30>,
Param size: 0.000 MB, Activation size: 381.470 MB, 100000000 FLOPs, 0 MACs
>>> input = torch.rand(100, 100, 100, 100, device='meta')
>>> func = torch.nn.functional.relu
>>> output, (fwd_flop, bwd_flop), (fwd_tmp, fwd_out, bwd_tmp, bwd_out) = profile_function(func)(input, inplace=False)
"""
def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any:
assert meta_profiler_function.has(target) or meta_profiler_function.has(
target.__name__), CALL_FUNCTION_MSG.format(target)
# call_function has no parameters
param_size = 0
activation_size = 0
result = func(*args, **kwargs)
if target not in INPLACE_OPS and not kwargs.get('inplace', False):
activation_size += calculate_activation_size(result)
if meta_profiler_function.has(target):
profiler = meta_profiler_function.get(target)
else:
profiler = meta_profiler_function.get(target.__name__)
flops, macs = profiler(*args, **kwargs)
return result, MetaProfile(param_size, activation_size, flops, macs)
if kwargs.get('inplace', False):
args = tree_map(lambda x: x.to('meta') if isinstance(x, torch.Tensor) else x, args)
kwargs = tree_map(lambda x: x.to('meta') if isinstance(x, torch.Tensor) else x, kwargs)
out = func(*args, **kwargs)
return out, (0, 0), (0, 0, 0, 0)
out, flop_count, mem_stat = _profile(func, *args, **kwargs)
return out, flop_count, mem_stat
f.__name__ = target.__name__
func = target
@@ -162,27 +151,13 @@ def profile_method(target: 'Target') -> Callable:
"""
Wrap a `call_method` node
record the memory cost and FLOPs of the execution.
Warnings:
This is not fully implemented and you may follow the error message to debug.
"""
def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any:
# args[0] is the `self` object for this method call
self_obj, *args_tail = args
# execute the method and return the result
assert isinstance(target, str), f'{target} instance is not str.'
result = getattr(self_obj, target)(*args_tail, **kwargs)
assert target in INPLACE_METHOD + NON_INPLACE_METHOD, CALL_METHOD_MSG.format(
target, INPLACE_METHOD, NON_INPLACE_METHOD)
# call_method has no parameters and are MOSTLY(?) inplace, and has no FLOPs or MACs.
param_size = 0
activation_size = 0 if target in INPLACE_METHOD else calculate_activation_size(result)
flops = 0
macs = 0
return result, MetaProfile(param_size, activation_size, flops, macs)
out, flop_count, mem_stat = _profile(target, *args, **kwargs)
return out, flop_count, mem_stat
return f
@@ -197,27 +172,19 @@ def profile_module(module: torch.nn.Module) -> Callable:
Only original `torch.nn` are available.
Example:
>> input = torch.rand(4, 3, 224, 224, device='meta')
>> mod = torch.nn.Conv2d(3, 128, 3)
>> output, profile = profile_module(mod)(input)
>> print(f"Profiling function {mod},")
>> print(f"Param size: {profile.param / 1024**2:.3f} MB, Activation size: {profile.activation / 1024**2:.3f} MB, {profile.flops} FLOPs, {profile.macs} MACs")
Profiling function Conv2d(3, 128, kernel_size=(3, 3), stride=(1, 1)),
Param size: 0.014 MB, Activation size: 96.258 MB, 1387837440 FLOPs, 681302016 MACs
>>> input = torch.rand(4, 3, 224, 224, device='meta')
>>> mod = torch.nn.Conv2d(3, 128, 3)
>>> output, (fwd_flop, bwd_flop), (fwd_tmp, fwd_out, bwd_tmp, bwd_out) = profile_module(mod)(input)
"""
def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any:
assert meta_profiler_module.has(type(module)), CALL_MODULE_MSG.format(type(module))
# only `nn.Module` has parameters
param_size = calculate_param_size(module)
activation_size = 0
result = func(*args, **kwargs)
if not getattr(module, 'inplace', False):
activation_size += calculate_activation_size(result)
profiler = meta_profiler_module.get(type(module))
flops, macs = profiler(module, *args, **kwargs)
return result, MetaProfile(param_size, activation_size, flops, macs)
if getattr(module, 'inplace', False):
args = tree_map(lambda x: x.to('meta'), args)
kwargs = tree_map(lambda x: x.to('meta'), kwargs)
out = func(*args, **kwargs)
return out, (out.numel(), out.numel()), (0, 0, 0, 0)
out, flop_count, mem_stat = _profile(func, *args, **kwargs)
return out, flop_count, mem_stat
f.__name__ = module.__class__.__name__
func = module.forward