mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-04 18:40:28 +00:00
[hotfix/rotor] fix variable names (#1597)
* [fx] add some comment and docstrings. * [fx] add dataflow analysis for an autograd graph. * add intepretation for graph analysis. * [fx] before doing save_tensor_hooks. * [fx] provide an accurate estimation of memory except for GPT-2. * [fx] provide an accurate estimation of memory except for GPT-2. * [fx] provide an accurate estimation of memory except for GPT-2. * [fx] a very accurate version on GPT-2. * [fx] refactor code. * [fx] remove redundant inplace=True. * [fx] refactor code. * [fx] refactor code. * [fx] refactor code. * [fx] dive into backward memory. * [fx] fix variable names in ckpt_solvers and unskip tests. * [fx] commit my changes. * [fx] restore skips. * [fx] restore skips. * [fx] chaange stage into phase. * [fx] chaange stage into phase. * [fx] chaange stage into phase.
This commit is contained in:
@@ -1,11 +1,12 @@
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from functools import partial
|
||||
from typing import Dict
|
||||
from torch.fx import Graph, Node
|
||||
from .memory import activation_size
|
||||
|
||||
|
||||
class Stage(Enum):
|
||||
class Phase(Enum):
|
||||
FORWARD = 0
|
||||
LOSS = 1
|
||||
BACKWARD = 2
|
||||
@@ -48,24 +49,9 @@ class GraphInfo:
|
||||
bwd_mem_out: int = 0
|
||||
|
||||
|
||||
def is_forward(n: Node):
|
||||
assert 'stage' in n.meta, f'Node meta of {n} has no key `stage`!'
|
||||
return n.meta['stage'] == Stage.FORWARD
|
||||
|
||||
|
||||
def is_loss(n: Node):
|
||||
assert 'stage' in n.meta, f'Node meta of {n} has no key `stage`!'
|
||||
return n.meta['stage'] == Stage.LOSS
|
||||
|
||||
|
||||
def is_placeholder(n: Node):
|
||||
assert 'stage' in n.meta, f'Node meta of {n} has no key `stage`!'
|
||||
return n.meta['stage'] == Stage.PLACEHOLDER
|
||||
|
||||
|
||||
def is_backward(n: Node):
|
||||
assert 'stage' in n.meta, f'Node meta of {n} has no key `stage`!'
|
||||
return n.meta['stage'] == Stage.BACKWARD
|
||||
def is_phase(n: Node, phase: Phase) -> bool:
|
||||
assert 'phase' in n.meta, f'Node meta of {n} has no key `phase`!'
|
||||
return n.meta['phase'] == phase
|
||||
|
||||
|
||||
def is_saved(n: Node):
|
||||
@@ -74,7 +60,7 @@ def is_saved(n: Node):
|
||||
|
||||
def autograd_graph_analysis(graph: Graph) -> GraphInfo:
|
||||
"""Analyze the autograd node dependencies and find out the memory usage.
|
||||
Basically the input graph should have all nodes marked 'f' (forward), 'l' (loss), 'b' (backward) for keyword `stage`.
|
||||
Basically the input graph should have all nodes marked for keyword `phase`.
|
||||
Nodes should have attribute `out` indicating the output of each node.
|
||||
============================================================================
|
||||
Placeholder ----> p o <---- We need to keep track of grad out
|
||||
@@ -91,18 +77,18 @@ def autograd_graph_analysis(graph: Graph) -> GraphInfo:
|
||||
l
|
||||
=============================================================================
|
||||
Args:
|
||||
graph (Graph): The autograd graph with nodes marked 'f' (forward), 'l' (loss), 'b' (backward) for keyword `stage`.
|
||||
graph (Graph): The autograd graph with nodes marked for keyword `phase`.
|
||||
|
||||
Returns:
|
||||
graph_info (GraphInfo): Meta information for the dataflow.
|
||||
"""
|
||||
|
||||
def _peak_memory(deps: Dict[Node, int]):
|
||||
bwd_tmp = 0
|
||||
peak_mem = 0
|
||||
for k, v in deps.items():
|
||||
if v > 0:
|
||||
bwd_tmp += activation_size(k.meta['out'])
|
||||
return bwd_tmp
|
||||
peak_mem += activation_size(k.meta['out'])
|
||||
return peak_mem
|
||||
|
||||
# deps is used to track all the memory dependencies of the graph.
|
||||
deps = {}
|
||||
@@ -110,19 +96,19 @@ def autograd_graph_analysis(graph: Graph) -> GraphInfo:
|
||||
|
||||
for n in graph.nodes:
|
||||
n: Node
|
||||
if is_saved(n) and not any(map(is_loss, n.users)):
|
||||
if is_saved(n) and not any(map(partial(is_phase, phase=Phase.LOSS), n.users)):
|
||||
# A forward tensor who is marked `save` but is not
|
||||
# an input to `loss` should be saved during forward.
|
||||
# If the tensor is a placeholder, then it belongs to `fwd_in`.
|
||||
# Any `fwd_in` should be kept in memory even this function
|
||||
# If the tensor is a placeholder, then it belongs to `fwd_mem_in`.
|
||||
# Any `fwd_mem_in` should be kept in memory even this function
|
||||
# is checkpointed.
|
||||
# Otherwise, the tensor belongs to `fwd_tmp`. If we checkpoint
|
||||
# the node, `fwd_tmp` can be freed.
|
||||
if is_placeholder(n):
|
||||
# Otherwise, the tensor belongs to `fwd_mem_tmp`. If we checkpoint
|
||||
# the node, `fwd_mem_tmp` can be freed.
|
||||
if is_phase(n, Phase.PLACEHOLDER):
|
||||
graph_info.fwd_mem_in += activation_size(n.meta['out'])
|
||||
if is_forward(n):
|
||||
if is_phase(n, Phase.FORWARD):
|
||||
graph_info.fwd_mem_tmp += activation_size(n.meta['out'])
|
||||
elif is_backward(n):
|
||||
elif is_phase(n, Phase.BACKWARD):
|
||||
if len(n.users):
|
||||
# liveness analysis is only used in backward
|
||||
deps[n] = len(n.users)
|
||||
|
Reference in New Issue
Block a user