mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-04 02:26:51 +00:00
[fx/tuning] tune performance on rotor with meta info. (#1599)
This commit is contained in:
@@ -1,16 +1,17 @@
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from functools import partial
|
||||
from typing import Dict
|
||||
from torch.fx import Graph, Node
|
||||
from .memory import activation_size
|
||||
from .memory import activation_size, is_inplace
|
||||
from . import META_COMPATIBILITY
|
||||
if META_COMPATIBILITY:
|
||||
from .memory import NORMALIZATION_ATEN, CLONE_ATEN
|
||||
|
||||
|
||||
class Phase(Enum):
|
||||
FORWARD = 0
|
||||
LOSS = 1
|
||||
BACKWARD = 2
|
||||
PLACEHOLDER = 3
|
||||
BACKWARD = 1
|
||||
PLACEHOLDER = 2
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -86,8 +87,10 @@ def autograd_graph_analysis(graph: Graph) -> GraphInfo:
|
||||
def _peak_memory(deps: Dict[Node, int]):
|
||||
peak_mem = 0
|
||||
for k, v in deps.items():
|
||||
if v > 0:
|
||||
if v > 0 and is_phase(k, Phase.BACKWARD) and not any(map(is_inplace, k.users)):
|
||||
peak_mem += activation_size(k.meta['out'])
|
||||
if v <= float('-inf') and is_saved(k) and (k.target not in NORMALIZATION_ATEN):
|
||||
peak_mem -= activation_size(k.meta['out'])
|
||||
return peak_mem
|
||||
|
||||
# deps is used to track all the memory dependencies of the graph.
|
||||
@@ -96,7 +99,7 @@ def autograd_graph_analysis(graph: Graph) -> GraphInfo:
|
||||
|
||||
for n in graph.nodes:
|
||||
n: Node
|
||||
if is_saved(n) and not any(map(partial(is_phase, phase=Phase.LOSS), n.users)):
|
||||
if is_saved(n) and (n.target not in NORMALIZATION_ATEN) or any(map(lambda x: x.target in CLONE_ATEN, n.users)):
|
||||
# A forward tensor who is marked `save` but is not
|
||||
# an input to `loss` should be saved during forward.
|
||||
# If the tensor is a placeholder, then it belongs to `fwd_mem_in`.
|
||||
@@ -110,13 +113,14 @@ def autograd_graph_analysis(graph: Graph) -> GraphInfo:
|
||||
graph_info.fwd_mem_tmp += activation_size(n.meta['out'])
|
||||
elif is_phase(n, Phase.BACKWARD):
|
||||
if len(n.users):
|
||||
# liveness analysis is only used in backward
|
||||
deps[n] = len(n.users)
|
||||
graph_info.bwd_mem_tmp = max(graph_info.bwd_mem_tmp, _peak_memory(deps))
|
||||
for input_n in n.all_input_nodes:
|
||||
if input_n in deps:
|
||||
deps[input_n] -= 1
|
||||
else:
|
||||
# TODO: some of the bwd_mem_out might be model parameters.
|
||||
# basically a backward node without user is a `grad_out` node
|
||||
graph_info.bwd_mem_out += activation_size(n.meta['out'])
|
||||
for input_n in n.all_input_nodes:
|
||||
if input_n in deps:
|
||||
deps[input_n] -= 1
|
||||
if deps[input_n] <= 0:
|
||||
deps[input_n] = float('-inf')
|
||||
return graph_info
|
||||
|
Reference in New Issue
Block a user