mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-12 20:54:35 +00:00
[fx/profiler] tuned the calculation of memory estimation (#1619)
* [fx] tuned the meta info and rotor solver. * [fx] remove import. * [fx] remove import. * [fx] remove import. * [fx] tune the meta calculations. * [fx] polish comments. * [fx] remove assertions. * [fx] modify test cases. * [fx] modify test cases. * [fx] optimize import. * [fx
This commit is contained in:
@@ -3,9 +3,6 @@ from enum import Enum
|
||||
from typing import Dict
|
||||
from torch.fx import Graph, Node
|
||||
from .memory import activation_size, is_inplace
|
||||
from . import META_COMPATIBILITY
|
||||
if META_COMPATIBILITY:
|
||||
from .memory import NORMALIZATION_ATEN, CLONE_ATEN
|
||||
|
||||
|
||||
class Phase(Enum):
|
||||
@@ -23,29 +20,32 @@ class GraphInfo:
|
||||
============================================================================
|
||||
-------------------------------
|
||||
| Node |
|
||||
[fwd_in] are ---> | [fwd_in] [bwd_out] | <----- [bwd_out] is marks the memory for `grad_out`
|
||||
[fwd_in] are ---> | [fwd_in] [bwd_out] | <----- [bwd_out] is marks the memory for `grad_out`.
|
||||
placeholders saved for | | \__________ | |
|
||||
backward. | | \ | |
|
||||
| [fwd_tmp] ------> [bwd_tmp] | <-----
|
||||
| | \_________ | | [bwd_tmp] marks the peak memory
|
||||
| / \ \ | | in backward pass.
|
||||
[x] is not counted ---> | [x] [fwd_tmp] -> [bwd_tmp] | <-----
|
||||
in [fwd_tmp] because | | | \_____ | |
|
||||
it is not saved for | | | \ | |
|
||||
backward. -------------------------------
|
||||
in [fwd_tmp] because | | \_____ | |
|
||||
it is not saved for | | \ | |
|
||||
backward. | [fwd_out] \ | | <----- [fwd_out] is [fwd_in] for the next node.
|
||||
-------------------------------
|
||||
============================================================================
|
||||
Attributes:
|
||||
fwd_flop (int): The forward FLOPs of a certain node
|
||||
bwd_flop (int): The backward FLOPs of a certain node.
|
||||
fwd_mem_in (int): See the above illustration.
|
||||
save_fwd_in (bool): The decision variable of whether to save the fwd_mem_out of parent nodes.
|
||||
fwd_mem_tmp (int): See the above illustration.
|
||||
fwd_mem_out (int): See the above illustration.
|
||||
bwd_mem_tmp (int): See the above illustration.
|
||||
bwd_mem_out (int): See the above illustration.
|
||||
"""
|
||||
fwd_flop: int = 0
|
||||
bwd_flop: int = 0
|
||||
fwd_mem_in: int = 0
|
||||
save_fwd_in: bool = False
|
||||
fwd_mem_tmp: int = 0
|
||||
fwd_mem_out: int = 0
|
||||
bwd_mem_tmp: int = 0
|
||||
bwd_mem_out: int = 0
|
||||
|
||||
@@ -56,7 +56,7 @@ def is_phase(n: Node, phase: Phase) -> bool:
|
||||
|
||||
|
||||
def is_saved(n: Node):
|
||||
return n.meta.get('saved', False)
|
||||
return len(n.meta['saved_tensor'])
|
||||
|
||||
|
||||
def autograd_graph_analysis(graph: Graph) -> GraphInfo:
|
||||
@@ -87,10 +87,10 @@ def autograd_graph_analysis(graph: Graph) -> GraphInfo:
|
||||
def _peak_memory(deps: Dict[Node, int]):
|
||||
peak_mem = 0
|
||||
for k, v in deps.items():
|
||||
if v > 0 and is_phase(k, Phase.BACKWARD) and not any(map(is_inplace, k.users)):
|
||||
peak_mem += activation_size(k.meta['out'])
|
||||
if v <= float('-inf') and is_saved(k) and (k.target not in NORMALIZATION_ATEN):
|
||||
peak_mem -= activation_size(k.meta['out'])
|
||||
if v > 0 and is_phase(k, Phase.BACKWARD) and not all(map(is_inplace, k.users)) and not is_inplace(k):
|
||||
peak_mem += activation_size(k.meta['saved_tensor'])
|
||||
if v <= float('-inf') and is_phase(k, Phase.FORWARD):
|
||||
peak_mem -= activation_size(k.meta['saved_tensor'])
|
||||
return peak_mem
|
||||
|
||||
# deps is used to track all the memory dependencies of the graph.
|
||||
@@ -99,25 +99,25 @@ def autograd_graph_analysis(graph: Graph) -> GraphInfo:
|
||||
|
||||
for n in graph.nodes:
|
||||
n: Node
|
||||
if is_saved(n) and (n.target not in NORMALIZATION_ATEN) or any(map(lambda x: x.target in CLONE_ATEN, n.users)):
|
||||
# A forward tensor who is marked `save` but is not
|
||||
# an input to `loss` should be saved during forward.
|
||||
# If the tensor is a placeholder, then it belongs to `fwd_mem_in`.
|
||||
# Any `fwd_mem_in` should be kept in memory even this function
|
||||
# is checkpointed.
|
||||
# Otherwise, the tensor belongs to `fwd_mem_tmp`. If we checkpoint
|
||||
# the node, `fwd_mem_tmp` can be freed.
|
||||
if is_phase(n, Phase.PLACEHOLDER):
|
||||
graph_info.fwd_mem_in += activation_size(n.meta['out'])
|
||||
if is_phase(n, Phase.FORWARD):
|
||||
graph_info.fwd_mem_tmp += activation_size(n.meta['out'])
|
||||
deps[n] = len(n.users)
|
||||
# A forward tensor who is marked `save` but is also
|
||||
# an input to `Phase.FORWARD` should be saved during forward.
|
||||
# If the tensor is a placeholder, then it belongs to `fwd_mem_in`.
|
||||
# Any `fwd_mem_in` should be kept in memory even this function
|
||||
# is checkpointed.
|
||||
# Otherwise, the tensor belongs to `fwd_mem_tmp`. If we checkpoint
|
||||
# the node, `fwd_mem_tmp` can be freed.
|
||||
if is_phase(n, Phase.PLACEHOLDER):
|
||||
graph_info.save_fwd_in |= activation_size(n.meta['saved_tensor']) > 0
|
||||
if is_phase(n, Phase.FORWARD):
|
||||
graph_info.fwd_mem_tmp += activation_size(n.meta['saved_tensor'])
|
||||
elif is_phase(n, Phase.BACKWARD):
|
||||
if len(n.users):
|
||||
graph_info.bwd_mem_tmp = max(graph_info.bwd_mem_tmp, _peak_memory(deps))
|
||||
else:
|
||||
# TODO: some of the bwd_mem_out might be model parameters.
|
||||
# basically a backward node without user is a `grad_out` node
|
||||
graph_info.bwd_mem_out += activation_size(n.meta['out'])
|
||||
graph_info.bwd_mem_out += activation_size(n.meta['saved_tensor'])
|
||||
for input_n in n.all_input_nodes:
|
||||
if input_n in deps:
|
||||
deps[input_n] -= 1
|
||||
|
Reference in New Issue
Block a user