mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-04 10:34:41 +00:00
[hotfix/rotor] fix variable names (#1597)
* [fx] add some comment and docstrings. * [fx] add dataflow analysis for an autograd graph. * add intepretation for graph analysis. * [fx] before doing save_tensor_hooks. * [fx] provide an accurate estimation of memory except for GPT-2. * [fx] provide an accurate estimation of memory except for GPT-2. * [fx] provide an accurate estimation of memory except for GPT-2. * [fx] a very accurate version on GPT-2. * [fx] refactor code. * [fx] remove redundant inplace=True. * [fx] refactor code. * [fx] refactor code. * [fx] refactor code. * [fx] dive into backward memory. * [fx] fix variable names in ckpt_solvers and unskip tests. * [fx] commit my changes. * [fx] restore skips. * [fx] restore skips. * [fx] chaange stage into phase. * [fx] chaange stage into phase. * [fx] chaange stage into phase.
This commit is contained in:
@@ -73,10 +73,11 @@ def chen_greedy(gm: GraphModule) -> GraphModule:
|
||||
y = 0
|
||||
prev_idx = 2
|
||||
for (idx, n) in enumerate(gm.graph.nodes):
|
||||
temp += getattr(n, 'fwd_out')
|
||||
n: Node
|
||||
temp += n.meta['fwd_mem_out'] + n.meta['fwd_mem_tmp']
|
||||
y = max(y, temp)
|
||||
if temp > b and n in ckpt_nodes:
|
||||
x += getattr(n, 'fwd_out')
|
||||
x += n.meta['fwd_mem_out']
|
||||
temp = 0
|
||||
ckpt_intv.append((prev_idx, idx + 1))
|
||||
prev_idx = idx + 1
|
||||
|
@@ -1,11 +1,11 @@
|
||||
from typing import List, Set, Tuple, Dict
|
||||
from typing import List, Tuple
|
||||
import torch
|
||||
from torch.fx import GraphModule, Node
|
||||
from colossalai.fx.graph_module import ColoGraphModule
|
||||
from colossalai.fx.profiler import parameter_size
|
||||
import math
|
||||
from .linearize import linearize
|
||||
from .utils import *
|
||||
from colossalai.fx.profiler import profile_function, profile_module
|
||||
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
|
||||
from colossalai.fx.codegen.activation_checkpoint_codegen import _find_nested_ckpt_regions
|
||||
|
||||
@@ -25,8 +25,8 @@ def _compute_table(chain: Chain, mmax) -> Tuple:
|
||||
bw = chain.bweight ## backward time, not used
|
||||
cw = chain.cweight + [0] ## size of x (and of y)
|
||||
cbw = chain.cbweight + [0] ## size of xbar
|
||||
fwd_tmp = chain.fwd_tmp + [0]
|
||||
bwd_tmp = chain.bwd_tmp + [0]
|
||||
fwd_mem_tmp = chain.fwd_mem_tmp + [0]
|
||||
bwd_mem_tmp = chain.bwd_mem_tmp + [0]
|
||||
|
||||
# Build table
|
||||
opt = [[{} for _ in range(chain.length + 1)] for _ in range(mmax + 1)]
|
||||
@@ -37,7 +37,7 @@ def _compute_table(chain: Chain, mmax) -> Tuple:
|
||||
for m in range(mmax + 1):
|
||||
for i in range(chain.length + 1):
|
||||
#lmax-lmin = 0
|
||||
limit = max(cw[i + 1] + cbw[i + 1] + fwd_tmp[i], cw[i + 1] + cbw[i + 1] + bwd_tmp[i])
|
||||
limit = max(cw[i + 1] + cbw[i + 1] + fwd_mem_tmp[i], cw[i + 1] + cbw[i + 1] + bwd_mem_tmp[i])
|
||||
if m >= limit: ## Equation (1)
|
||||
opt[m][i][i] = fw[i] + bw[i]
|
||||
else:
|
||||
@@ -49,9 +49,9 @@ def _compute_table(chain: Chain, mmax) -> Tuple:
|
||||
for i in range(chain.length + 1 - d):
|
||||
# for idx in range(i+1, chain.length + 1):
|
||||
idx = i + d
|
||||
mmin = cw[idx + 1] + cw[i + 1] + fwd_tmp[i]
|
||||
mmin = cw[idx + 1] + cw[i + 1] + fwd_mem_tmp[i]
|
||||
if idx > i + 1:
|
||||
mmin = max(mmin, cw[idx + 1] + max(cw[j] + cw[j + 1] + fwd_tmp[j] for j in range(i + 1, idx)))
|
||||
mmin = max(mmin, cw[idx + 1] + max(cw[j] + cw[j + 1] + fwd_mem_tmp[j] for j in range(i + 1, idx)))
|
||||
if m < mmin:
|
||||
opt[m][i][idx] = float("inf")
|
||||
else:
|
||||
@@ -165,7 +165,7 @@ def _fwd_xbar(node: List[Node]) -> int:
|
||||
|
||||
xbar = 0
|
||||
for n in node:
|
||||
xbar += n.fwd_tmp + n.fwd_out
|
||||
xbar += n.meta['fwd_mem_tmp'] + n.meta['fwd_mem_out']
|
||||
return xbar
|
||||
|
||||
|
||||
@@ -183,7 +183,7 @@ def _fwd_time(node: List[Node]) -> int:
|
||||
fwd_time = 0
|
||||
for n in node:
|
||||
# minimum flop count is needed
|
||||
fwd_time += max(n.fwd_flop, 1)
|
||||
fwd_time += max(n.meta['fwd_flop'], 1)
|
||||
return fwd_time
|
||||
|
||||
|
||||
@@ -201,11 +201,11 @@ def _bwd_time(node: List[Node]) -> int:
|
||||
bwd_time = 0
|
||||
for n in node:
|
||||
# minimum flop count is needed
|
||||
bwd_time += max(n.bwd_flop, 1)
|
||||
bwd_time += max(n.meta['bwd_flop'], 1)
|
||||
return bwd_time
|
||||
|
||||
|
||||
def _get_bwd_tmp(node: List[Node]) -> int:
|
||||
def _get_bwd_mem_tmp(node: List[Node]) -> int:
|
||||
"""Get the backward temp memory of a node
|
||||
|
||||
Args:
|
||||
@@ -218,29 +218,32 @@ def _get_bwd_tmp(node: List[Node]) -> int:
|
||||
|
||||
def _get_deps_size():
|
||||
deps_size = 0
|
||||
for key in deps.keys():
|
||||
deps_size += key.bwd_out
|
||||
for k, v in deps.items():
|
||||
if v > 0:
|
||||
deps_size += k.meta['bwd_mem_out']
|
||||
|
||||
return deps_size
|
||||
|
||||
bwd_tmp = 0
|
||||
bwd_mem_tmp = 0
|
||||
deps = {}
|
||||
|
||||
# add all the users for last node into deps,
|
||||
# as those nodes' gradient out will be stored in memory
|
||||
for son in node[-1].users:
|
||||
deps[son] = 1
|
||||
for child in node[-1].users:
|
||||
deps[child] = 1
|
||||
for n in reversed(node):
|
||||
bwd_tmp = max(bwd_tmp, _get_deps_size() + n.bwd_tmp)
|
||||
deps[n] = len(n._input_nodes)
|
||||
for son in n.users:
|
||||
deps[son] -= 1
|
||||
bwd_mem_tmp = max(bwd_mem_tmp, _get_deps_size() + n.meta['bwd_mem_tmp'])
|
||||
|
||||
deps[n] = len(n.all_input_nodes)
|
||||
for child in n.users:
|
||||
if child in deps:
|
||||
deps[child] -= 1
|
||||
|
||||
for key in list(deps.keys()):
|
||||
if deps[key] == 0:
|
||||
del deps[key]
|
||||
|
||||
return bwd_tmp
|
||||
return bwd_mem_tmp
|
||||
|
||||
|
||||
def _construct_chain(node_list: List[List[Node]], data, mem_unit: int) -> Chain:
|
||||
@@ -267,7 +270,7 @@ def _construct_chain(node_list: List[List[Node]], data, mem_unit: int) -> Chain:
|
||||
bwd_time.append(_bwd_time(node))
|
||||
x_sizes.append(_compute_output_size(node))
|
||||
xbar_sizes.append(max(x_sizes[-1], _fwd_xbar(node)))
|
||||
tmp_bwd.append(_get_bwd_tmp(node))
|
||||
tmp_bwd.append(_get_bwd_mem_tmp(node))
|
||||
|
||||
# if a node with only one inplace op, we need to let x_bar = 0
|
||||
if len(node) == 1 and _get_inplace(node[0]):
|
||||
@@ -394,6 +397,7 @@ def solver_rotor(gm: ColoGraphModule,
|
||||
"""
|
||||
|
||||
node_list = linearize(gm, cnode)
|
||||
mem_limit -= parameter_size(gm)
|
||||
mem_unit = mem_limit * (1.0 - eps) // mem_slots
|
||||
MetaInfoProp(gm).run(data)
|
||||
chain: Chain = _construct_chain(node_list, data, mem_unit)
|
||||
|
@@ -5,24 +5,24 @@ class Chain:
|
||||
self.bweight = bw
|
||||
self.cweight = cw
|
||||
self.cbweight = cbw
|
||||
self.fwd_tmp = ftmp
|
||||
self.bwd_tmp = btmp
|
||||
self.fwd_mem_tmp = ftmp
|
||||
self.bwd_mem_tmp = btmp
|
||||
self.length = len(fw)
|
||||
if check and not self.check_lengths():
|
||||
raise AttributeError("In Chain, input lists do not have consistent lengths")
|
||||
|
||||
def check_lengths(self):
|
||||
return ((len(self.fweight) == self.length) and (len(self.bweight) == self.length + 1)
|
||||
and (len(self.cweight) == self.length + 1) and (len(self.fwd_tmp) == self.length)
|
||||
and (len(self.bwd_tmp) == self.length + 1) and (len(self.cbweight) == self.length + 1))
|
||||
and (len(self.cweight) == self.length + 1) and (len(self.fwd_mem_tmp) == self.length)
|
||||
and (len(self.bwd_mem_tmp) == self.length + 1) and (len(self.cbweight) == self.length + 1))
|
||||
|
||||
def __repr__(self):
|
||||
chain_list = []
|
||||
for i in range(self.length):
|
||||
chain_list.append(
|
||||
(self.fweight[i], self.bweight[i], self.cweight[i], self.cbweight[i], self.fwd_tmp[i], self.bwd_tmp[i]))
|
||||
chain_list.append((self.fweight[i], self.bweight[i], self.cweight[i], self.cbweight[i], self.fwd_mem_tmp[i],
|
||||
self.bwd_mem_tmp[i]))
|
||||
i = self.length
|
||||
chain_list.append((None, self.bweight[i], self.cweight[i], self.cbweight[i], None, self.bwd_tmp[i]))
|
||||
chain_list.append((None, self.bweight[i], self.cweight[i], self.cbweight[i], None, self.bwd_mem_tmp[i]))
|
||||
return chain_list.__repr__()
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user