mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-09 21:09:18 +00:00
[fx/tuning] tune performance on rotor with meta info. (#1599)
This commit is contained in:
@@ -7,4 +7,4 @@ else:
|
||||
from .experimental import meta_profiler_function, meta_profiler_module, profile_function, profile_method, profile_module
|
||||
|
||||
from .dataflow import GraphInfo
|
||||
from .memory import parameter_size, activation_size
|
||||
from .memory import parameter_size, activation_size, is_inplace
|
||||
|
@@ -1,16 +1,17 @@
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from functools import partial
|
||||
from typing import Dict
|
||||
from torch.fx import Graph, Node
|
||||
from .memory import activation_size
|
||||
from .memory import activation_size, is_inplace
|
||||
from . import META_COMPATIBILITY
|
||||
if META_COMPATIBILITY:
|
||||
from .memory import NORMALIZATION_ATEN, CLONE_ATEN
|
||||
|
||||
|
||||
class Phase(Enum):
|
||||
FORWARD = 0
|
||||
LOSS = 1
|
||||
BACKWARD = 2
|
||||
PLACEHOLDER = 3
|
||||
BACKWARD = 1
|
||||
PLACEHOLDER = 2
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -86,8 +87,10 @@ def autograd_graph_analysis(graph: Graph) -> GraphInfo:
|
||||
def _peak_memory(deps: Dict[Node, int]):
|
||||
peak_mem = 0
|
||||
for k, v in deps.items():
|
||||
if v > 0:
|
||||
if v > 0 and is_phase(k, Phase.BACKWARD) and not any(map(is_inplace, k.users)):
|
||||
peak_mem += activation_size(k.meta['out'])
|
||||
if v <= float('-inf') and is_saved(k) and (k.target not in NORMALIZATION_ATEN):
|
||||
peak_mem -= activation_size(k.meta['out'])
|
||||
return peak_mem
|
||||
|
||||
# deps is used to track all the memory dependencies of the graph.
|
||||
@@ -96,7 +99,7 @@ def autograd_graph_analysis(graph: Graph) -> GraphInfo:
|
||||
|
||||
for n in graph.nodes:
|
||||
n: Node
|
||||
if is_saved(n) and not any(map(partial(is_phase, phase=Phase.LOSS), n.users)):
|
||||
if is_saved(n) and (n.target not in NORMALIZATION_ATEN) or any(map(lambda x: x.target in CLONE_ATEN, n.users)):
|
||||
# A forward tensor who is marked `save` but is not
|
||||
# an input to `loss` should be saved during forward.
|
||||
# If the tensor is a placeholder, then it belongs to `fwd_mem_in`.
|
||||
@@ -110,13 +113,14 @@ def autograd_graph_analysis(graph: Graph) -> GraphInfo:
|
||||
graph_info.fwd_mem_tmp += activation_size(n.meta['out'])
|
||||
elif is_phase(n, Phase.BACKWARD):
|
||||
if len(n.users):
|
||||
# liveness analysis is only used in backward
|
||||
deps[n] = len(n.users)
|
||||
graph_info.bwd_mem_tmp = max(graph_info.bwd_mem_tmp, _peak_memory(deps))
|
||||
for input_n in n.all_input_nodes:
|
||||
if input_n in deps:
|
||||
deps[input_n] -= 1
|
||||
else:
|
||||
# TODO: some of the bwd_mem_out might be model parameters.
|
||||
# basically a backward node without user is a `grad_out` node
|
||||
graph_info.bwd_mem_out += activation_size(n.meta['out'])
|
||||
for input_n in n.all_input_nodes:
|
||||
if input_n in deps:
|
||||
deps[input_n] -= 1
|
||||
if deps[input_n] <= 0:
|
||||
deps[input_n] = float('-inf')
|
||||
return graph_info
|
||||
|
@@ -1,9 +1,10 @@
|
||||
import torch
|
||||
from torch.fx import Node
|
||||
from typing import Union, Dict, List, Tuple
|
||||
from operator import add, floordiv, getitem, mul, neg, setitem, sub, pos
|
||||
from . import META_COMPATIBILITY
|
||||
|
||||
__all__ = ['activation_size', 'parameter_size']
|
||||
__all__ = ['activation_size', 'parameter_size', 'is_inplace']
|
||||
|
||||
if META_COMPATIBILITY:
|
||||
aten = torch.ops.aten
|
||||
@@ -21,6 +22,7 @@ if META_COMPATIBILITY:
|
||||
aten.bernoulli_.float,
|
||||
|
||||
# inplace reshaping
|
||||
aten.copy_.default,
|
||||
aten.detach.default,
|
||||
aten.t.default,
|
||||
aten.transpose.int,
|
||||
@@ -28,7 +30,17 @@ if META_COMPATIBILITY:
|
||||
aten._unsafe_view.default,
|
||||
]
|
||||
|
||||
__all__ += ['INPLACE_ATEN', 'WEIRD_OPS']
|
||||
NORMALIZATION_ATEN = [
|
||||
aten.native_batch_norm.default,
|
||||
aten.native_layer_norm.default,
|
||||
# aten.max_pool2d_with_indices.default,
|
||||
]
|
||||
|
||||
CLONE_ATEN = [
|
||||
aten.clone.default,
|
||||
]
|
||||
|
||||
__all__ += ['INPLACE_ATEN', 'WEIRD_OPS', 'NORMALIZATION_ATEN', 'CLONE_ATEN']
|
||||
|
||||
else:
|
||||
# TODO fill out the inplace ops
|
||||
@@ -106,3 +118,23 @@ def parameter_size(mod: torch.nn.Module) -> int:
|
||||
for param in mod.parameters():
|
||||
param_size += param.numel() * torch.tensor([], dtype=param.dtype).element_size()
|
||||
return param_size
|
||||
|
||||
|
||||
def is_inplace(n: Node):
|
||||
"""Get the inplace argument from torch.fx.Node
|
||||
|
||||
Args:
|
||||
node (Node): torch.fx.Node
|
||||
|
||||
Returns:
|
||||
bool: indicates whether this op is inplace
|
||||
"""
|
||||
inplace = False
|
||||
if n.op == "call_function":
|
||||
inplace = n.kwargs.get("inplace", False)
|
||||
if META_COMPATIBILITY and n.target in INPLACE_ATEN:
|
||||
inplace = True
|
||||
elif n.op == "call_module":
|
||||
inplace = getattr(n.graph.owning_module.get_submodule(n.target), "inplace", False)
|
||||
|
||||
return inplace
|
||||
|
@@ -222,6 +222,7 @@ flop_mapping = {
|
||||
aten._adaptive_avg_pool2d_backward.default: elementwise_flop_counter(0, 1),
|
||||
aten._adaptive_avg_pool3d.default: elementwise_flop_counter(1, 0),
|
||||
aten._adaptive_avg_pool3d_backward.default: elementwise_flop_counter(0, 1),
|
||||
aten.embedding_dense_backward.default: elementwise_flop_counter(0, 1),
|
||||
}
|
||||
|
||||
elementwise_flop_aten = [
|
||||
|
@@ -1,12 +1,10 @@
|
||||
from dataclasses import dataclass
|
||||
from enum import auto
|
||||
from typing import Callable, Any, Dict, Tuple
|
||||
import torch
|
||||
from torch.fx import Graph, Node
|
||||
from torch.fx.node import Argument, Target
|
||||
from torch.utils._pytree import tree_map
|
||||
from .dataflow import GraphInfo, autograd_graph_analysis, Phase
|
||||
from .memory import WEIRD_OPS, activation_size
|
||||
from .memory import WEIRD_OPS
|
||||
from .tensor import MetaTensor
|
||||
from .opcount import flop_mapping
|
||||
|
||||
@@ -23,7 +21,7 @@ def is_autogradable(x):
|
||||
return isinstance(x, torch.Tensor) and x.is_floating_point()
|
||||
|
||||
|
||||
def _profile(target: Callable, *args, inplace=False, **kwargs) -> Tuple[Any, ...]:
|
||||
def _profile(target: Callable, *args, **kwargs) -> Tuple[Any, ...]:
|
||||
"""
|
||||
Profile a Callable function with args and kwargs.
|
||||
|
||||
@@ -42,7 +40,6 @@ def _profile(target: Callable, *args, inplace=False, **kwargs) -> Tuple[Any, ...
|
||||
# `flop_count`` serves as a global dictionary to store results.
|
||||
flop_count = {
|
||||
Phase.FORWARD: 0,
|
||||
Phase.LOSS: 0,
|
||||
Phase.BACKWARD: 0,
|
||||
}
|
||||
|
||||
@@ -71,6 +68,10 @@ def _profile(target: Callable, *args, inplace=False, **kwargs) -> Tuple[Any, ...
|
||||
kwargs_node = tree_map(get_node, kwargs)
|
||||
node = subgraph.create_node('call_function', func, args_node, kwargs_node)
|
||||
|
||||
# do not allocate on `cpu`
|
||||
if 'device' in kwargs:
|
||||
kwargs['device'] = 'meta'
|
||||
|
||||
def unwrap(x):
|
||||
# if x is a `nn.Parameter`, we can first wrap it with `FlopTensor`
|
||||
if isinstance(x, torch.Tensor) and not hasattr(x, '_tensor'):
|
||||
@@ -101,13 +102,13 @@ def _profile(target: Callable, *args, inplace=False, **kwargs) -> Tuple[Any, ...
|
||||
if target not in WEIRD_OPS:
|
||||
|
||||
def wrap(x):
|
||||
return FlopTensor(x.detach().requires_grad_(
|
||||
True)) if is_autogradable(x) and not inplace and not hasattr(x, '_tensor') else x
|
||||
return FlopTensor(
|
||||
x.detach().requires_grad_(True)) if is_autogradable(x) and not hasattr(x, '_tensor') else x
|
||||
else:
|
||||
|
||||
def wrap(x):
|
||||
return FlopTensor(x.detach().requires_grad_(
|
||||
False)) if is_autogradable(x) and not inplace and not hasattr(x, '_tensor') else x
|
||||
return FlopTensor(
|
||||
x.detach().requires_grad_(False)) if is_autogradable(x) and not hasattr(x, '_tensor') else x
|
||||
|
||||
# Basically, we need to detach the args and kwargs from the outer graph.
|
||||
args = tree_map(wrap, args)
|
||||
@@ -125,7 +126,7 @@ def _profile(target: Callable, *args, inplace=False, **kwargs) -> Tuple[Any, ...
|
||||
tree_map(set_placeholder, kwargs)
|
||||
|
||||
def pack(x):
|
||||
if isinstance(x, FlopTensor):
|
||||
if isinstance(x, FlopTensor) and not isinstance(x, torch.nn.Parameter):
|
||||
x._node.meta['saved'] = True
|
||||
return x
|
||||
|
||||
@@ -143,13 +144,15 @@ def _profile(target: Callable, *args, inplace=False, **kwargs) -> Tuple[Any, ...
|
||||
else:
|
||||
out = target(*args, **kwargs)
|
||||
|
||||
# If the output is not a floating point `torch.Tensor` or it does not
|
||||
# requires grad, then we should not run backward for this node.
|
||||
if is_autogradable(out) and out.requires_grad:
|
||||
phase = Phase.LOSS
|
||||
loss = out.sum()
|
||||
phase = Phase.BACKWARD
|
||||
loss.backward()
|
||||
# If the output is not a floating point `torch.Tensor` or it does not
|
||||
# requires grad, then we should not run backward for this node.
|
||||
if is_autogradable(out) and out.requires_grad:
|
||||
phase = Phase.BACKWARD
|
||||
if isinstance(out, FlopTensor):
|
||||
out._node.meta['save'] = False
|
||||
grad = torch.empty_like(out._tensor, device='meta') if isinstance(out, FlopTensor) else torch.empty_like(
|
||||
out, device='meta')
|
||||
torch.autograd.backward(out, FlopTensor(grad))
|
||||
|
||||
graph_info = autograd_graph_analysis(subgraph)
|
||||
graph_info.fwd_flop, graph_info.bwd_flop = flop_count[Phase.FORWARD], flop_count[Phase.BACKWARD]
|
||||
@@ -172,7 +175,7 @@ def profile_function(target: 'Target') -> Callable:
|
||||
Examples:
|
||||
>>> input = torch.rand(100, 100, 100, 100, device='meta')
|
||||
>>> func = torch.nn.functional.relu
|
||||
>>> output, meta_info = profile_function(func)(input, inplace=False)
|
||||
>>> output, meta_info = profile_function(func)(input)
|
||||
"""
|
||||
|
||||
def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any:
|
||||
@@ -183,7 +186,7 @@ def profile_function(target: 'Target') -> Callable:
|
||||
args = tree_map(lambda x: x.to('meta') if isinstance(x, torch.Tensor) else x, args)
|
||||
kwargs = tree_map(lambda x: x.to('meta') if isinstance(x, torch.Tensor) else x, kwargs)
|
||||
out = func(*args, **kwargs)
|
||||
return out, GraphInfo(out.numel(), out.numel(), activation_size((args, kwargs)), 0, activation_size(out), 0)
|
||||
return out, GraphInfo(out.numel(), out.numel(), 0, 0, 0, 0)
|
||||
out, meta = _profile(func, *args, **kwargs)
|
||||
return out, meta
|
||||
|
||||
@@ -201,7 +204,7 @@ def profile_method(target: 'Target') -> Callable:
|
||||
def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any:
|
||||
# execute the method and return the result
|
||||
assert isinstance(target, str), f'{target} instance is not str.'
|
||||
out, meta = _profile(target, *args, inplace=False, **kwargs)
|
||||
out, meta = _profile(target, *args, **kwargs)
|
||||
return out, meta
|
||||
|
||||
return f
|
||||
@@ -230,8 +233,8 @@ def profile_module(module: torch.nn.Module) -> Callable:
|
||||
args = tree_map(lambda x: x.to('meta'), args)
|
||||
kwargs = tree_map(lambda x: x.to('meta'), kwargs)
|
||||
out = func(*args, **kwargs)
|
||||
return out, GraphInfo(out.numel(), out.numel(), activation_size((args, kwargs)), 0, activation_size(out), 0)
|
||||
out, meta = _profile(func, *args, inplace=getattr(module, 'inplace', False), **kwargs)
|
||||
return out, GraphInfo(out.numel(), out.numel(), 0, 0, 0, 0)
|
||||
out, meta = _profile(func, *args, **kwargs)
|
||||
return out, meta
|
||||
|
||||
f.__name__ = module.__class__.__name__
|
||||
|
Reference in New Issue
Block a user