diff --git a/colossalai/fx/passes/algorithms/ckpt_solver_rotor.py b/colossalai/fx/passes/algorithms/ckpt_solver_rotor.py index b72d20fd2..0d8ed9553 100644 --- a/colossalai/fx/passes/algorithms/ckpt_solver_rotor.py +++ b/colossalai/fx/passes/algorithms/ckpt_solver_rotor.py @@ -7,6 +7,7 @@ from .linearize import linearize from .utils import * from colossalai.fx.profiler import profile_function, profile_module from colossalai.fx.passes.meta_info_prop import MetaInfoProp +from colossalai.fx.codegen.activation_checkpoint_codegen import _find_nested_ckpt_regions # this is the python compute table code from rotor @@ -36,7 +37,7 @@ def _compute_table(chain: Chain, mmax) -> Tuple: for m in range(mmax + 1): for i in range(chain.length + 1): #lmax-lmin = 0 - limit = max(cw[i + 1] + cbw[i + 1] + fwd_tmp[i], cw[i] + cw[i + 1] + cbw[i + 1] + bwd_tmp[i]) + limit = max(cw[i + 1] + cbw[i + 1] + fwd_tmp[i], cw[i + 1] + cbw[i + 1] + bwd_tmp[i]) if m >= limit: ## Equation (1) opt[m][i][i] = fw[i] + bw[i] else: @@ -151,6 +152,97 @@ def _get_inplace(node: Node) -> bool: return is_inplace +def _fwd_xbar(node: List[Node]) -> int: + """Get the forward xbar of a node + + Args: + node (List[Node]): List of torch.fx Node, + indicates a node in linearized graph + + Returns: + int: xbar size, unit Byte + """ + + xbar = 0 + for n in node: + xbar += n.fwd_tmp + n.fwd_out + return xbar + + +def _fwd_time(node: List[Node]) -> int: + """Get the foward time of a node + + Args: + node (List[Node]): List of torch.fx Node, + indicates a node in linearized graph + + Returns: + int: foward time, extimated by flops count + """ + + fwd_time = 0 + for n in node: + # minimum flop count is needed + fwd_time += max(n.fwd_flop, 1) + return fwd_time + + +def _bwd_time(node: List[Node]) -> int: + """Get the backward time of a node + + Args: + node (List[Node]): List of torch.fx Node, + indicates a node in linearized graph + + Returns: + int: backward time, extimated by flops count + """ + + bwd_time = 0 + for n in node: + # minimum flop count is needed + bwd_time += max(n.bwd_flop, 1) + return bwd_time + + +def _get_bwd_tmp(node: List[Node]) -> int: + """Get the backward temp memory of a node + + Args: + node (List[Node]): List of torch.fx Node, + indicates a node in linearized graph + + Returns: + int: backward temp memory, unit Byte + """ + + def _get_deps_size(): + deps_size = 0 + for key in deps.keys(): + deps_size += key.bwd_out + + return deps_size + + bwd_tmp = 0 + deps = {} + + # add all the users for last node into deps, + # as those nodes' gradient out will be stored in memory + for son in node[-1].users: + deps[son] = 1 + for n in reversed(node): + bwd_tmp = max(bwd_tmp, _get_deps_size() + n.bwd_tmp) + deps[n] = len(n._input_nodes) + for son in n.users: + deps[son] -= 1 + + for key in list(deps.keys()): + if deps[key] == 0: + del deps[key] + + return bwd_tmp + + def _construct_chain(node_list: List[List[Node]], data, mem_unit: int) -> Chain: fwd_time = [] @@ -160,45 +252,32 @@ def _construct_chain(node_list: List[List[Node]], data, mem_unit: int) -> Chain: xbar_sizes = [_compute_size(data)] x_sizes = [_compute_size(data)] elif isinstance(data, list) or isinstance(data, tuple): - xbar_sizes = [_compute_size(obj) for obj in data] - x_sizes = [_compute_size(obj) for obj in data] + xbar_sizes = [sum([_compute_size(obj) for obj in data])] + x_sizes = [sum([_compute_size(obj) for obj in data])] elif isinstance(data, dict): - xbar_sizes = [_compute_size(obj) for obj in data.values()] - x_sizes = [_compute_size(obj) for obj in data.values()] + xbar_sizes = [sum([_compute_size(obj) for obj in data.values()])] + x_sizes = [sum([_compute_size(obj) for obj in data.values()])] - # currently we can't get the temp memory needed in fwd and bwd + # currently we can't get the temp memory needed in fwd tmp_fwd = [0] * len(node_list) - tmp_bwd = [0] * (len(node_list) + 1) + tmp_bwd = [] for idx, node in enumerate(node_list): - fwd_time.append(0) - bwd_time.append(0) - xbar_sizes.append(0) + fwd_time.append(_fwd_time(node)) + bwd_time.append(_bwd_time(node)) x_sizes.append(_compute_output_size(node)) + xbar_sizes.append(max(x_sizes[-1], _fwd_xbar(node))) + tmp_bwd.append(_get_bwd_tmp(node)) - _check_inplace_flag = 1 - for n in node: - fwd_time[-1] += max(n.__flops__, 1) - - # currently we haven't patched the backward flops count - bwd_time[-1] += max(n.__flops__ * 2, 2) - xbar_sizes[-1] += n.__activation__ - - # we need to clear the xbar of previous node as there is - # one op in the current node that use the previous node's - # output but applies inplace operation on it - # NOTE: This process should be done only once as the previous - # node will only have one output - if _check_inplace_flag: - for par in n._input_nodes: - if par not in node and _get_inplace(n): - xbar_sizes[-2] -= x_sizes[-2] - _check_inplace_flag = 0 - - xbar_sizes[-1] = max(xbar_sizes[-1], x_sizes[-1]) + # if a node with only one inplace op, we need to let x_bar = 0 + if len(node) == 1 and _get_inplace(node[0]): + xbar_sizes[-1] = 0 bwd_time.append(0) + # currently we view loss backward temp as zero + tmp_bwd.append(0) + xbar_sizes = _discretize(mem_unit, xbar_sizes) x_sizes = _discretize(mem_unit, x_sizes) tmp_fwd = _discretize(mem_unit, tmp_fwd) @@ -207,14 +286,17 @@ def _construct_chain(node_list: List[List[Node]], data, mem_unit: int) -> Chain: return Chain(fwd_time, bwd_time, x_sizes, xbar_sizes, tmp_fwd, tmp_bwd) -def _annotate_from_sequence(sequence: Sequence, node_list: List[List[Node]]) -> GraphModule: +def _annotate_from_sequence(sequence: Sequence, node_list: List[List[Node]]): op_list = sequence.list_operations() loss_op = next(op for op in op_list if isinstance(op, Loss)) - op_list = op_list[:op_list.index(loss_op)] + fwd_list = op_list[:op_list.index(loss_op)] + bwd_list = op_list[op_list.index(loss_op) + 1:] ckpt_idx = 0 in_ckpt = False ckpt_region = [] - for idx, op in enumerate(op_list, 0): + + # forward annotation + for idx, op in enumerate(fwd_list, 0): if in_ckpt: if isinstance(op, ForwardNograd): ckpt_region.append(idx) @@ -223,7 +305,7 @@ def _annotate_from_sequence(sequence: Sequence, node_list: List[List[Node]]) -> in_ckpt = False for node_idx in ckpt_region: for n in node_list[node_idx]: - setattr(n, "activation_checkpoint", ckpt_idx) + setattr(n, "activation_checkpoint", [ckpt_idx]) ckpt_idx += 1 ckpt_region = [] @@ -231,7 +313,7 @@ def _annotate_from_sequence(sequence: Sequence, node_list: List[List[Node]]) -> elif isinstance(op, ForwardCheck): for node_idx in ckpt_region: for n in node_list[node_idx]: - setattr(n, "activation_checkpoint", ckpt_idx) + setattr(n, "activation_checkpoint", [ckpt_idx]) ckpt_idx += 1 ckpt_region = [idx] @@ -241,12 +323,62 @@ def _annotate_from_sequence(sequence: Sequence, node_list: List[List[Node]]) -> in_ckpt = True ckpt_region.append(idx) + # annotate the backward if there is any nested activation checkpoint + in_recompute = False + for op in bwd_list: + if in_recompute: + if isinstance(op, ForwardNograd): + ckpt_region.append(op.index) + + elif isinstance(op, ForwardEnable): + for node_idx in ckpt_region: + for n in node_list[node_idx]: + n.activation_checkpoint.append(ckpt_idx) + + ckpt_idx += 1 + ckpt_region = [] + + elif isinstance(op, ForwardCheck): + for node_idx in ckpt_region: + for n in node_list[node_idx]: + n.activation_checkpoint.append(ckpt_idx) + + ckpt_idx += 1 + ckpt_region = [op.index] + + elif isinstance(op, Backward): + for node_idx in ckpt_region: + for n in node_list[node_idx]: + n.activation_checkpoint.append(ckpt_idx) + + in_recompute = False + + else: + if not isinstance(op, Backward): + in_recompute = True + ckpt_idx = 0 + ckpt_region = [] + if isinstance(op, ForwardCheck): + ckpt_region.append(op.index) + + # postprocess, make sure every activation checkpoint label in the + # same activation checkpoint region (level = 0) has the same length + op_list = [] + for node in node_list: + op_list += node + ckpt_regions = _find_nested_ckpt_regions(op_list) + for (start_idx, end_idx) in ckpt_regions: + nested_length = max(len(op_list[idx].activation_checkpoint) for idx in range(start_idx, end_idx + 1)) + for idx in range(start_idx, end_idx + 1): + op_list[idx].activation_checkpoint += [None] * (nested_length - len(op_list[idx].activation_checkpoint)) + def solver_rotor(gm: ColoGraphModule, data, mem_limit: int, mem_slots: int = 500, - cnode: List[str] = None) -> ColoGraphModule: + cnode: List[str] = None, + eps: float = 0.02) -> ColoGraphModule: """solver that automatically find activation checkpoint in rotor's manner Args: @@ -255,13 +387,14 @@ def solver_rotor(gm: ColoGraphModule, mem_limit (int): memory budget in Byte. mem_slots (int, optional): number of slots for discretizing memory budget. Defaults to 500. cnode (List[Node], optional): common node list for linearize. Defaults to None. + eps (float): epsilon for memory decay. Defaults to 0.02 Returns: ColoGraphModule: annotated ColoGraphModuled with __sequence__ attribute """ node_list = linearize(gm, cnode) - mem_unit = mem_limit // mem_slots + mem_unit = mem_limit * (1.0 - eps) // mem_slots MetaInfoProp(gm).run(data) chain: Chain = _construct_chain(node_list, data, mem_unit) opt_table = _compute_table(chain, mem_slots) diff --git a/colossalai/fx/passes/algorithms/linearize.py b/colossalai/fx/passes/algorithms/linearize.py index f8c531356..043827a76 100644 --- a/colossalai/fx/passes/algorithms/linearize.py +++ b/colossalai/fx/passes/algorithms/linearize.py @@ -1,6 +1,35 @@ -from typing import List +from typing import List, Any from torch.fx import GraphModule, Node +# Common nodes are type of nodes that could be seen as attributes and remain +# unchanged throughout the whole model, it will be used several times by +# different blocks of model, so that it is hard for us to linearize the graph +# when we encounter those kinds of nodes. We let users to annotate some of the +# input as common node, such as attention mask, and the followings are some of +# the ops that could actually be seen as common nodes. With our common node prop, +# we could find some of the "real" common nodes (e.g. the real attention mask +# used in BERT and GPT), the rule is simple, for node who's parents are all common +# nodes or it's op belongs to the following operations, we view this node as a +# newly born common node. +# List of target name that could be seen as common node +COPS = ["getattr", "getitem", "size"] + + +def _is_cop(target: Any) -> bool: + """Check if an op could be seen as common node + + Args: + target (Any): node target + + Returns: + bool + """ + + if isinstance(target, str): + return target in COPS + else: + return target.__name__ in COPS + def linearize(gm: GraphModule, cnode: List[str] = None) -> List[List[Node]]: """Linearizing the graph @@ -53,7 +82,7 @@ def linearize(gm: GraphModule, cnode: List[str] = None) -> List[List[Node]]: region = [] # propagate common node attr if possible - if len(n._input_nodes) == len([node for node in n._input_nodes if node.name in cnode]): + if len(n._input_nodes) == len([node for node in n._input_nodes if node.name in cnode]) or _is_cop(n.target): cnode.append(n.name) else: deps[n] = len([user for user in n.users if user.op != "output"])