mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-14 21:51:57 +00:00
[fx] Improve linearize and rotor solver (#1586)
* [fx] add nested activation_checkpoint codegen * undo algorithms commits * solver * undo some commits * [fx] torch11 add nested activation checkpoint codegen * remove some imports * [fx] add some comments in activation codegen * [fx] codegen instance error fix * [fx] imporve linearize and rotor solver * [fx] some comments and format modification
This commit is contained in:
@@ -7,6 +7,7 @@ from .linearize import linearize
|
|||||||
from .utils import *
|
from .utils import *
|
||||||
from colossalai.fx.profiler import profile_function, profile_module
|
from colossalai.fx.profiler import profile_function, profile_module
|
||||||
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
|
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
|
||||||
|
from colossalai.fx.codegen.activation_checkpoint_codegen import _find_nested_ckpt_regions
|
||||||
|
|
||||||
|
|
||||||
# this is the python compute table code from rotor
|
# this is the python compute table code from rotor
|
||||||
@@ -36,7 +37,7 @@ def _compute_table(chain: Chain, mmax) -> Tuple:
|
|||||||
for m in range(mmax + 1):
|
for m in range(mmax + 1):
|
||||||
for i in range(chain.length + 1):
|
for i in range(chain.length + 1):
|
||||||
#lmax-lmin = 0
|
#lmax-lmin = 0
|
||||||
limit = max(cw[i + 1] + cbw[i + 1] + fwd_tmp[i], cw[i] + cw[i + 1] + cbw[i + 1] + bwd_tmp[i])
|
limit = max(cw[i + 1] + cbw[i + 1] + fwd_tmp[i], cw[i + 1] + cbw[i + 1] + bwd_tmp[i])
|
||||||
if m >= limit: ## Equation (1)
|
if m >= limit: ## Equation (1)
|
||||||
opt[m][i][i] = fw[i] + bw[i]
|
opt[m][i][i] = fw[i] + bw[i]
|
||||||
else:
|
else:
|
||||||
@@ -151,6 +152,97 @@ def _get_inplace(node: Node) -> bool:
|
|||||||
return is_inplace
|
return is_inplace
|
||||||
|
|
||||||
|
|
||||||
|
def _fwd_xbar(node: List[Node]) -> int:
|
||||||
|
"""Get the forward xbar of a node
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node (List[Node]): List of torch.fx Node,
|
||||||
|
indicates a node in linearized graph
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
int: xbar size, unit Byte
|
||||||
|
"""
|
||||||
|
|
||||||
|
xbar = 0
|
||||||
|
for n in node:
|
||||||
|
xbar += n.fwd_tmp + n.fwd_out
|
||||||
|
return xbar
|
||||||
|
|
||||||
|
|
||||||
|
def _fwd_time(node: List[Node]) -> int:
|
||||||
|
"""Get the foward time of a node
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node (List[Node]): List of torch.fx Node,
|
||||||
|
indicates a node in linearized graph
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
int: foward time, extimated by flops count
|
||||||
|
"""
|
||||||
|
|
||||||
|
fwd_time = 0
|
||||||
|
for n in node:
|
||||||
|
# minimum flop count is needed
|
||||||
|
fwd_time += max(n.fwd_flop, 1)
|
||||||
|
return fwd_time
|
||||||
|
|
||||||
|
|
||||||
|
def _bwd_time(node: List[Node]) -> int:
|
||||||
|
"""Get the backward time of a node
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node (List[Node]): List of torch.fx Node,
|
||||||
|
indicates a node in linearized graph
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
int: backward time, extimated by flops count
|
||||||
|
"""
|
||||||
|
|
||||||
|
bwd_time = 0
|
||||||
|
for n in node:
|
||||||
|
# minimum flop count is needed
|
||||||
|
bwd_time += max(n.bwd_flop, 1)
|
||||||
|
return bwd_time
|
||||||
|
|
||||||
|
|
||||||
|
def _get_bwd_tmp(node: List[Node]) -> int:
|
||||||
|
"""Get the backward temp memory of a node
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node (List[Node]): List of torch.fx Node,
|
||||||
|
indicates a node in linearized graph
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
int: backward temp memory, unit Byte
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _get_deps_size():
|
||||||
|
deps_size = 0
|
||||||
|
for key in deps.keys():
|
||||||
|
deps_size += key.bwd_out
|
||||||
|
|
||||||
|
return deps_size
|
||||||
|
|
||||||
|
bwd_tmp = 0
|
||||||
|
deps = {}
|
||||||
|
|
||||||
|
# add all the users for last node into deps,
|
||||||
|
# as those nodes' gradient out will be stored in memory
|
||||||
|
for son in node[-1].users:
|
||||||
|
deps[son] = 1
|
||||||
|
for n in reversed(node):
|
||||||
|
bwd_tmp = max(bwd_tmp, _get_deps_size() + n.bwd_tmp)
|
||||||
|
deps[n] = len(n._input_nodes)
|
||||||
|
for son in n.users:
|
||||||
|
deps[son] -= 1
|
||||||
|
|
||||||
|
for key in list(deps.keys()):
|
||||||
|
if deps[key] == 0:
|
||||||
|
del deps[key]
|
||||||
|
|
||||||
|
return bwd_tmp
|
||||||
|
|
||||||
|
|
||||||
def _construct_chain(node_list: List[List[Node]], data, mem_unit: int) -> Chain:
|
def _construct_chain(node_list: List[List[Node]], data, mem_unit: int) -> Chain:
|
||||||
|
|
||||||
fwd_time = []
|
fwd_time = []
|
||||||
@@ -160,45 +252,32 @@ def _construct_chain(node_list: List[List[Node]], data, mem_unit: int) -> Chain:
|
|||||||
xbar_sizes = [_compute_size(data)]
|
xbar_sizes = [_compute_size(data)]
|
||||||
x_sizes = [_compute_size(data)]
|
x_sizes = [_compute_size(data)]
|
||||||
elif isinstance(data, list) or isinstance(data, tuple):
|
elif isinstance(data, list) or isinstance(data, tuple):
|
||||||
xbar_sizes = [_compute_size(obj) for obj in data]
|
xbar_sizes = [sum([_compute_size(obj) for obj in data])]
|
||||||
x_sizes = [_compute_size(obj) for obj in data]
|
x_sizes = [sum([_compute_size(obj) for obj in data])]
|
||||||
elif isinstance(data, dict):
|
elif isinstance(data, dict):
|
||||||
xbar_sizes = [_compute_size(obj) for obj in data.values()]
|
xbar_sizes = [sum([_compute_size(obj) for obj in data.values()])]
|
||||||
x_sizes = [_compute_size(obj) for obj in data.values()]
|
x_sizes = [sum([_compute_size(obj) for obj in data.values()])]
|
||||||
|
|
||||||
# currently we can't get the temp memory needed in fwd and bwd
|
# currently we can't get the temp memory needed in fwd
|
||||||
tmp_fwd = [0] * len(node_list)
|
tmp_fwd = [0] * len(node_list)
|
||||||
tmp_bwd = [0] * (len(node_list) + 1)
|
tmp_bwd = []
|
||||||
|
|
||||||
for idx, node in enumerate(node_list):
|
for idx, node in enumerate(node_list):
|
||||||
fwd_time.append(0)
|
fwd_time.append(_fwd_time(node))
|
||||||
bwd_time.append(0)
|
bwd_time.append(_bwd_time(node))
|
||||||
xbar_sizes.append(0)
|
|
||||||
x_sizes.append(_compute_output_size(node))
|
x_sizes.append(_compute_output_size(node))
|
||||||
|
xbar_sizes.append(max(x_sizes[-1], _fwd_xbar(node)))
|
||||||
|
tmp_bwd.append(_get_bwd_tmp(node))
|
||||||
|
|
||||||
_check_inplace_flag = 1
|
# if a node with only one inplace op, we need to let x_bar = 0
|
||||||
for n in node:
|
if len(node) == 1 and _get_inplace(node[0]):
|
||||||
fwd_time[-1] += max(n.__flops__, 1)
|
xbar_sizes[-1] = 0
|
||||||
|
|
||||||
# currently we haven't patched the backward flops count
|
|
||||||
bwd_time[-1] += max(n.__flops__ * 2, 2)
|
|
||||||
xbar_sizes[-1] += n.__activation__
|
|
||||||
|
|
||||||
# we need to clear the xbar of previous node as there is
|
|
||||||
# one op in the current node that use the previous node's
|
|
||||||
# output but applies inplace operation on it
|
|
||||||
# NOTE: This process should be done only once as the previous
|
|
||||||
# node will only have one output
|
|
||||||
if _check_inplace_flag:
|
|
||||||
for par in n._input_nodes:
|
|
||||||
if par not in node and _get_inplace(n):
|
|
||||||
xbar_sizes[-2] -= x_sizes[-2]
|
|
||||||
_check_inplace_flag = 0
|
|
||||||
|
|
||||||
xbar_sizes[-1] = max(xbar_sizes[-1], x_sizes[-1])
|
|
||||||
|
|
||||||
bwd_time.append(0)
|
bwd_time.append(0)
|
||||||
|
|
||||||
|
# currently we view loss backward temp as zero
|
||||||
|
tmp_bwd.append(0)
|
||||||
|
|
||||||
xbar_sizes = _discretize(mem_unit, xbar_sizes)
|
xbar_sizes = _discretize(mem_unit, xbar_sizes)
|
||||||
x_sizes = _discretize(mem_unit, x_sizes)
|
x_sizes = _discretize(mem_unit, x_sizes)
|
||||||
tmp_fwd = _discretize(mem_unit, tmp_fwd)
|
tmp_fwd = _discretize(mem_unit, tmp_fwd)
|
||||||
@@ -207,14 +286,17 @@ def _construct_chain(node_list: List[List[Node]], data, mem_unit: int) -> Chain:
|
|||||||
return Chain(fwd_time, bwd_time, x_sizes, xbar_sizes, tmp_fwd, tmp_bwd)
|
return Chain(fwd_time, bwd_time, x_sizes, xbar_sizes, tmp_fwd, tmp_bwd)
|
||||||
|
|
||||||
|
|
||||||
def _annotate_from_sequence(sequence: Sequence, node_list: List[List[Node]]) -> GraphModule:
|
def _annotate_from_sequence(sequence: Sequence, node_list: List[List[Node]]):
|
||||||
op_list = sequence.list_operations()
|
op_list = sequence.list_operations()
|
||||||
loss_op = next(op for op in op_list if isinstance(op, Loss))
|
loss_op = next(op for op in op_list if isinstance(op, Loss))
|
||||||
op_list = op_list[:op_list.index(loss_op)]
|
fwd_list = op_list[:op_list.index(loss_op)]
|
||||||
|
bwd_list = op_list[op_list.index(loss_op) + 1:]
|
||||||
ckpt_idx = 0
|
ckpt_idx = 0
|
||||||
in_ckpt = False
|
in_ckpt = False
|
||||||
ckpt_region = []
|
ckpt_region = []
|
||||||
for idx, op in enumerate(op_list, 0):
|
|
||||||
|
# forward annotation
|
||||||
|
for idx, op in enumerate(fwd_list, 0):
|
||||||
if in_ckpt:
|
if in_ckpt:
|
||||||
if isinstance(op, ForwardNograd):
|
if isinstance(op, ForwardNograd):
|
||||||
ckpt_region.append(idx)
|
ckpt_region.append(idx)
|
||||||
@@ -223,7 +305,7 @@ def _annotate_from_sequence(sequence: Sequence, node_list: List[List[Node]]) ->
|
|||||||
in_ckpt = False
|
in_ckpt = False
|
||||||
for node_idx in ckpt_region:
|
for node_idx in ckpt_region:
|
||||||
for n in node_list[node_idx]:
|
for n in node_list[node_idx]:
|
||||||
setattr(n, "activation_checkpoint", ckpt_idx)
|
setattr(n, "activation_checkpoint", [ckpt_idx])
|
||||||
|
|
||||||
ckpt_idx += 1
|
ckpt_idx += 1
|
||||||
ckpt_region = []
|
ckpt_region = []
|
||||||
@@ -231,7 +313,7 @@ def _annotate_from_sequence(sequence: Sequence, node_list: List[List[Node]]) ->
|
|||||||
elif isinstance(op, ForwardCheck):
|
elif isinstance(op, ForwardCheck):
|
||||||
for node_idx in ckpt_region:
|
for node_idx in ckpt_region:
|
||||||
for n in node_list[node_idx]:
|
for n in node_list[node_idx]:
|
||||||
setattr(n, "activation_checkpoint", ckpt_idx)
|
setattr(n, "activation_checkpoint", [ckpt_idx])
|
||||||
|
|
||||||
ckpt_idx += 1
|
ckpt_idx += 1
|
||||||
ckpt_region = [idx]
|
ckpt_region = [idx]
|
||||||
@@ -241,12 +323,62 @@ def _annotate_from_sequence(sequence: Sequence, node_list: List[List[Node]]) ->
|
|||||||
in_ckpt = True
|
in_ckpt = True
|
||||||
ckpt_region.append(idx)
|
ckpt_region.append(idx)
|
||||||
|
|
||||||
|
# annotate the backward if there is any nested activation checkpoint
|
||||||
|
in_recompute = False
|
||||||
|
for op in bwd_list:
|
||||||
|
if in_recompute:
|
||||||
|
if isinstance(op, ForwardNograd):
|
||||||
|
ckpt_region.append(op.index)
|
||||||
|
|
||||||
|
elif isinstance(op, ForwardEnable):
|
||||||
|
for node_idx in ckpt_region:
|
||||||
|
for n in node_list[node_idx]:
|
||||||
|
n.activation_checkpoint.append(ckpt_idx)
|
||||||
|
|
||||||
|
ckpt_idx += 1
|
||||||
|
ckpt_region = []
|
||||||
|
|
||||||
|
elif isinstance(op, ForwardCheck):
|
||||||
|
for node_idx in ckpt_region:
|
||||||
|
for n in node_list[node_idx]:
|
||||||
|
n.activation_checkpoint.append(ckpt_idx)
|
||||||
|
|
||||||
|
ckpt_idx += 1
|
||||||
|
ckpt_region = [op.index]
|
||||||
|
|
||||||
|
elif isinstance(op, Backward):
|
||||||
|
for node_idx in ckpt_region:
|
||||||
|
for n in node_list[node_idx]:
|
||||||
|
n.activation_checkpoint.append(ckpt_idx)
|
||||||
|
|
||||||
|
in_recompute = False
|
||||||
|
|
||||||
|
else:
|
||||||
|
if not isinstance(op, Backward):
|
||||||
|
in_recompute = True
|
||||||
|
ckpt_idx = 0
|
||||||
|
ckpt_region = []
|
||||||
|
if isinstance(op, ForwardCheck):
|
||||||
|
ckpt_region.append(op.index)
|
||||||
|
|
||||||
|
# postprocess, make sure every activation checkpoint label in the
|
||||||
|
# same activation checkpoint region (level = 0) has the same length
|
||||||
|
op_list = []
|
||||||
|
for node in node_list:
|
||||||
|
op_list += node
|
||||||
|
ckpt_regions = _find_nested_ckpt_regions(op_list)
|
||||||
|
for (start_idx, end_idx) in ckpt_regions:
|
||||||
|
nested_length = max(len(op_list[idx].activation_checkpoint) for idx in range(start_idx, end_idx + 1))
|
||||||
|
for idx in range(start_idx, end_idx + 1):
|
||||||
|
op_list[idx].activation_checkpoint += [None] * (nested_length - len(op_list[idx].activation_checkpoint))
|
||||||
|
|
||||||
|
|
||||||
def solver_rotor(gm: ColoGraphModule,
|
def solver_rotor(gm: ColoGraphModule,
|
||||||
data,
|
data,
|
||||||
mem_limit: int,
|
mem_limit: int,
|
||||||
mem_slots: int = 500,
|
mem_slots: int = 500,
|
||||||
cnode: List[str] = None) -> ColoGraphModule:
|
cnode: List[str] = None,
|
||||||
|
eps: float = 0.02) -> ColoGraphModule:
|
||||||
"""solver that automatically find activation checkpoint in rotor's manner
|
"""solver that automatically find activation checkpoint in rotor's manner
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -255,13 +387,14 @@ def solver_rotor(gm: ColoGraphModule,
|
|||||||
mem_limit (int): memory budget in Byte.
|
mem_limit (int): memory budget in Byte.
|
||||||
mem_slots (int, optional): number of slots for discretizing memory budget. Defaults to 500.
|
mem_slots (int, optional): number of slots for discretizing memory budget. Defaults to 500.
|
||||||
cnode (List[Node], optional): common node list for linearize. Defaults to None.
|
cnode (List[Node], optional): common node list for linearize. Defaults to None.
|
||||||
|
eps (float): epsilon for memory decay. Defaults to 0.02
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
ColoGraphModule: annotated ColoGraphModuled with __sequence__ attribute
|
ColoGraphModule: annotated ColoGraphModuled with __sequence__ attribute
|
||||||
"""
|
"""
|
||||||
|
|
||||||
node_list = linearize(gm, cnode)
|
node_list = linearize(gm, cnode)
|
||||||
mem_unit = mem_limit // mem_slots
|
mem_unit = mem_limit * (1.0 - eps) // mem_slots
|
||||||
MetaInfoProp(gm).run(data)
|
MetaInfoProp(gm).run(data)
|
||||||
chain: Chain = _construct_chain(node_list, data, mem_unit)
|
chain: Chain = _construct_chain(node_list, data, mem_unit)
|
||||||
opt_table = _compute_table(chain, mem_slots)
|
opt_table = _compute_table(chain, mem_slots)
|
||||||
|
@@ -1,6 +1,35 @@
|
|||||||
from typing import List
|
from typing import List, Any
|
||||||
from torch.fx import GraphModule, Node
|
from torch.fx import GraphModule, Node
|
||||||
|
|
||||||
|
# Common nodes are type of nodes that could be seen as attributes and remain
|
||||||
|
# unchanged throughout the whole model, it will be used several times by
|
||||||
|
# different blocks of model, so that it is hard for us to linearize the graph
|
||||||
|
# when we encounter those kinds of nodes. We let users to annotate some of the
|
||||||
|
# input as common node, such as attention mask, and the followings are some of
|
||||||
|
# the ops that could actually be seen as common nodes. With our common node prop,
|
||||||
|
# we could find some of the "real" common nodes (e.g. the real attention mask
|
||||||
|
# used in BERT and GPT), the rule is simple, for node who's parents are all common
|
||||||
|
# nodes or it's op belongs to the following operations, we view this node as a
|
||||||
|
# newly born common node.
|
||||||
|
# List of target name that could be seen as common node
|
||||||
|
COPS = ["getattr", "getitem", "size"]
|
||||||
|
|
||||||
|
|
||||||
|
def _is_cop(target: Any) -> bool:
|
||||||
|
"""Check if an op could be seen as common node
|
||||||
|
|
||||||
|
Args:
|
||||||
|
target (Any): node target
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool
|
||||||
|
"""
|
||||||
|
|
||||||
|
if isinstance(target, str):
|
||||||
|
return target in COPS
|
||||||
|
else:
|
||||||
|
return target.__name__ in COPS
|
||||||
|
|
||||||
|
|
||||||
def linearize(gm: GraphModule, cnode: List[str] = None) -> List[List[Node]]:
|
def linearize(gm: GraphModule, cnode: List[str] = None) -> List[List[Node]]:
|
||||||
"""Linearizing the graph
|
"""Linearizing the graph
|
||||||
@@ -53,7 +82,7 @@ def linearize(gm: GraphModule, cnode: List[str] = None) -> List[List[Node]]:
|
|||||||
region = []
|
region = []
|
||||||
|
|
||||||
# propagate common node attr if possible
|
# propagate common node attr if possible
|
||||||
if len(n._input_nodes) == len([node for node in n._input_nodes if node.name in cnode]):
|
if len(n._input_nodes) == len([node for node in n._input_nodes if node.name in cnode]) or _is_cop(n.target):
|
||||||
cnode.append(n.name)
|
cnode.append(n.name)
|
||||||
else:
|
else:
|
||||||
deps[n] = len([user for user in n.users if user.op != "output"])
|
deps[n] = len([user for user in n.users if user.op != "output"])
|
||||||
|
Reference in New Issue
Block a user