mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-24 22:42:15 +00:00
[fx] add profiler for fx nodes. (#1480)
* [fx] modify the calculation of node_size in MetaInfoProp for activation checkpointing usages * [fx] modify the calculation of node_size in MetaInfoProp for activation checkpointing usages * [fx] modify the calculation of node_size in MetaInfoProp for activation checkpointing usages * [fx] merge development into main (#1) * [fx] activation checkpointing using Chen strategies. * [fx] add test for ckpt_solver_chen * [fx] add vanilla activation checkpoint search with test on resnet and densenet * [fx] add a namespace code for solver_chen. * [fx] fix the false interpretation of algorithm 3 in https://arxiv.org/abs/1604.06174. * [fx] fix lowercase naming conventions. * [fx] simplify test for ckpt. * [fx] add rules to linearize computation graphs for searching. (#2) * [fx] modify the calculation of node_size in MetaInfoProp for activation checkpointing usages * [fx] modify the calculation of node_size in MetaInfoProp for activation checkpointing usages * [fx] modify the calculation of node_size in MetaInfoProp for activation checkpointing usages * [fx] merge development into main (#1) * [fx] activation checkpointing using Chen strategies. * [fx] add test for ckpt_solver_chen * [fx] add vanilla activation checkpoint search with test on resnet and densenet * [fx] add a namespace code for solver_chen. * [fx] fix the false interpretation of algorithm 3 in https://arxiv.org/abs/1604.06174. * [fx] fix lowercase naming conventions. * [fx] simplify test for ckpt. * [fx] fix test and algorithm bugs in activation checkpointing. * [fx] polish ckpt_test. * [fx] add rules to linearize computation graphs for searching. * [fx] remove chen_sqrt for sake of simplicity * [fx] remove chen_sqrt for sake of simplicity * [fx] remove chen_sqrt for sake of simplicity * [fx] remove chen_sqrt for sake of simplicity * [fx] fix inconsistencies. * [fx] fix MetaInfoProp. * [fx] fix MetaInfoProp. * [fx] consider MetaInfoProp for inplace operands. * [fx] consider MetaInfoProp for inplace operands. * [fx] consider MetaInfoProp for inplace operands. * [fx] consider MetaInfoProp for inplace operands. * [fx] consider MetaInfoProp for inplace operands. * [fx] add profiler for fx nodes. * [fx] add profiler for fx nodes. * [fx] add profiler for fx nodes. * [fx] add profiler for fx nodes. * [fx] add profiler for fx nodes. * [fx] add profiler for fx nodes. * [fx] add profiler for fx nodes. * [fx] fix error in tests. * [fx] unfix bug. * [fx] unfix bug.
This commit is contained in:
parent
d39e11dffb
commit
32efe8e740
@ -73,10 +73,10 @@ def chen_greedy(gm: GraphModule) -> GraphModule:
|
|||||||
y = 0
|
y = 0
|
||||||
prev_idx = 2
|
prev_idx = 2
|
||||||
for (idx, n) in enumerate(gm.graph.nodes):
|
for (idx, n) in enumerate(gm.graph.nodes):
|
||||||
temp += getattr(n, 'activation_size')
|
temp += getattr(n, '__activation__')
|
||||||
y = max(y, temp)
|
y = max(y, temp)
|
||||||
if temp > b and n in ckpt_nodes:
|
if temp > b and n in ckpt_nodes:
|
||||||
x += getattr(n, 'activation_size')
|
x += getattr(n, '__activation__')
|
||||||
temp = 0
|
temp = 0
|
||||||
ckpt_intv.append((prev_idx, idx + 1))
|
ckpt_intv.append((prev_idx, idx + 1))
|
||||||
prev_idx = idx + 1
|
prev_idx = idx + 1
|
||||||
|
@ -1,10 +1,12 @@
|
|||||||
|
from operator import add, getitem
|
||||||
import torch
|
import torch
|
||||||
import torch.fx
|
import torch.fx
|
||||||
from torch.fx.node import Node, map_aggregate
|
from torch.fx.node import Node, map_aggregate, Argument, Target
|
||||||
from typing import Any, Tuple, NamedTuple, Optional, Dict
|
from typing import Any, Tuple, NamedTuple, Optional, Dict
|
||||||
from functools import reduce
|
from functools import reduce
|
||||||
from torch.fx._compatibility import compatibility
|
from torch.fx._compatibility import compatibility
|
||||||
from torch.fx.immutable_collections import immutable_dict, immutable_list
|
from torch.fx.immutable_collections import immutable_dict, immutable_list
|
||||||
|
from colossalai.fx.profiler import MetaProfile, profile_function, profile_module, calculate_activation_size, profile_method
|
||||||
|
|
||||||
|
|
||||||
@compatibility(is_backward_compatible=True)
|
@compatibility(is_backward_compatible=True)
|
||||||
@ -36,47 +38,11 @@ def _extract_tensor_metadata(result: torch.Tensor) -> TensorMetadata:
|
|||||||
return TensorMetadata(shape, dtype, requires_grad, stride, numel, is_tensor)
|
return TensorMetadata(shape, dtype, requires_grad, stride, numel, is_tensor)
|
||||||
|
|
||||||
|
|
||||||
def _compute_activation_size(node_metadata: any) -> int:
|
|
||||||
"""
|
|
||||||
Compute numel of a node with ``tensor_meta`` attribute.
|
|
||||||
"""
|
|
||||||
node_numel = 0
|
|
||||||
|
|
||||||
if isinstance(node_metadata, TensorMetadata):
|
|
||||||
node_numel += node_metadata.numel * torch.tensor([], dtype=node_metadata.dtype).element_size()
|
|
||||||
elif isinstance(node_metadata, dict):
|
|
||||||
value_list = [v for _, v in node_metadata.items()]
|
|
||||||
node_numel += _compute_activation_size(value_list)
|
|
||||||
else:
|
|
||||||
for element in node_metadata:
|
|
||||||
node_numel += _compute_activation_size(element)
|
|
||||||
|
|
||||||
return node_numel
|
|
||||||
|
|
||||||
|
|
||||||
def _map_aggregate(arg, fn):
|
|
||||||
"""
|
|
||||||
Apply fn to each Node appearing arg. arg may be a list, tuple, slice, or dict with string keys.
|
|
||||||
"""
|
|
||||||
if isinstance(arg, torch.Size):
|
|
||||||
return fn(arg)
|
|
||||||
if isinstance(arg, tuple):
|
|
||||||
return tuple(map_aggregate(elem, fn) for elem in arg)
|
|
||||||
elif isinstance(arg, list):
|
|
||||||
return immutable_list(map_aggregate(elem, fn) for elem in arg)
|
|
||||||
elif isinstance(arg, dict):
|
|
||||||
return immutable_dict((k, map_aggregate(v, fn)) for k, v in arg.items())
|
|
||||||
elif isinstance(arg, slice):
|
|
||||||
return slice(map_aggregate(arg.start, fn), map_aggregate(arg.stop, fn), map_aggregate(arg.step, fn))
|
|
||||||
else:
|
|
||||||
return fn(arg)
|
|
||||||
|
|
||||||
|
|
||||||
@compatibility(is_backward_compatible=True)
|
@compatibility(is_backward_compatible=True)
|
||||||
class MetaInfoProp(torch.fx.Interpreter):
|
class MetaInfoProp(torch.fx.Interpreter):
|
||||||
"""
|
"""
|
||||||
Execute an FX graph Node-by-Node and
|
Execute an FX graph Node-by-Node with meta tensor and
|
||||||
record the shape and type of the result
|
record the shape, FLOPs, MACs and type of the result
|
||||||
into the corresponding node.
|
into the corresponding node.
|
||||||
|
|
||||||
Usage:
|
Usage:
|
||||||
@ -104,9 +70,32 @@ class MetaInfoProp(torch.fx.Interpreter):
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@compatibility(is_backward_compatible=True)
|
||||||
|
def run(self, *args, initial_env: Optional[Dict[Node, Any]] = None, enable_io_processing: bool = True) -> Any:
|
||||||
|
"""
|
||||||
|
Add additional check for initial args to ensure all the tensor appears with `device='meta'`
|
||||||
|
"""
|
||||||
|
for elem in args:
|
||||||
|
if isinstance(elem, torch.Tensor):
|
||||||
|
assert elem.is_meta, "Input torch.Tensor are assumed to appear with device='meta'"
|
||||||
|
return super().run(*args, initial_env, enable_io_processing)
|
||||||
|
|
||||||
|
@compatibility(is_backward_compatible=True)
|
||||||
def run_node(self, n: Node) -> Any:
|
def run_node(self, n: Node) -> Any:
|
||||||
# TODO: We might run_node(n) with meta data, and count FLOPS for each node
|
"""
|
||||||
result = super().run_node(n)
|
Run a specific node ``n`` and return the result.
|
||||||
|
Calls into placeholder, get_attr, call_function,
|
||||||
|
call_method, call_module, or output depending
|
||||||
|
on ``node.op``
|
||||||
|
|
||||||
|
Args:
|
||||||
|
n (Node): The Node to execute
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Any: The result of executing ``n``
|
||||||
|
"""
|
||||||
|
result, profile = super().run_node(n)
|
||||||
|
profile: MetaProfile
|
||||||
|
|
||||||
def extract_tensor_meta(obj):
|
def extract_tensor_meta(obj):
|
||||||
if isinstance(obj, torch.Tensor):
|
if isinstance(obj, torch.Tensor):
|
||||||
@ -114,29 +103,139 @@ class MetaInfoProp(torch.fx.Interpreter):
|
|||||||
else:
|
else:
|
||||||
return TensorMetadata(None, None, False, None, 0, False)
|
return TensorMetadata(None, None, False, None, 0, False)
|
||||||
|
|
||||||
meta = _map_aggregate(result, extract_tensor_meta)
|
meta = map_aggregate(result, extract_tensor_meta)
|
||||||
n.meta['tensor_meta'] = meta
|
n.meta['tensor_meta'] = meta
|
||||||
|
|
||||||
total_activation_size = 0
|
# TODO: the attribute node_size should be removed in the future
|
||||||
total_param_size = 0
|
setattr(n, 'node_size', profile.param + profile.activation)
|
||||||
if n.op == 'call_module':
|
setattr(n, '__param__', profile.param)
|
||||||
target_module = n.graph.owning_module.get_submodule(n.target)
|
setattr(n, '__activation__', profile.activation)
|
||||||
if not getattr(target_module, 'inplace', False):
|
setattr(n, '__flops__', profile.flops)
|
||||||
total_activation_size = _compute_activation_size(n.meta['tensor_meta'])
|
setattr(n, '__macs__', profile.macs)
|
||||||
for param in target_module.parameters():
|
|
||||||
total_param_size += param.numel() * torch.tensor([], dtype=param.dtype).element_size()
|
|
||||||
elif n.op == 'call_function':
|
|
||||||
if 'inplace' not in n.kwargs:
|
|
||||||
total_activation_size = _compute_activation_size(n.meta['tensor_meta'])
|
|
||||||
else:
|
|
||||||
total_activation_size = _compute_activation_size(n.meta['tensor_meta'])
|
|
||||||
|
|
||||||
setattr(n, 'node_size', total_activation_size + total_param_size)
|
|
||||||
setattr(n, 'param_size', total_param_size)
|
|
||||||
setattr(n, 'activation_size', total_activation_size)
|
|
||||||
n.meta['type'] = type(result)
|
n.meta['type'] = type(result)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
# Main Node running APIs
|
||||||
|
@compatibility(is_backward_compatible=True)
|
||||||
|
def placeholder(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
|
||||||
|
"""
|
||||||
|
Execute a ``placeholder`` node. Note that this is stateful:
|
||||||
|
``Interpreter`` maintains an internal iterator over
|
||||||
|
arguments passed to ``run`` and this method returns
|
||||||
|
next() on that iterator.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
target (Target): The call target for this node. See
|
||||||
|
`Node <https://pytorch.org/docs/master/fx.html#torch.fx.Node>`__ for
|
||||||
|
details on semantics
|
||||||
|
args (Tuple): Tuple of positional args for this invocation
|
||||||
|
kwargs (Dict): Dict of keyword arguments for this invocation
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
result (Any): The argument value that was retrieved
|
||||||
|
profile (MetaProfile): The meta profile of this node
|
||||||
|
"""
|
||||||
|
result = super().placeholder(target, args, kwargs)
|
||||||
|
# A placeholder node only has activation
|
||||||
|
return result, MetaProfile(0, calculate_activation_size(result), 0, 0)
|
||||||
|
|
||||||
|
@compatibility(is_backward_compatible=True)
|
||||||
|
def get_attr(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
|
||||||
|
"""
|
||||||
|
Execute a ``get_attr`` node. Will retrieve an attribute
|
||||||
|
value from the ``Module`` hierarchy of ``self.module``.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
target (Target): The call target for this node. See
|
||||||
|
`Node <https://pytorch.org/docs/master/fx.html#torch.fx.Node>`__ for
|
||||||
|
details on semantics
|
||||||
|
args (Tuple): Tuple of positional args for this invocation
|
||||||
|
kwargs (Dict): Dict of keyword arguments for this invocation
|
||||||
|
|
||||||
|
Return:
|
||||||
|
result (Any): The argument value that was retrieved
|
||||||
|
profile (MetaProfile): The meta profile of this node
|
||||||
|
"""
|
||||||
|
# A get_attr node never has parameters, activations, FLOPs, or MACs
|
||||||
|
return super().get_attr(target, args, kwargs), MetaProfile(0, 0, 0, 0)
|
||||||
|
|
||||||
|
@compatibility(is_backward_compatible=True)
|
||||||
|
def call_function(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
|
||||||
|
"""
|
||||||
|
Execute a ``call_function`` node with meta tensor and return the result and its meta profile.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
target (Target): The call target for this node. See
|
||||||
|
`Node <https://pytorch.org/docs/master/fx.html#torch.fx.Node>`__ for
|
||||||
|
details on semantics
|
||||||
|
args (Tuple): Tuple of positional args for this invocation
|
||||||
|
kwargs (Dict): Dict of keyword arguments for this invocation
|
||||||
|
|
||||||
|
Return
|
||||||
|
result (Any): The argument value that was retrieved
|
||||||
|
profile (MetaProfile): The meta profile of this node
|
||||||
|
"""
|
||||||
|
assert not isinstance(target, str)
|
||||||
|
return profile_function(target)(*args, **kwargs)
|
||||||
|
|
||||||
|
@compatibility(is_backward_compatible=True)
|
||||||
|
def call_method(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
|
||||||
|
"""
|
||||||
|
Execute a ``call_method`` node with meta tensor and return the result and its meta profile.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
target (Target): The call target for this node. See
|
||||||
|
`Node <https://pytorch.org/docs/master/fx.html#torch.fx.Node>`__ for
|
||||||
|
details on semantics
|
||||||
|
args (Tuple): Tuple of positional args for this invocation
|
||||||
|
kwargs (Dict): Dict of keyword arguments for this invocation
|
||||||
|
|
||||||
|
Return
|
||||||
|
result (Any): The argument value that was retrieved
|
||||||
|
profile (MetaProfile): The meta profile of this node
|
||||||
|
"""
|
||||||
|
return profile_method(target)(*args, **kwargs)
|
||||||
|
|
||||||
|
@compatibility(is_backward_compatible=True)
|
||||||
|
def call_module(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
|
||||||
|
"""
|
||||||
|
Execute a ``call_module`` node with meta tensor and return the result and its meta profile.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
target (Target): The call target for this node. See
|
||||||
|
`Node <https://pytorch.org/docs/master/fx.html#torch.fx.Node>`__ for
|
||||||
|
details on semantics
|
||||||
|
args (Tuple): Tuple of positional args for this invocation
|
||||||
|
kwargs (Dict): Dict of keyword arguments for this invocation
|
||||||
|
|
||||||
|
Return
|
||||||
|
result (Any): The argument value that was retrieved
|
||||||
|
profile (MetaProfile): The meta profile of this node
|
||||||
|
"""
|
||||||
|
# Retrieve executed args and kwargs values from the environment
|
||||||
|
# Execute the method and return the result
|
||||||
|
assert isinstance(target, str)
|
||||||
|
submod = self.fetch_attr(target)
|
||||||
|
return profile_module(submod)(*args, **kwargs)
|
||||||
|
|
||||||
|
@compatibility(is_backward_compatible=True)
|
||||||
|
def output(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
|
||||||
|
"""
|
||||||
|
Execute an ``output`` node. This really just retrieves
|
||||||
|
the value referenced by the ``output`` node and returns it.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
target (Target): The call target for this node. See
|
||||||
|
`Node <https://pytorch.org/docs/master/fx.html#torch.fx.Node>`__ for
|
||||||
|
details on semantics
|
||||||
|
args (Tuple): Tuple of positional args for this invocation
|
||||||
|
kwargs (Dict): Dict of keyword arguments for this invocation
|
||||||
|
|
||||||
|
Return:
|
||||||
|
Any: The return value referenced by the output node
|
||||||
|
"""
|
||||||
|
return args[0], MetaProfile(0, 0, 0, 0)
|
||||||
|
|
||||||
def propagate(self, *args):
|
def propagate(self, *args):
|
||||||
"""
|
"""
|
||||||
Run `module` via interpretation and return the result and
|
Run `module` via interpretation and return the result and
|
||||||
|
4
colossalai/fx/profiler/__init__.py
Normal file
4
colossalai/fx/profiler/__init__.py
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
from .registry import *
|
||||||
|
from .profiler_function import *
|
||||||
|
from .profiler_module import *
|
||||||
|
from .utils import *
|
8
colossalai/fx/profiler/profiler_function/__init__.py
Normal file
8
colossalai/fx/profiler/profiler_function/__init__.py
Normal file
@ -0,0 +1,8 @@
|
|||||||
|
from .activation_function import *
|
||||||
|
from .arithmetic import *
|
||||||
|
from .embedding import *
|
||||||
|
from .linear import *
|
||||||
|
from .normalization import *
|
||||||
|
from .pooling import *
|
||||||
|
from .python_ops import *
|
||||||
|
from .torch_ops import *
|
@ -0,0 +1,29 @@
|
|||||||
|
from typing import Tuple
|
||||||
|
import torch
|
||||||
|
from ..registry import meta_profiler_function
|
||||||
|
|
||||||
|
# TODO: different activation has different FLOPs count, currently unused.
|
||||||
|
_multiplier = {
|
||||||
|
torch.nn.functional.relu: 1,
|
||||||
|
torch.nn.functional.prelu: 4,
|
||||||
|
torch.nn.functional.sigmoid: 4,
|
||||||
|
torch.nn.functional.tanh: 5,
|
||||||
|
torch.nn.functional.leaky_relu: 3,
|
||||||
|
torch.nn.functional.elu: 4,
|
||||||
|
torch.nn.functional.relu6: 2,
|
||||||
|
torch.nn.functional.gelu: 9,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@meta_profiler_function.register(torch.nn.functional.leaky_relu)
|
||||||
|
@meta_profiler_function.register(torch.nn.functional.elu)
|
||||||
|
@meta_profiler_function.register(torch.nn.functional.gelu)
|
||||||
|
@meta_profiler_function.register(torch.nn.functional.relu6)
|
||||||
|
@meta_profiler_function.register(torch.nn.functional.prelu)
|
||||||
|
@meta_profiler_function.register(torch.nn.functional.relu)
|
||||||
|
@meta_profiler_function.register(torch.nn.functional.sigmoid)
|
||||||
|
@meta_profiler_function.register(torch.nn.functional.tanh)
|
||||||
|
def torch_nn_func_non_linear_act(input: torch.Tensor, inplace: bool = False) -> Tuple[int, int]:
|
||||||
|
flops = input.numel()
|
||||||
|
macs = 0
|
||||||
|
return flops, macs
|
83
colossalai/fx/profiler/profiler_function/arithmetic.py
Normal file
83
colossalai/fx/profiler/profiler_function/arithmetic.py
Normal file
@ -0,0 +1,83 @@
|
|||||||
|
from typing import Any, Optional, Tuple, Union
|
||||||
|
import torch
|
||||||
|
from ..registry import meta_profiler_function
|
||||||
|
|
||||||
|
|
||||||
|
def _prod(dims):
|
||||||
|
p = 1
|
||||||
|
for v in dims:
|
||||||
|
p *= v
|
||||||
|
return p
|
||||||
|
|
||||||
|
|
||||||
|
def _elementwise_flops_compute(input, other):
|
||||||
|
# copied from https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/profiling/flops_profiler/profiler.py#L763
|
||||||
|
if not torch.is_tensor(input):
|
||||||
|
if torch.is_tensor(other):
|
||||||
|
return _prod(other.shape), 0
|
||||||
|
else:
|
||||||
|
return 1, 0
|
||||||
|
elif not torch.is_tensor(other):
|
||||||
|
return _prod(input.shape), 0
|
||||||
|
else:
|
||||||
|
dim_input = len(input.shape)
|
||||||
|
dim_other = len(other.shape)
|
||||||
|
max_dim = max(dim_input, dim_other)
|
||||||
|
|
||||||
|
final_shape = []
|
||||||
|
for i in range(max_dim):
|
||||||
|
in_i = input.shape[i] if i < dim_input else 1
|
||||||
|
ot_i = other.shape[i] if i < dim_other else 1
|
||||||
|
if in_i > ot_i:
|
||||||
|
final_shape.append(in_i)
|
||||||
|
else:
|
||||||
|
final_shape.append(ot_i)
|
||||||
|
flops = _prod(final_shape)
|
||||||
|
return flops, 0
|
||||||
|
|
||||||
|
|
||||||
|
@meta_profiler_function.register(torch.add)
|
||||||
|
@meta_profiler_function.register('add') # for built-in op +
|
||||||
|
@meta_profiler_function.register('iadd') # for built-in op +=
|
||||||
|
@meta_profiler_function.register('sub') # for built-in op -
|
||||||
|
@meta_profiler_function.register('isub') # for built-in op -=
|
||||||
|
@meta_profiler_function.register('mul') # for built-in op *
|
||||||
|
@meta_profiler_function.register('imul') # for built-in op *=
|
||||||
|
def torch_add_like_ops(input: Any, other: Any, *, out: Optional[torch.Tensor] = None) -> Tuple[int, int]:
|
||||||
|
return _elementwise_flops_compute(input, other)
|
||||||
|
|
||||||
|
|
||||||
|
@meta_profiler_function.register(torch.abs)
|
||||||
|
def torch_elementwise_op(input: torch.Tensor, *, out: Optional[torch.Tensor] = None) -> Tuple[int, int]:
|
||||||
|
flops = input.numel()
|
||||||
|
macs = 0
|
||||||
|
return flops, macs
|
||||||
|
|
||||||
|
|
||||||
|
@meta_profiler_function.register(torch.matmul)
|
||||||
|
@meta_profiler_function.register('matmul') # for built-in op @
|
||||||
|
@meta_profiler_function.register(torch.Tensor.matmul)
|
||||||
|
def torch_matmul(input: torch.Tensor, other: torch.Tensor, *, out: Optional[torch.Tensor] = None) -> Tuple[int, int]:
|
||||||
|
macs = _prod(input.shape) * other.shape[-1]
|
||||||
|
flops = 2 * macs
|
||||||
|
return flops, macs
|
||||||
|
|
||||||
|
|
||||||
|
@meta_profiler_function.register(torch.bmm)
|
||||||
|
def torch_bmm(input: torch.Tensor, other: torch.Tensor, *, out: Optional[torch.Tensor] = None) -> Tuple[int, int]:
|
||||||
|
macs = _prod(input.shape) * other.shape[-1]
|
||||||
|
flops = 2 * macs
|
||||||
|
return flops, macs
|
||||||
|
|
||||||
|
|
||||||
|
@meta_profiler_function.register(torch.var_mean)
|
||||||
|
def torch_var_mean(input: torch.Tensor,
|
||||||
|
dim: Union[int, Tuple[int, ...]],
|
||||||
|
unbiased: Optional[bool] = True,
|
||||||
|
keepdim: Optional[bool] = False,
|
||||||
|
*,
|
||||||
|
out: Optional[torch.Tensor] = None) -> Tuple[int, int]:
|
||||||
|
assert out is None, 'saving to out is not supported yet'
|
||||||
|
flops = input.numel() * 3
|
||||||
|
macs = 0
|
||||||
|
return flops, macs
|
19
colossalai/fx/profiler/profiler_function/embedding.py
Normal file
19
colossalai/fx/profiler/profiler_function/embedding.py
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
import torch
|
||||||
|
from typing import Optional
|
||||||
|
from ..registry import meta_profiler_function
|
||||||
|
|
||||||
|
|
||||||
|
@meta_profiler_function.register(torch.nn.functional.embedding)
|
||||||
|
def torch_nn_functional_embedding(
|
||||||
|
input: torch.Tensor,
|
||||||
|
weight: torch.Tensor,
|
||||||
|
padding_idx: Optional[int] = None,
|
||||||
|
max_norm: Optional[float] = None,
|
||||||
|
norm_type: float = 2.0,
|
||||||
|
scale_grad_by_freq: bool = False,
|
||||||
|
sparse: bool = False,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
# F.embedding is a dictionary lookup, so technically it has 0 FLOPs. (https://discuss.pytorch.org/t/correct-way-to-calculate-flops-in-model/67198/6)
|
||||||
|
flops = 0
|
||||||
|
macs = 0
|
||||||
|
return flops, macs
|
13
colossalai/fx/profiler/profiler_function/linear.py
Normal file
13
colossalai/fx/profiler/profiler_function/linear.py
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
from typing import Tuple
|
||||||
|
import torch
|
||||||
|
from ..registry import meta_profiler_function
|
||||||
|
|
||||||
|
|
||||||
|
@meta_profiler_function.register(torch.nn.functional.linear)
|
||||||
|
def torch_nn_linear(input: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor = None) -> Tuple[int, int]:
|
||||||
|
out_features = weight.shape[0]
|
||||||
|
macs = torch.numel(input) * out_features
|
||||||
|
flops = 2 * macs
|
||||||
|
if bias is not None:
|
||||||
|
flops += bias.numel()
|
||||||
|
return flops, macs
|
66
colossalai/fx/profiler/profiler_function/normalization.py
Normal file
66
colossalai/fx/profiler/profiler_function/normalization.py
Normal file
@ -0,0 +1,66 @@
|
|||||||
|
from typing import List, Optional, Tuple
|
||||||
|
import torch
|
||||||
|
from ..registry import meta_profiler_function
|
||||||
|
|
||||||
|
|
||||||
|
@meta_profiler_function.register(torch.nn.functional.instance_norm)
|
||||||
|
def torch_nn_func_instancenorm(
|
||||||
|
input: torch.Tensor,
|
||||||
|
running_mean: Optional[torch.Tensor] = None,
|
||||||
|
running_var: Optional[torch.Tensor] = None,
|
||||||
|
weight: Optional[torch.Tensor] = None,
|
||||||
|
bias: Optional[torch.Tensor] = None,
|
||||||
|
use_input_stats: bool = True,
|
||||||
|
momentum: float = 0.1,
|
||||||
|
eps: float = 1e-5,
|
||||||
|
):
|
||||||
|
has_affine = weight is not None
|
||||||
|
flops = input.numel() * (5 if has_affine else 4)
|
||||||
|
macs = 0
|
||||||
|
return flops, macs
|
||||||
|
|
||||||
|
|
||||||
|
@meta_profiler_function.register(torch.nn.functional.group_norm)
|
||||||
|
def torch_nn_func_groupnorm(input: torch.Tensor,
|
||||||
|
num_groups: int,
|
||||||
|
weight: Optional[torch.Tensor] = None,
|
||||||
|
bias: Optional[torch.Tensor] = None,
|
||||||
|
eps: float = 1e-5) -> Tuple[int, int]:
|
||||||
|
has_affine = weight is not None
|
||||||
|
flops = input.numel() * (5 if has_affine else 4)
|
||||||
|
macs = 0
|
||||||
|
return flops, macs
|
||||||
|
|
||||||
|
|
||||||
|
@meta_profiler_function.register(torch.nn.functional.layer_norm)
|
||||||
|
def torch_nn_func_layernorm(
|
||||||
|
input: torch.Tensor,
|
||||||
|
normalized_shape: List[int],
|
||||||
|
weight: Optional[torch.Tensor] = None,
|
||||||
|
bias: Optional[torch.Tensor] = None,
|
||||||
|
eps: float = 1e-5,
|
||||||
|
) -> Tuple[int, int]:
|
||||||
|
has_affine = weight is not None
|
||||||
|
flops = input.numel() * (5 if has_affine else 4)
|
||||||
|
macs = 0
|
||||||
|
return flops, macs
|
||||||
|
|
||||||
|
|
||||||
|
@meta_profiler_function.register(torch.nn.functional.batch_norm)
|
||||||
|
def torch_nn_func_batchnorm(
|
||||||
|
input: torch.Tensor,
|
||||||
|
running_mean: Optional[torch.Tensor],
|
||||||
|
running_var: Optional[torch.Tensor],
|
||||||
|
weight: Optional[torch.Tensor] = None,
|
||||||
|
bias: Optional[torch.Tensor] = None,
|
||||||
|
training: bool = False,
|
||||||
|
momentum: float = 0.1,
|
||||||
|
eps: float = 1e-5,
|
||||||
|
) -> Tuple[int, int]:
|
||||||
|
has_affine = weight is not None
|
||||||
|
if training:
|
||||||
|
flops = input.numel() * (2 if has_affine else 1)
|
||||||
|
else:
|
||||||
|
flops = input.numel() * (5 if has_affine else 4)
|
||||||
|
macs = 0
|
||||||
|
return flops, macs
|
22
colossalai/fx/profiler/profiler_function/pooling.py
Normal file
22
colossalai/fx/profiler/profiler_function/pooling.py
Normal file
@ -0,0 +1,22 @@
|
|||||||
|
from typing import Tuple, Union
|
||||||
|
import torch
|
||||||
|
from ..registry import meta_profiler_function
|
||||||
|
|
||||||
|
|
||||||
|
@meta_profiler_function.register(torch.nn.functional.avg_pool1d)
|
||||||
|
@meta_profiler_function.register(torch.nn.functional.avg_pool2d)
|
||||||
|
@meta_profiler_function.register(torch.nn.functional.avg_pool3d)
|
||||||
|
@meta_profiler_function.register(torch.nn.functional.max_pool1d)
|
||||||
|
@meta_profiler_function.register(torch.nn.functional.max_pool2d)
|
||||||
|
@meta_profiler_function.register(torch.nn.functional.max_pool3d)
|
||||||
|
@meta_profiler_function.register(torch.nn.functional.adaptive_avg_pool1d)
|
||||||
|
@meta_profiler_function.register(torch.nn.functional.adaptive_avg_pool2d)
|
||||||
|
@meta_profiler_function.register(torch.nn.functional.adaptive_avg_pool3d)
|
||||||
|
@meta_profiler_function.register(torch.nn.functional.adaptive_max_pool1d)
|
||||||
|
@meta_profiler_function.register(torch.nn.functional.adaptive_max_pool2d)
|
||||||
|
@meta_profiler_function.register(torch.nn.functional.adaptive_max_pool3d)
|
||||||
|
def torch_nn_func_pooling(input: torch.Tensor, *args, **kwargs) -> Tuple[int, int]:
|
||||||
|
# all pooling could be considered as going over each input element only once (https://stackoverflow.com/a/67301217)
|
||||||
|
flops = input.numel()
|
||||||
|
macs = 0
|
||||||
|
return flops, macs
|
12
colossalai/fx/profiler/profiler_function/python_ops.py
Normal file
12
colossalai/fx/profiler/profiler_function/python_ops.py
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
import operator
|
||||||
|
from typing import Any, Tuple
|
||||||
|
import torch
|
||||||
|
from ..registry import meta_profiler_function
|
||||||
|
from colossalai.fx.proxy import ColoProxy
|
||||||
|
|
||||||
|
|
||||||
|
@meta_profiler_function.register(operator.getitem)
|
||||||
|
def operator_getitem(a: Any, b: Any) -> Tuple[int, int]:
|
||||||
|
flops = 0
|
||||||
|
macs = 0
|
||||||
|
return flops, macs
|
64
colossalai/fx/profiler/profiler_function/torch_ops.py
Normal file
64
colossalai/fx/profiler/profiler_function/torch_ops.py
Normal file
@ -0,0 +1,64 @@
|
|||||||
|
from typing import Any, Optional, Tuple
|
||||||
|
import torch
|
||||||
|
from ..registry import meta_profiler_function
|
||||||
|
|
||||||
|
|
||||||
|
def _prod(dims):
|
||||||
|
p = 1
|
||||||
|
for v in dims:
|
||||||
|
p *= v
|
||||||
|
return p
|
||||||
|
|
||||||
|
|
||||||
|
@meta_profiler_function.register(torch.arange)
|
||||||
|
@meta_profiler_function.register(torch.finfo)
|
||||||
|
@meta_profiler_function.register(torch.permute)
|
||||||
|
@meta_profiler_function.register(torch.Tensor.permute)
|
||||||
|
@meta_profiler_function.register(torch.Tensor.repeat)
|
||||||
|
@meta_profiler_function.register(torch.index_select)
|
||||||
|
@meta_profiler_function.register(torch.Tensor.index_select)
|
||||||
|
@meta_profiler_function.register(torch.squeeze)
|
||||||
|
@meta_profiler_function.register(torch.Tensor.squeeze)
|
||||||
|
@meta_profiler_function.register(torch.unsqueeze)
|
||||||
|
@meta_profiler_function.register(torch.Tensor.unsqueeze)
|
||||||
|
@meta_profiler_function.register(torch.cat)
|
||||||
|
@meta_profiler_function.register(torch.concat)
|
||||||
|
@meta_profiler_function.register(torch.repeat_interleave)
|
||||||
|
@meta_profiler_function.register(torch.Tensor.repeat_interleave)
|
||||||
|
@meta_profiler_function.register(torch.flatten)
|
||||||
|
@meta_profiler_function.register(torch.Tensor.flatten)
|
||||||
|
@meta_profiler_function.register(torch.roll)
|
||||||
|
@meta_profiler_function.register(torch.full)
|
||||||
|
@meta_profiler_function.register(torch.Tensor.cpu)
|
||||||
|
@meta_profiler_function.register(torch.Tensor.cuda)
|
||||||
|
def torch_zero_flops_op(*args, **kwargs) -> Tuple[int, int]:
|
||||||
|
flops = 0
|
||||||
|
macs = 0
|
||||||
|
return flops, macs
|
||||||
|
|
||||||
|
|
||||||
|
@meta_profiler_function.register(torch.where)
|
||||||
|
def torch_where(condition: torch.Tensor, x: Any, y: Any) -> Tuple[int, int]:
|
||||||
|
# torch.where returns the broadcasted tensor of condition, x, and y,
|
||||||
|
# so hack it by using addition
|
||||||
|
flops = condition.numel()
|
||||||
|
macs = 0
|
||||||
|
return flops, macs
|
||||||
|
|
||||||
|
|
||||||
|
@meta_profiler_function.register(torch.max)
|
||||||
|
def torch_max(input: torch.Tensor,
|
||||||
|
dim: int = None,
|
||||||
|
keepdim: bool = False,
|
||||||
|
*,
|
||||||
|
out: Optional[torch.Tensor] = None) -> Tuple[int, int]:
|
||||||
|
macs = 0
|
||||||
|
assert out is None, 'assigning value to out is not supported yet'
|
||||||
|
if dim is not None:
|
||||||
|
shape = list(input.shape)
|
||||||
|
shape.pop(int(dim))
|
||||||
|
flops = _prod(shape), macs
|
||||||
|
return flops, macs
|
||||||
|
else:
|
||||||
|
flops = input.numel()
|
||||||
|
return flops, macs
|
7
colossalai/fx/profiler/profiler_module/__init__.py
Normal file
7
colossalai/fx/profiler/profiler_module/__init__.py
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
from .activation_function import *
|
||||||
|
from .convolution import *
|
||||||
|
from .embedding import *
|
||||||
|
from .linear import *
|
||||||
|
from .normalization import *
|
||||||
|
from .pooling import *
|
||||||
|
from .rnn import *
|
@ -0,0 +1,29 @@
|
|||||||
|
from typing import Tuple
|
||||||
|
import torch
|
||||||
|
from ..registry import meta_profiler_module
|
||||||
|
|
||||||
|
# TODO: different activation has different FLOPs count, currently unused.
|
||||||
|
_multiplier = {
|
||||||
|
torch.nn.ReLU: 1,
|
||||||
|
torch.nn.PReLU: 4,
|
||||||
|
torch.nn.Sigmoid: 4,
|
||||||
|
torch.nn.Tanh: 5,
|
||||||
|
torch.nn.LeakyReLU: 3,
|
||||||
|
torch.nn.ELU: 4,
|
||||||
|
torch.nn.ReLU6: 2,
|
||||||
|
torch.nn.GELU: 9,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@meta_profiler_module.register(torch.nn.ELU)
|
||||||
|
@meta_profiler_module.register(torch.nn.LeakyReLU)
|
||||||
|
@meta_profiler_module.register(torch.nn.ReLU)
|
||||||
|
@meta_profiler_module.register(torch.nn.GELU)
|
||||||
|
@meta_profiler_module.register(torch.nn.Sigmoid)
|
||||||
|
@meta_profiler_module.register(torch.nn.Tanh)
|
||||||
|
@meta_profiler_module.register(torch.nn.ReLU6)
|
||||||
|
@meta_profiler_module.register(torch.nn.PReLU)
|
||||||
|
def torch_nn_non_linear_act(self: torch.nn.Module, input: torch.Tensor) -> Tuple[int, int]:
|
||||||
|
flops = input.numel()
|
||||||
|
macs = 0
|
||||||
|
return flops, macs
|
157
colossalai/fx/profiler/profiler_module/convolution.py
Normal file
157
colossalai/fx/profiler/profiler_module/convolution.py
Normal file
@ -0,0 +1,157 @@
|
|||||||
|
import math
|
||||||
|
from typing import Tuple
|
||||||
|
import torch
|
||||||
|
from ..registry import meta_profiler_module
|
||||||
|
|
||||||
|
|
||||||
|
def _prod(dims):
|
||||||
|
p = 1
|
||||||
|
for v in dims:
|
||||||
|
p *= v
|
||||||
|
return p
|
||||||
|
|
||||||
|
|
||||||
|
@meta_profiler_module.register(torch.nn.Conv1d)
|
||||||
|
def torch_nn_conv1d(self: torch.nn.Conv1d, input: torch.Tensor) -> Tuple[int, int]:
|
||||||
|
# the output shape is calculated using the formula stated
|
||||||
|
# at https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
|
||||||
|
c_in, l_in = input.shape[-2:]
|
||||||
|
c_out = self.out_channels
|
||||||
|
l_out = math.floor((l_in + 2 * self.padding[0] - self.dilation[0] *
|
||||||
|
(self.kernel_size[0] - 1) - 1) / self.stride[0] + 1)
|
||||||
|
result_shape = input.shape[:-2] + (
|
||||||
|
c_out,
|
||||||
|
l_out,
|
||||||
|
)
|
||||||
|
macs_per_elem = _prod(self.kernel_size) * c_in // self.groups
|
||||||
|
num_elem = _prod(result_shape)
|
||||||
|
macs = macs_per_elem * num_elem
|
||||||
|
flops = 2 * macs
|
||||||
|
if self.bias is not None:
|
||||||
|
flops += num_elem
|
||||||
|
return flops, macs
|
||||||
|
|
||||||
|
|
||||||
|
@meta_profiler_module.register(torch.nn.Conv2d)
|
||||||
|
def torch_nn_conv2d(self: torch.nn.Conv2d, input: torch.Tensor) -> Tuple[int, int]:
|
||||||
|
# the output shape is calculated using the formula stated
|
||||||
|
# at https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html
|
||||||
|
c_in, h_in, w_in = input.shape[-3:]
|
||||||
|
c_out = self.out_channels
|
||||||
|
h_out = math.floor((h_in + 2 * self.padding[0] - self.dilation[0] *
|
||||||
|
(self.kernel_size[0] - 1) - 1) / self.stride[0] + 1)
|
||||||
|
w_out = math.floor((w_in + 2 * self.padding[1] - self.dilation[1] *
|
||||||
|
(self.kernel_size[1] - 1) - 1) / self.stride[1] + 1)
|
||||||
|
result_shape = input.shape[:-3] + (
|
||||||
|
c_out,
|
||||||
|
h_out,
|
||||||
|
w_out,
|
||||||
|
)
|
||||||
|
macs_per_elem = _prod(self.kernel_size) * c_in // self.groups
|
||||||
|
num_elem = _prod(result_shape)
|
||||||
|
macs = macs_per_elem * num_elem
|
||||||
|
flops = 2 * macs
|
||||||
|
if self.bias is not None:
|
||||||
|
flops += num_elem
|
||||||
|
return flops, macs
|
||||||
|
|
||||||
|
|
||||||
|
@meta_profiler_module.register(torch.nn.Conv3d)
|
||||||
|
def torch_nn_conv3d(self: torch.nn.Conv3d, input: torch.Tensor) -> Tuple[int, int]:
|
||||||
|
# the output shape is calculated using the formula stated
|
||||||
|
# at https://pytorch.org/docs/stable/generated/torch.nn.Conv3d.html
|
||||||
|
c_in, d_in, h_in, w_in = input.shape[-4:]
|
||||||
|
c_out = self.out_channels
|
||||||
|
d_out = math.floor((d_in + 2 * self.padding[0] - self.dilation[0] *
|
||||||
|
(self.kernel_size[0] - 1) - 1) / self.stride[0] + 1)
|
||||||
|
h_out = math.floor((h_in + 2 * self.padding[1] - self.dilation[1] *
|
||||||
|
(self.kernel_size[1] - 1) - 1) / self.stride[1] + 1)
|
||||||
|
w_out = math.floor((w_in + 2 * self.padding[2] - self.dilation[2] *
|
||||||
|
(self.kernel_size[2] - 1) - 1) / self.stride[2] + 1)
|
||||||
|
result_shape = input.shape[:-4] + (
|
||||||
|
c_out,
|
||||||
|
d_out,
|
||||||
|
h_out,
|
||||||
|
w_out,
|
||||||
|
)
|
||||||
|
macs_per_elem = _prod(self.kernel_size) * c_in // self.groups
|
||||||
|
num_elem = _prod(result_shape)
|
||||||
|
macs = macs_per_elem * num_elem
|
||||||
|
flops = 2 * macs
|
||||||
|
if self.bias is not None:
|
||||||
|
flops += num_elem
|
||||||
|
return flops, macs
|
||||||
|
|
||||||
|
|
||||||
|
@meta_profiler_module.register(torch.nn.ConvTranspose1d)
|
||||||
|
def torch_nn_convtranspose1d(self: torch.nn.ConvTranspose1d, input: torch.Tensor) -> Tuple[int, int]:
|
||||||
|
# the output shape is calculated using the formula stated
|
||||||
|
# at https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose1d.html
|
||||||
|
c_in, l_in = input.shape[-2:]
|
||||||
|
c_out = self.out_channels
|
||||||
|
l_out = math.floor((l_in - 1) * self.stride[0] - 2 * self.padding[0] + self.dilation[0] *
|
||||||
|
(self.kernel_size[0] - 1) + self.output_padding[0] + 1)
|
||||||
|
result_shape = input.shape[:-2] + (
|
||||||
|
c_out,
|
||||||
|
l_out,
|
||||||
|
)
|
||||||
|
macs_per_elem = _prod(self.kernel_size) * c_in // self.groups
|
||||||
|
num_elem = _prod(
|
||||||
|
input.shape
|
||||||
|
) # see https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/profiling/flops_profiler/profiler.py#L604
|
||||||
|
macs = macs_per_elem * num_elem
|
||||||
|
flops = 2 * macs
|
||||||
|
if self.bias is not None:
|
||||||
|
flops += _prod(result_shape)
|
||||||
|
return flops, macs
|
||||||
|
|
||||||
|
|
||||||
|
@meta_profiler_module.register(torch.nn.ConvTranspose2d)
|
||||||
|
def torch_nn_convtranspose2d(self: torch.nn.ConvTranspose2d, input: torch.Tensor) -> Tuple[int, int]:
|
||||||
|
# the output shape is calculated using the formula stated
|
||||||
|
# at https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose2d.html
|
||||||
|
c_in, h_in, w_in = input.shape[-3:]
|
||||||
|
c_out = self.out_channels
|
||||||
|
h_out = math.floor((h_in - 1) * self.stride[0] - 2 * self.padding[0] + self.dilation[0] *
|
||||||
|
(self.kernel_size[0] - 1) + self.output_padding[0] + 1)
|
||||||
|
w_out = math.floor((w_in - 1) * self.stride[1] - 2 * self.padding[1] + self.dilation[1] *
|
||||||
|
(self.kernel_size[1] - 1) + self.output_padding[1] + 1)
|
||||||
|
result_shape = input.shape[:-3] + (
|
||||||
|
c_out,
|
||||||
|
h_out,
|
||||||
|
w_out,
|
||||||
|
)
|
||||||
|
macs_per_elem = _prod(self.kernel_size) * c_in // self.groups
|
||||||
|
num_elem = _prod(input.shape)
|
||||||
|
macs = macs_per_elem * num_elem
|
||||||
|
flops = 2 * macs
|
||||||
|
if self.bias is not None:
|
||||||
|
flops += _prod(result_shape)
|
||||||
|
return flops, macs
|
||||||
|
|
||||||
|
|
||||||
|
@meta_profiler_module.register(torch.nn.ConvTranspose3d)
|
||||||
|
def torch_nn_convtranspose3d(self: torch.nn.ConvTranspose3d, input: torch.Tensor) -> Tuple[int, int]:
|
||||||
|
# the output shape is calculated using the formula stated
|
||||||
|
# at https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose3d.html
|
||||||
|
c_in, d_in, h_in, w_in = input.shape[-4:]
|
||||||
|
c_out = self.out_channels
|
||||||
|
d_out = math.floor((d_in - 1) * self.stride[0] - 2 * self.padding[0] + self.dilation[0] *
|
||||||
|
(self.kernel_size[0] - 1) + self.output_padding[0] + 1)
|
||||||
|
h_out = math.floor((h_in - 1) * self.stride[1] - 2 * self.padding[1] + self.dilation[1] *
|
||||||
|
(self.kernel_size[1] - 1) + self.output_padding[1] + 1)
|
||||||
|
w_out = math.floor((w_in - 1) * self.stride[2] - 2 * self.padding[2] + self.dilation[2] *
|
||||||
|
(self.kernel_size[2] - 1) + self.output_padding[2] + 1)
|
||||||
|
result_shape = input.shape[:-4] + (
|
||||||
|
c_out,
|
||||||
|
d_out,
|
||||||
|
h_out,
|
||||||
|
w_out,
|
||||||
|
)
|
||||||
|
macs_per_elem = _prod(self.kernel_size) * c_in // self.groups
|
||||||
|
num_elem = _prod(input.shape)
|
||||||
|
macs = macs_per_elem * num_elem
|
||||||
|
flops = 2 * macs
|
||||||
|
if self.bias is not None:
|
||||||
|
flops += _prod(result_shape)
|
||||||
|
return flops, macs
|
11
colossalai/fx/profiler/profiler_module/embedding.py
Normal file
11
colossalai/fx/profiler/profiler_module/embedding.py
Normal file
@ -0,0 +1,11 @@
|
|||||||
|
from typing import Tuple
|
||||||
|
import torch
|
||||||
|
from ..registry import meta_profiler_module
|
||||||
|
|
||||||
|
|
||||||
|
@meta_profiler_module.register(torch.nn.Embedding)
|
||||||
|
def torch_nn_embedding(self: torch.nn.Embedding, input: torch.Tensor) -> Tuple[int, int]:
|
||||||
|
# nn.Embedding is a dictionary lookup, so technically it has 0 FLOPs. (https://discuss.pytorch.org/t/correct-way-to-calculate-flops-in-model/67198/6)
|
||||||
|
flops = 0
|
||||||
|
macs = 0
|
||||||
|
return flops, macs
|
13
colossalai/fx/profiler/profiler_module/linear.py
Normal file
13
colossalai/fx/profiler/profiler_module/linear.py
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
from typing import Tuple
|
||||||
|
import torch
|
||||||
|
from ..registry import meta_profiler_module
|
||||||
|
|
||||||
|
|
||||||
|
@meta_profiler_module.register(torch.nn.Linear)
|
||||||
|
def torch_nn_linear(self: torch.nn.Linear, input: torch.Tensor) -> Tuple[int, int]:
|
||||||
|
out_features = self.weight.shape[0]
|
||||||
|
macs = torch.numel(input) * out_features
|
||||||
|
flops = 2 * macs
|
||||||
|
if self.bias is not None:
|
||||||
|
flops += self.bias.numel()
|
||||||
|
return flops, macs
|
33
colossalai/fx/profiler/profiler_module/normalization.py
Normal file
33
colossalai/fx/profiler/profiler_module/normalization.py
Normal file
@ -0,0 +1,33 @@
|
|||||||
|
from typing import Tuple, Union
|
||||||
|
import torch
|
||||||
|
from ..registry import meta_profiler_module
|
||||||
|
|
||||||
|
|
||||||
|
@meta_profiler_module.register(torch.nn.InstanceNorm1d)
|
||||||
|
@meta_profiler_module.register(torch.nn.InstanceNorm2d)
|
||||||
|
@meta_profiler_module.register(torch.nn.InstanceNorm3d)
|
||||||
|
@meta_profiler_module.register(torch.nn.LayerNorm)
|
||||||
|
@meta_profiler_module.register(torch.nn.GroupNorm)
|
||||||
|
@meta_profiler_module.register(torch.nn.BatchNorm1d)
|
||||||
|
@meta_profiler_module.register(torch.nn.BatchNorm2d)
|
||||||
|
@meta_profiler_module.register(torch.nn.BatchNorm3d)
|
||||||
|
def torch_nn_normalize(self: Union[torch.nn.LayerNorm, torch.nn.GroupNorm, torch.nn.BatchNorm1d, torch.nn.BatchNorm2d,
|
||||||
|
torch.nn.BatchNorm3d], input: torch.Tensor) -> Tuple[int, int]:
|
||||||
|
# adopted from https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/profiling/flops_profiler/profiler.py#L615
|
||||||
|
has_affine = self.weight is not None
|
||||||
|
if self.training:
|
||||||
|
flops = input.numel() * (2 if has_affine else 1)
|
||||||
|
else:
|
||||||
|
flops = input.numel() * (5 if has_affine else 4)
|
||||||
|
macs = 0
|
||||||
|
return flops, macs
|
||||||
|
|
||||||
|
|
||||||
|
try:
|
||||||
|
import apex
|
||||||
|
meta_profiler_module.register(apex.normalization.FusedLayerNorm)(torch_nn_normalize)
|
||||||
|
meta_profiler_module.register(apex.normalization.FusedRMSNorm)(torch_nn_normalize)
|
||||||
|
meta_profiler_module.register(apex.normalization.MixedFusedLayerNorm)(torch_nn_normalize)
|
||||||
|
meta_profiler_module.register(apex.normalization.MixedFusedRMSNorm)(torch_nn_normalize)
|
||||||
|
except (ImportError, AttributeError):
|
||||||
|
pass
|
22
colossalai/fx/profiler/profiler_module/pooling.py
Normal file
22
colossalai/fx/profiler/profiler_module/pooling.py
Normal file
@ -0,0 +1,22 @@
|
|||||||
|
from typing import Tuple
|
||||||
|
import torch
|
||||||
|
from ..registry import meta_profiler_module
|
||||||
|
|
||||||
|
|
||||||
|
@meta_profiler_module.register(torch.nn.AvgPool1d)
|
||||||
|
@meta_profiler_module.register(torch.nn.AvgPool2d)
|
||||||
|
@meta_profiler_module.register(torch.nn.AvgPool3d)
|
||||||
|
@meta_profiler_module.register(torch.nn.MaxPool1d)
|
||||||
|
@meta_profiler_module.register(torch.nn.MaxPool2d)
|
||||||
|
@meta_profiler_module.register(torch.nn.MaxPool3d)
|
||||||
|
@meta_profiler_module.register(torch.nn.AdaptiveAvgPool1d)
|
||||||
|
@meta_profiler_module.register(torch.nn.AdaptiveMaxPool1d)
|
||||||
|
@meta_profiler_module.register(torch.nn.AdaptiveAvgPool2d)
|
||||||
|
@meta_profiler_module.register(torch.nn.AdaptiveMaxPool2d)
|
||||||
|
@meta_profiler_module.register(torch.nn.AdaptiveAvgPool3d)
|
||||||
|
@meta_profiler_module.register(torch.nn.AdaptiveMaxPool3d)
|
||||||
|
def torch_nn_pooling(self: torch.nn.Module, input: torch.Tensor) -> Tuple[int, int]:
|
||||||
|
# all pooling could be considered as going over each input element only once (https://stackoverflow.com/a/67301217)
|
||||||
|
flops = input.numel()
|
||||||
|
macs = 0
|
||||||
|
return flops, macs
|
13
colossalai/fx/profiler/profiler_module/rnn.py
Normal file
13
colossalai/fx/profiler/profiler_module/rnn.py
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
import torch
|
||||||
|
from ..registry import meta_profiler_module
|
||||||
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: calculate rnn FLOPs
|
||||||
|
@meta_profiler_module.register(torch.nn.GRU)
|
||||||
|
@meta_profiler_module.register(torch.nn.RNN)
|
||||||
|
def torch_nn_rnn(self: torch.nn.Module, input: torch.Tensor, hx: torch.Tensor) -> Tuple[int, int]:
|
||||||
|
raise NotImplementedError
|
||||||
|
flops = 0
|
||||||
|
macs = 0
|
||||||
|
return flops, macs
|
25
colossalai/fx/profiler/registry.py
Normal file
25
colossalai/fx/profiler/registry.py
Normal file
@ -0,0 +1,25 @@
|
|||||||
|
class ProfilerRegistry:
|
||||||
|
|
||||||
|
def __init__(self, name):
|
||||||
|
self.name = name
|
||||||
|
self.store = {}
|
||||||
|
|
||||||
|
def register(self, source):
|
||||||
|
|
||||||
|
def wrapper(func):
|
||||||
|
self.store[source] = func
|
||||||
|
return func
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
def get(self, source):
|
||||||
|
assert source in self.store
|
||||||
|
target = self.store[source]
|
||||||
|
return target
|
||||||
|
|
||||||
|
def has(self, source):
|
||||||
|
return source in self.store
|
||||||
|
|
||||||
|
|
||||||
|
meta_profiler_function = ProfilerRegistry(name='patched_functions_for_meta_profile')
|
||||||
|
meta_profiler_module = ProfilerRegistry(name='patched_modules_for_meta_profile')
|
180
colossalai/fx/profiler/utils.py
Normal file
180
colossalai/fx/profiler/utils.py
Normal file
@ -0,0 +1,180 @@
|
|||||||
|
from functools import partial
|
||||||
|
from operator import add, floordiv, getitem, mul, neg, setitem, sub, pos
|
||||||
|
from typing import Callable, NamedTuple, Any, Dict, Tuple
|
||||||
|
import torch
|
||||||
|
from torch.fx.node import Argument, Target
|
||||||
|
from torch.fx._compatibility import compatibility
|
||||||
|
from colossalai.fx.tracer.meta_patch import meta_patched_function, meta_patched_module
|
||||||
|
from . import meta_profiler_function, meta_profiler_module
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
'MetaProfile', 'profile_function', 'profile_module', 'profile_method', 'calculate_activation_size',
|
||||||
|
'calculate_param_size'
|
||||||
|
]
|
||||||
|
|
||||||
|
# TODO fill out the inplace ops
|
||||||
|
INPLACE_OPS = [
|
||||||
|
add,
|
||||||
|
sub,
|
||||||
|
mul,
|
||||||
|
floordiv,
|
||||||
|
neg,
|
||||||
|
pos,
|
||||||
|
getitem,
|
||||||
|
setitem,
|
||||||
|
torch.Tensor.cpu,
|
||||||
|
]
|
||||||
|
|
||||||
|
# TODO check that call_methods are indeed inplace
|
||||||
|
INPLACE_METHOD = [
|
||||||
|
'transpose',
|
||||||
|
'permute',
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@compatibility(is_backward_compatible=True)
|
||||||
|
class MetaProfile(NamedTuple):
|
||||||
|
# MetaProfile is a structure containing pertinent information
|
||||||
|
# about a node within a torch.fx GraphModule.
|
||||||
|
|
||||||
|
param: int
|
||||||
|
activation: int
|
||||||
|
flops: int
|
||||||
|
macs: int
|
||||||
|
|
||||||
|
|
||||||
|
def calculate_activation_size(activation: any) -> int:
|
||||||
|
"""
|
||||||
|
Calculate activation size of a node.
|
||||||
|
"""
|
||||||
|
activation_size = 0
|
||||||
|
if isinstance(activation, torch.Tensor):
|
||||||
|
activation_size += activation.numel() * torch.tensor([], dtype=activation.dtype).element_size()
|
||||||
|
elif isinstance(activation, dict):
|
||||||
|
value_list = [v for _, v in activation.items()]
|
||||||
|
activation_size += calculate_activation_size(value_list)
|
||||||
|
else:
|
||||||
|
for element in activation:
|
||||||
|
activation_size += calculate_activation_size(element)
|
||||||
|
return activation_size
|
||||||
|
|
||||||
|
|
||||||
|
def calculate_param_size(mod: torch.nn.Module) -> int:
|
||||||
|
"""
|
||||||
|
Calculate param size of a node.
|
||||||
|
"""
|
||||||
|
param_size = 0
|
||||||
|
for param in mod.parameters():
|
||||||
|
param_size += param.numel() * torch.tensor([], dtype=param.dtype).element_size()
|
||||||
|
return param_size
|
||||||
|
|
||||||
|
|
||||||
|
def profile_function(target: 'Target') -> Callable:
|
||||||
|
"""
|
||||||
|
Wrap a `call_function` node or `torch.nn.functional` in order to
|
||||||
|
record the memory cost and FLOPs of the execution.
|
||||||
|
|
||||||
|
Warnings:
|
||||||
|
You may only use tensors with `device=meta` for this wrapped function.
|
||||||
|
Only original `torch.nn.functional` are available.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
input = torch.rand(100, 100, 100, 100, device='meta')
|
||||||
|
func = torch.nn.functional.relu
|
||||||
|
output, profile = profile_function(func)(input, inplace=False)
|
||||||
|
print(f"Profiling function {func},")
|
||||||
|
print(f"Param size: {profile.param / 1024**2:.3f} MB, Activation size: {profile.activation / 1024**2:.3f} MB, {profile.flops} FLOPs, {profile.macs} MACs")
|
||||||
|
"""
|
||||||
|
|
||||||
|
def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any:
|
||||||
|
assert meta_profiler_function.has(target) or meta_profiler_function.has(
|
||||||
|
target.__name__), f"Colossal-AI hasn't supported profiling for {target}, you might manually patch it."
|
||||||
|
|
||||||
|
# call_function has no parameters
|
||||||
|
param_size = 0
|
||||||
|
activation_size = 0
|
||||||
|
result = func(*args, **kwargs)
|
||||||
|
if target not in INPLACE_OPS and not kwargs.get('inplace', False):
|
||||||
|
activation_size += calculate_activation_size(result)
|
||||||
|
if meta_profiler_function.has(target):
|
||||||
|
profiler = meta_profiler_function.get(target)
|
||||||
|
else:
|
||||||
|
profiler = meta_profiler_function.get(target.__name__)
|
||||||
|
flops, macs = profiler(*args, **kwargs)
|
||||||
|
return result, MetaProfile(param_size, activation_size, flops, macs)
|
||||||
|
|
||||||
|
f.__name__ = target.__name__
|
||||||
|
# fetch patched function
|
||||||
|
if meta_patched_function.has(target):
|
||||||
|
func = meta_patched_function.get(target)
|
||||||
|
elif meta_patched_function.has(target.__name__):
|
||||||
|
func = meta_patched_function.get(target.__name__)
|
||||||
|
else:
|
||||||
|
func = target
|
||||||
|
return f
|
||||||
|
|
||||||
|
|
||||||
|
def profile_method(target: 'Target') -> Callable:
|
||||||
|
"""
|
||||||
|
Wrap a `call_method` node
|
||||||
|
record the memory cost and FLOPs of the execution.
|
||||||
|
|
||||||
|
Warnings:
|
||||||
|
This is not fully implemented and you may follow the error message to debug.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any:
|
||||||
|
# args[0] is the `self` object for this method call
|
||||||
|
self_obj, *args_tail = args
|
||||||
|
|
||||||
|
# Execute the method and return the result
|
||||||
|
assert isinstance(target, str), f'{target} instance is not str.'
|
||||||
|
result = getattr(self_obj, target)(*args_tail, **kwargs)
|
||||||
|
assert target in INPLACE_METHOD, f'Please check {target} is an inplace method. If so, add target to INPLACE_METHOD={INPLACE_METHOD}.'
|
||||||
|
|
||||||
|
# call_method has no parameters and are MOSTLY(?) inplace, and has no FLOPs or MACs.
|
||||||
|
param_size = 0
|
||||||
|
activation_size = 0
|
||||||
|
flops = 0
|
||||||
|
macs = 0
|
||||||
|
return result, MetaProfile(param_size, activation_size, flops, macs)
|
||||||
|
|
||||||
|
return f
|
||||||
|
|
||||||
|
|
||||||
|
def profile_module(module: torch.nn.Module) -> Callable:
|
||||||
|
"""
|
||||||
|
Wrap a `call_module` node or `torch.nn` in order to
|
||||||
|
record the memory cost and FLOPs of the execution.
|
||||||
|
|
||||||
|
Warnings:
|
||||||
|
You may only use tensors with `device=meta` for this wrapped function.
|
||||||
|
Only original `torch.nn` are available.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
input = torch.rand(4, 3, 224, 224, device='meta')
|
||||||
|
mod = torch.nn.Conv2d(3, 128, 3)
|
||||||
|
output, profile = profile_module(mod)(input)
|
||||||
|
print(f"Profiling function {mod},")
|
||||||
|
print(f"Param size: {profile.param / 1024**2:.3f} MB, Activation size: {profile.activation / 1024**2:.3f} MB, {profile.flops} FLOPs, {profile.macs} MACs")
|
||||||
|
"""
|
||||||
|
|
||||||
|
def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any:
|
||||||
|
assert meta_profiler_module.has(
|
||||||
|
type(module)), f"Colossal-AI hasn't supported profiling for {module}, you might manually patch it."
|
||||||
|
param_size = calculate_param_size(module)
|
||||||
|
activation_size = 0
|
||||||
|
result = func(*args, **kwargs)
|
||||||
|
if not getattr(module, 'inplace', False):
|
||||||
|
activation_size += calculate_activation_size(result)
|
||||||
|
profiler = meta_profiler_module.get(type(module))
|
||||||
|
flops, macs = profiler(module, *args, **kwargs)
|
||||||
|
return result, MetaProfile(param_size, activation_size, flops, macs)
|
||||||
|
|
||||||
|
f.__name__ = module.__class__.__name__
|
||||||
|
# fetch patched module
|
||||||
|
if meta_patched_module.has(type(module)):
|
||||||
|
func = partial(meta_patched_module.get(type(module)), module)
|
||||||
|
else:
|
||||||
|
func = module.forward
|
||||||
|
return f
|
@ -68,7 +68,7 @@ def _run_ckpt_solver(rank):
|
|||||||
|
|
||||||
tracer = ColoTracer(trace_act_ckpt=False)
|
tracer = ColoTracer(trace_act_ckpt=False)
|
||||||
|
|
||||||
data = torch.rand(2, 3, 32, 32)
|
data = torch.rand(2, 3, 32, 32, device='meta')
|
||||||
for solver in SOLVERS:
|
for solver in SOLVERS:
|
||||||
for model_cls in MODEL_LIST:
|
for model_cls in MODEL_LIST:
|
||||||
m = model_cls(num_classes=5)
|
m = model_cls(num_classes=5)
|
||||||
@ -98,7 +98,7 @@ def _run_ckpt_solver_torch11(rank):
|
|||||||
|
|
||||||
tracer = ColoTracer(trace_act_ckpt=False)
|
tracer = ColoTracer(trace_act_ckpt=False)
|
||||||
|
|
||||||
data = torch.rand(2, 3, 32, 32)
|
data = torch.rand(2, 3, 32, 32, device='meta')
|
||||||
for solver in SOLVERS:
|
for solver in SOLVERS:
|
||||||
for model_cls in MODEL_LIST:
|
for model_cls in MODEL_LIST:
|
||||||
m = model_cls(num_classes=5)
|
m = model_cls(num_classes=5)
|
||||||
|
@ -32,7 +32,7 @@ class MLP(torch.nn.Module):
|
|||||||
|
|
||||||
def test_comm_size_compute():
|
def test_comm_size_compute():
|
||||||
model = MLP(MODEL_DIM)
|
model = MLP(MODEL_DIM)
|
||||||
input_sample = torch.rand(BATCH_SIZE, MODEL_DIM)
|
input_sample = torch.rand(BATCH_SIZE, MODEL_DIM, device='meta')
|
||||||
gm = symbolic_trace(model)
|
gm = symbolic_trace(model)
|
||||||
MetaInfoProp(gm).run(input_sample)
|
MetaInfoProp(gm).run(input_sample)
|
||||||
annotated_model = uniform_split_pass(gm, PIPELINE_SIZE)
|
annotated_model = uniform_split_pass(gm, PIPELINE_SIZE)
|
||||||
|
@ -20,17 +20,20 @@ def meta_check(meta_info_spec: TensorMetadata, orig_tensor: torch.Tensor):
|
|||||||
|
|
||||||
def test_meta_info_prop():
|
def test_meta_info_prop():
|
||||||
model = torch.nn.Linear(DIM_IN, DIM_OUT)
|
model = torch.nn.Linear(DIM_IN, DIM_OUT)
|
||||||
input_sample = torch.rand(BATCH_SIZE, DIM_IN)
|
input_sample = torch.rand(BATCH_SIZE, DIM_IN, device='meta')
|
||||||
orig_output = model(input_sample)
|
orig_output = model(input_sample)
|
||||||
gm = symbolic_trace(model)
|
gm = symbolic_trace(model)
|
||||||
for node in gm.graph.nodes:
|
for node in gm.graph.nodes:
|
||||||
assert not hasattr(node,
|
assert not hasattr(node,
|
||||||
'node_size'), 'The attribute Node.node_size should not exist before MetaInfoProp procedure'
|
'node_size'), 'The attribute Node.node_size should not exist before MetaInfoProp procedure'
|
||||||
assert not hasattr(node,
|
assert not hasattr(node,
|
||||||
'param_size'), 'The attribute Node.param_size should not exist before MetaInfoProp procedure'
|
'__param__'), 'The attribute Node.__param__ should not exist before MetaInfoProp procedure'
|
||||||
assert not hasattr(
|
assert not hasattr(
|
||||||
node,
|
node, '__activation__'), 'The attribute Node.__activation__ should not exist before MetaInfoProp procedure'
|
||||||
'activation_size'), 'The attribute Node.activation_size should not exist before MetaInfoProp procedure'
|
assert not hasattr(node,
|
||||||
|
'__flops__'), 'The attribute Node.__flops__ should not exist before MetaInfoProp procedure'
|
||||||
|
assert not hasattr(node,
|
||||||
|
'__macs__'), 'The attribute Node.__macs__ should not exist before MetaInfoProp procedure'
|
||||||
MetaInfoProp(gm).run(input_sample)
|
MetaInfoProp(gm).run(input_sample)
|
||||||
for node in gm.graph.nodes:
|
for node in gm.graph.nodes:
|
||||||
if node.op == 'placeholder':
|
if node.op == 'placeholder':
|
||||||
@ -38,9 +41,11 @@ def test_meta_info_prop():
|
|||||||
if node.op == 'output':
|
if node.op == 'output':
|
||||||
meta_check(node.meta['tensor_meta'], orig_output)
|
meta_check(node.meta['tensor_meta'], orig_output)
|
||||||
assert hasattr(node, 'node_size'), 'The attribute Node.node_size should exist after MetaInfoProp procedure'
|
assert hasattr(node, 'node_size'), 'The attribute Node.node_size should exist after MetaInfoProp procedure'
|
||||||
assert hasattr(node, 'param_size'), 'The attribute Node.param_size should exist after MetaInfoProp procedure'
|
assert hasattr(node, '__param__'), 'The attribute Node.__param__ should exist after MetaInfoProp procedure'
|
||||||
assert hasattr(
|
assert hasattr(node,
|
||||||
node, 'activation_size'), 'The attribute Node.activation_size should exist after MetaInfoProp procedure'
|
'__activation__'), 'The attribute Node.__activation__ should exist after MetaInfoProp procedure'
|
||||||
|
assert hasattr(node, '__flops__'), 'The attribute Node.__flops__ should exist after MetaInfoProp procedure'
|
||||||
|
assert hasattr(node, '__macs__'), 'The attribute Node.__macs__ should exist after MetaInfoProp procedure'
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
Loading…
Reference in New Issue
Block a user