[hotfix/rotor] fix variable names (#1597)

* [fx] add some comment and docstrings.

* [fx] add dataflow analysis for an autograd graph.

* add intepretation for graph analysis.

* [fx] before doing save_tensor_hooks.

* [fx] provide an accurate estimation of memory except for GPT-2.

* [fx] provide an accurate estimation of memory except for GPT-2.

* [fx] provide an accurate estimation of memory except for GPT-2.

* [fx] a very accurate version on GPT-2.

* [fx] refactor code.

* [fx] remove redundant inplace=True.

* [fx] refactor code.

* [fx] refactor code.

* [fx] refactor code.

* [fx] dive into backward memory.

* [fx] fix variable names in ckpt_solvers and unskip tests.

* [fx] commit my changes.

* [fx] restore skips.

* [fx] restore skips.

* [fx] chaange stage into phase.

* [fx] chaange stage into phase.

* [fx] chaange stage into phase.
This commit is contained in:
Super Daniel
2022-09-14 14:27:04 +08:00
committed by GitHub
parent faa23b9d9a
commit c8e9b2ad78
7 changed files with 85 additions and 83 deletions

View File

@@ -1,11 +1,12 @@
from dataclasses import dataclass
from enum import Enum
from functools import partial
from typing import Dict
from torch.fx import Graph, Node
from .memory import activation_size
class Stage(Enum):
class Phase(Enum):
FORWARD = 0
LOSS = 1
BACKWARD = 2
@@ -48,24 +49,9 @@ class GraphInfo:
bwd_mem_out: int = 0
def is_forward(n: Node):
assert 'stage' in n.meta, f'Node meta of {n} has no key `stage`!'
return n.meta['stage'] == Stage.FORWARD
def is_loss(n: Node):
assert 'stage' in n.meta, f'Node meta of {n} has no key `stage`!'
return n.meta['stage'] == Stage.LOSS
def is_placeholder(n: Node):
assert 'stage' in n.meta, f'Node meta of {n} has no key `stage`!'
return n.meta['stage'] == Stage.PLACEHOLDER
def is_backward(n: Node):
assert 'stage' in n.meta, f'Node meta of {n} has no key `stage`!'
return n.meta['stage'] == Stage.BACKWARD
def is_phase(n: Node, phase: Phase) -> bool:
assert 'phase' in n.meta, f'Node meta of {n} has no key `phase`!'
return n.meta['phase'] == phase
def is_saved(n: Node):
@@ -74,7 +60,7 @@ def is_saved(n: Node):
def autograd_graph_analysis(graph: Graph) -> GraphInfo:
"""Analyze the autograd node dependencies and find out the memory usage.
Basically the input graph should have all nodes marked 'f' (forward), 'l' (loss), 'b' (backward) for keyword `stage`.
Basically the input graph should have all nodes marked for keyword `phase`.
Nodes should have attribute `out` indicating the output of each node.
============================================================================
Placeholder ----> p o <---- We need to keep track of grad out
@@ -91,18 +77,18 @@ def autograd_graph_analysis(graph: Graph) -> GraphInfo:
l
=============================================================================
Args:
graph (Graph): The autograd graph with nodes marked 'f' (forward), 'l' (loss), 'b' (backward) for keyword `stage`.
graph (Graph): The autograd graph with nodes marked for keyword `phase`.
Returns:
graph_info (GraphInfo): Meta information for the dataflow.
"""
def _peak_memory(deps: Dict[Node, int]):
bwd_tmp = 0
peak_mem = 0
for k, v in deps.items():
if v > 0:
bwd_tmp += activation_size(k.meta['out'])
return bwd_tmp
peak_mem += activation_size(k.meta['out'])
return peak_mem
# deps is used to track all the memory dependencies of the graph.
deps = {}
@@ -110,19 +96,19 @@ def autograd_graph_analysis(graph: Graph) -> GraphInfo:
for n in graph.nodes:
n: Node
if is_saved(n) and not any(map(is_loss, n.users)):
if is_saved(n) and not any(map(partial(is_phase, phase=Phase.LOSS), n.users)):
# A forward tensor who is marked `save` but is not
# an input to `loss` should be saved during forward.
# If the tensor is a placeholder, then it belongs to `fwd_in`.
# Any `fwd_in` should be kept in memory even this function
# If the tensor is a placeholder, then it belongs to `fwd_mem_in`.
# Any `fwd_mem_in` should be kept in memory even this function
# is checkpointed.
# Otherwise, the tensor belongs to `fwd_tmp`. If we checkpoint
# the node, `fwd_tmp` can be freed.
if is_placeholder(n):
# Otherwise, the tensor belongs to `fwd_mem_tmp`. If we checkpoint
# the node, `fwd_mem_tmp` can be freed.
if is_phase(n, Phase.PLACEHOLDER):
graph_info.fwd_mem_in += activation_size(n.meta['out'])
if is_forward(n):
if is_phase(n, Phase.FORWARD):
graph_info.fwd_mem_tmp += activation_size(n.meta['out'])
elif is_backward(n):
elif is_phase(n, Phase.BACKWARD):
if len(n.users):
# liveness analysis is only used in backward
deps[n] = len(n.users)