[fx/tuning] tune performance on rotor with meta info. (#1599)

This commit is contained in:
Super Daniel
2022-09-15 14:46:36 +08:00
committed by GitHub
parent a7cda6f57d
commit cd5cf2bcc9
7 changed files with 96 additions and 107 deletions

View File

@@ -1,16 +1,17 @@
from dataclasses import dataclass
from enum import Enum
from functools import partial
from typing import Dict
from torch.fx import Graph, Node
from .memory import activation_size
from .memory import activation_size, is_inplace
from . import META_COMPATIBILITY
if META_COMPATIBILITY:
from .memory import NORMALIZATION_ATEN, CLONE_ATEN
class Phase(Enum):
FORWARD = 0
LOSS = 1
BACKWARD = 2
PLACEHOLDER = 3
BACKWARD = 1
PLACEHOLDER = 2
@dataclass
@@ -86,8 +87,10 @@ def autograd_graph_analysis(graph: Graph) -> GraphInfo:
def _peak_memory(deps: Dict[Node, int]):
peak_mem = 0
for k, v in deps.items():
if v > 0:
if v > 0 and is_phase(k, Phase.BACKWARD) and not any(map(is_inplace, k.users)):
peak_mem += activation_size(k.meta['out'])
if v <= float('-inf') and is_saved(k) and (k.target not in NORMALIZATION_ATEN):
peak_mem -= activation_size(k.meta['out'])
return peak_mem
# deps is used to track all the memory dependencies of the graph.
@@ -96,7 +99,7 @@ def autograd_graph_analysis(graph: Graph) -> GraphInfo:
for n in graph.nodes:
n: Node
if is_saved(n) and not any(map(partial(is_phase, phase=Phase.LOSS), n.users)):
if is_saved(n) and (n.target not in NORMALIZATION_ATEN) or any(map(lambda x: x.target in CLONE_ATEN, n.users)):
# A forward tensor who is marked `save` but is not
# an input to `loss` should be saved during forward.
# If the tensor is a placeholder, then it belongs to `fwd_mem_in`.
@@ -110,13 +113,14 @@ def autograd_graph_analysis(graph: Graph) -> GraphInfo:
graph_info.fwd_mem_tmp += activation_size(n.meta['out'])
elif is_phase(n, Phase.BACKWARD):
if len(n.users):
# liveness analysis is only used in backward
deps[n] = len(n.users)
graph_info.bwd_mem_tmp = max(graph_info.bwd_mem_tmp, _peak_memory(deps))
for input_n in n.all_input_nodes:
if input_n in deps:
deps[input_n] -= 1
else:
# TODO: some of the bwd_mem_out might be model parameters.
# basically a backward node without user is a `grad_out` node
graph_info.bwd_mem_out += activation_size(n.meta['out'])
for input_n in n.all_input_nodes:
if input_n in deps:
deps[input_n] -= 1
if deps[input_n] <= 0:
deps[input_n] = float('-inf')
return graph_info