[hotfix/rotor] fix variable names (#1597)

* [fx] add some comment and docstrings.

* [fx] add dataflow analysis for an autograd graph.

* add intepretation for graph analysis.

* [fx] before doing save_tensor_hooks.

* [fx] provide an accurate estimation of memory except for GPT-2.

* [fx] provide an accurate estimation of memory except for GPT-2.

* [fx] provide an accurate estimation of memory except for GPT-2.

* [fx] a very accurate version on GPT-2.

* [fx] refactor code.

* [fx] remove redundant inplace=True.

* [fx] refactor code.

* [fx] refactor code.

* [fx] refactor code.

* [fx] dive into backward memory.

* [fx] fix variable names in ckpt_solvers and unskip tests.

* [fx] commit my changes.

* [fx] restore skips.

* [fx] restore skips.

* [fx] chaange stage into phase.

* [fx] chaange stage into phase.

* [fx] chaange stage into phase.
This commit is contained in:
Super Daniel
2022-09-14 14:27:04 +08:00
committed by GitHub
parent faa23b9d9a
commit c8e9b2ad78
7 changed files with 85 additions and 83 deletions

View File

@@ -73,10 +73,11 @@ def chen_greedy(gm: GraphModule) -> GraphModule:
y = 0
prev_idx = 2
for (idx, n) in enumerate(gm.graph.nodes):
temp += getattr(n, 'fwd_out')
n: Node
temp += n.meta['fwd_mem_out'] + n.meta['fwd_mem_tmp']
y = max(y, temp)
if temp > b and n in ckpt_nodes:
x += getattr(n, 'fwd_out')
x += n.meta['fwd_mem_out']
temp = 0
ckpt_intv.append((prev_idx, idx + 1))
prev_idx = idx + 1

View File

@@ -1,11 +1,11 @@
from typing import List, Set, Tuple, Dict
from typing import List, Tuple
import torch
from torch.fx import GraphModule, Node
from colossalai.fx.graph_module import ColoGraphModule
from colossalai.fx.profiler import parameter_size
import math
from .linearize import linearize
from .utils import *
from colossalai.fx.profiler import profile_function, profile_module
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
from colossalai.fx.codegen.activation_checkpoint_codegen import _find_nested_ckpt_regions
@@ -25,8 +25,8 @@ def _compute_table(chain: Chain, mmax) -> Tuple:
bw = chain.bweight ## backward time, not used
cw = chain.cweight + [0] ## size of x (and of y)
cbw = chain.cbweight + [0] ## size of xbar
fwd_tmp = chain.fwd_tmp + [0]
bwd_tmp = chain.bwd_tmp + [0]
fwd_mem_tmp = chain.fwd_mem_tmp + [0]
bwd_mem_tmp = chain.bwd_mem_tmp + [0]
# Build table
opt = [[{} for _ in range(chain.length + 1)] for _ in range(mmax + 1)]
@@ -37,7 +37,7 @@ def _compute_table(chain: Chain, mmax) -> Tuple:
for m in range(mmax + 1):
for i in range(chain.length + 1):
#lmax-lmin = 0
limit = max(cw[i + 1] + cbw[i + 1] + fwd_tmp[i], cw[i + 1] + cbw[i + 1] + bwd_tmp[i])
limit = max(cw[i + 1] + cbw[i + 1] + fwd_mem_tmp[i], cw[i + 1] + cbw[i + 1] + bwd_mem_tmp[i])
if m >= limit: ## Equation (1)
opt[m][i][i] = fw[i] + bw[i]
else:
@@ -49,9 +49,9 @@ def _compute_table(chain: Chain, mmax) -> Tuple:
for i in range(chain.length + 1 - d):
# for idx in range(i+1, chain.length + 1):
idx = i + d
mmin = cw[idx + 1] + cw[i + 1] + fwd_tmp[i]
mmin = cw[idx + 1] + cw[i + 1] + fwd_mem_tmp[i]
if idx > i + 1:
mmin = max(mmin, cw[idx + 1] + max(cw[j] + cw[j + 1] + fwd_tmp[j] for j in range(i + 1, idx)))
mmin = max(mmin, cw[idx + 1] + max(cw[j] + cw[j + 1] + fwd_mem_tmp[j] for j in range(i + 1, idx)))
if m < mmin:
opt[m][i][idx] = float("inf")
else:
@@ -165,7 +165,7 @@ def _fwd_xbar(node: List[Node]) -> int:
xbar = 0
for n in node:
xbar += n.fwd_tmp + n.fwd_out
xbar += n.meta['fwd_mem_tmp'] + n.meta['fwd_mem_out']
return xbar
@@ -183,7 +183,7 @@ def _fwd_time(node: List[Node]) -> int:
fwd_time = 0
for n in node:
# minimum flop count is needed
fwd_time += max(n.fwd_flop, 1)
fwd_time += max(n.meta['fwd_flop'], 1)
return fwd_time
@@ -201,11 +201,11 @@ def _bwd_time(node: List[Node]) -> int:
bwd_time = 0
for n in node:
# minimum flop count is needed
bwd_time += max(n.bwd_flop, 1)
bwd_time += max(n.meta['bwd_flop'], 1)
return bwd_time
def _get_bwd_tmp(node: List[Node]) -> int:
def _get_bwd_mem_tmp(node: List[Node]) -> int:
"""Get the backward temp memory of a node
Args:
@@ -218,29 +218,32 @@ def _get_bwd_tmp(node: List[Node]) -> int:
def _get_deps_size():
deps_size = 0
for key in deps.keys():
deps_size += key.bwd_out
for k, v in deps.items():
if v > 0:
deps_size += k.meta['bwd_mem_out']
return deps_size
bwd_tmp = 0
bwd_mem_tmp = 0
deps = {}
# add all the users for last node into deps,
# as those nodes' gradient out will be stored in memory
for son in node[-1].users:
deps[son] = 1
for child in node[-1].users:
deps[child] = 1
for n in reversed(node):
bwd_tmp = max(bwd_tmp, _get_deps_size() + n.bwd_tmp)
deps[n] = len(n._input_nodes)
for son in n.users:
deps[son] -= 1
bwd_mem_tmp = max(bwd_mem_tmp, _get_deps_size() + n.meta['bwd_mem_tmp'])
deps[n] = len(n.all_input_nodes)
for child in n.users:
if child in deps:
deps[child] -= 1
for key in list(deps.keys()):
if deps[key] == 0:
del deps[key]
return bwd_tmp
return bwd_mem_tmp
def _construct_chain(node_list: List[List[Node]], data, mem_unit: int) -> Chain:
@@ -267,7 +270,7 @@ def _construct_chain(node_list: List[List[Node]], data, mem_unit: int) -> Chain:
bwd_time.append(_bwd_time(node))
x_sizes.append(_compute_output_size(node))
xbar_sizes.append(max(x_sizes[-1], _fwd_xbar(node)))
tmp_bwd.append(_get_bwd_tmp(node))
tmp_bwd.append(_get_bwd_mem_tmp(node))
# if a node with only one inplace op, we need to let x_bar = 0
if len(node) == 1 and _get_inplace(node[0]):
@@ -394,6 +397,7 @@ def solver_rotor(gm: ColoGraphModule,
"""
node_list = linearize(gm, cnode)
mem_limit -= parameter_size(gm)
mem_unit = mem_limit * (1.0 - eps) // mem_slots
MetaInfoProp(gm).run(data)
chain: Chain = _construct_chain(node_list, data, mem_unit)

View File

@@ -5,24 +5,24 @@ class Chain:
self.bweight = bw
self.cweight = cw
self.cbweight = cbw
self.fwd_tmp = ftmp
self.bwd_tmp = btmp
self.fwd_mem_tmp = ftmp
self.bwd_mem_tmp = btmp
self.length = len(fw)
if check and not self.check_lengths():
raise AttributeError("In Chain, input lists do not have consistent lengths")
def check_lengths(self):
return ((len(self.fweight) == self.length) and (len(self.bweight) == self.length + 1)
and (len(self.cweight) == self.length + 1) and (len(self.fwd_tmp) == self.length)
and (len(self.bwd_tmp) == self.length + 1) and (len(self.cbweight) == self.length + 1))
and (len(self.cweight) == self.length + 1) and (len(self.fwd_mem_tmp) == self.length)
and (len(self.bwd_mem_tmp) == self.length + 1) and (len(self.cbweight) == self.length + 1))
def __repr__(self):
chain_list = []
for i in range(self.length):
chain_list.append(
(self.fweight[i], self.bweight[i], self.cweight[i], self.cbweight[i], self.fwd_tmp[i], self.bwd_tmp[i]))
chain_list.append((self.fweight[i], self.bweight[i], self.cweight[i], self.cbweight[i], self.fwd_mem_tmp[i],
self.bwd_mem_tmp[i]))
i = self.length
chain_list.append((None, self.bweight[i], self.cweight[i], self.cbweight[i], None, self.bwd_tmp[i]))
chain_list.append((None, self.bweight[i], self.cweight[i], self.cbweight[i], None, self.bwd_mem_tmp[i]))
return chain_list.__repr__()