From de1e716dc455c8cad1828d1765529e8e113f5196 Mon Sep 17 00:00:00 2001 From: Boyuan Yao <70263930+Cypher30@users.noreply.github.com> Date: Fri, 26 Aug 2022 10:34:21 +0800 Subject: [PATCH] [fx] Add activation checkpoint solver rotor (#1496) * [fx] fix defining ckpt functions inside forward * [fx] Modify activation checkpoint codegen and add ColoGraphModule * [fx] some modification * some modifications * some modifications * some modifications * some modifications * some code modifications * [automatic_parallel] ckpt solver rotor * [fx] add ckpt_solver_rotor * [fx] modification * code refactor * code refactor --- colossalai/fx/__init__.py | 1 + colossalai/fx/passes/algorithms/__init__.py | 2 + .../fx/passes/algorithms/ckpt_solver_rotor.py | 198 +++++++++++++++ colossalai/fx/passes/algorithms/linearize.py | 89 +++++++ colossalai/fx/passes/algorithms/utils.py | 229 ++++++++++++++++++ .../test_ckpt_torchvision.py | 14 +- 6 files changed, 529 insertions(+), 4 deletions(-) create mode 100644 colossalai/fx/passes/algorithms/ckpt_solver_rotor.py create mode 100644 colossalai/fx/passes/algorithms/linearize.py create mode 100644 colossalai/fx/passes/algorithms/utils.py diff --git a/colossalai/fx/__init__.py b/colossalai/fx/__init__.py index ec6508a30..6513f6d03 100644 --- a/colossalai/fx/__init__.py +++ b/colossalai/fx/__init__.py @@ -1 +1,2 @@ from .tracer import ColoTracer +from .graph_module import ColoGraphModule diff --git a/colossalai/fx/passes/algorithms/__init__.py b/colossalai/fx/passes/algorithms/__init__.py index bf6f9eb28..465fef432 100644 --- a/colossalai/fx/passes/algorithms/__init__.py +++ b/colossalai/fx/passes/algorithms/__init__.py @@ -1 +1,3 @@ from .ckpt_solver_chen import chen_greedy +from .linearize import linearize +from .ckpt_solver_rotor import solver_rotor diff --git a/colossalai/fx/passes/algorithms/ckpt_solver_rotor.py b/colossalai/fx/passes/algorithms/ckpt_solver_rotor.py new file mode 100644 index 000000000..396cf7b29 --- /dev/null +++ b/colossalai/fx/passes/algorithms/ckpt_solver_rotor.py @@ -0,0 +1,198 @@ +from typing import List, Set, Tuple, Dict +import torch +from torch.fx import GraphModule, Node +import math +from .linearize import linearize +from .utils import * +from colossalai.fx.profiler import profile_function, profile_module +from colossalai.fx.passes.meta_info_prop import MetaInfoProp + + +# this is the python compute table code from rotor +# https://gitlab.inria.fr/hiepacs/rotor +# paper link: https://hal.inria.fr/hal-02352969 +def _compute_table(chain: Chain, mmax) -> Tuple: + """Returns the optimal table: a tuple containing: + Opt[m][lmin][lmax] with lmin = 0...chain.length + and lmax = lmin...chain.length (lmax is not included) and m = 0...mmax + what[m][lmin][lmax] is (True,) if the optimal choice is a chain checkpoint + (False, j) if the optimal choice is a leaf checkpoint of length j + The computation uses dynamic programming""" + + fw = chain.fweight + [0] ## forward time + bw = chain.bweight ## backward time, not used + cw = chain.cweight + [0] ## size of x (and of y) + cbw = chain.cbweight + [0] ## size of xbar + fwd_tmp = chain.fwd_tmp + [0] + bwd_tmp = chain.bwd_tmp + [0] + + # Build table + opt = [[{} for _ in range(chain.length + 1)] for _ in range(mmax + 1)] + what = [[{} for _ in range(chain.length + 1)] for _ in range(mmax + 1)] + ## Last one is a dict because its indices go from i to l. Renumbering will wait for C implementation + + # Initialize borders of the tables for lmax-lmin = 0 + for m in range(mmax + 1): + for i in range(chain.length + 1): + #lmax-lmin = 0 + limit = max(cw[i + 1] + cbw[i + 1] + fwd_tmp[i], cw[i] + cw[i + 1] + cbw[i + 1] + bwd_tmp[i]) + if m >= limit: ## Equation (1) + opt[m][i][i] = fw[i] + bw[i] + else: + opt[m][i][i] = float("inf") + + # Compute everything + for m in range(mmax + 1): + for d in range(1, chain.length + 1): + for i in range(chain.length + 1 - d): + # for idx in range(i+1, chain.length + 1): + idx = i + d + mmin = cw[idx + 1] + cw[i + 1] + fwd_tmp[i] + if idx > i + 1: + mmin = max(mmin, cw[idx + 1] + max(cw[j] + cw[j + 1] + fwd_tmp[j] for j in range(i + 1, idx))) + if m < mmin: + opt[m][i][idx] = float("inf") + else: + leaf_checkpoints = [(j, sum(fw[i:j]) + opt[m - cw[j]][j][idx] + opt[m][i][j - 1]) + for j in range(i + 1, idx + 1) + if m >= cw[j]] + if leaf_checkpoints: + best_leaf = min(leaf_checkpoints, key=lambda t: t[1]) + else: + best_leaf = None + if m >= cbw[i + 1]: + chain_checkpoint = opt[m][i][i] + opt[m - cbw[i + 1]][i + 1][idx] + else: + chain_checkpoint = float("inf") + if best_leaf and best_leaf[1] <= chain_checkpoint: + opt[m][i][idx] = best_leaf[1] + what[m][i][idx] = (False, best_leaf[0]) + else: + opt[m][i][idx] = chain_checkpoint + what[m][i][idx] = (True,) + return (opt, what) + + +def _rec(chain, lmin, lmax, cmem, opt_table): + """ chain : the class describing the AC graph + lmin : index of the first forward to execute + lmax : upper bound index of the last forward to execute (not included) + cmem : number of available memory slots + Return the optimal sequence of makespan Opt_hete[cmem][lmin][lmax-lmin]""" + if cmem <= 0: + raise ValueError("Can not process a chain with negative memory {cmem}".format(cmem=cmem)) + opt, what = opt_table + sequence = Sequence(Function("Persistent", lmax - lmin, cmem)) + if opt[cmem][lmin][lmax] == float("inf"): + raise ValueError("Can not process this chain from index {lmin} to {lmax} with memory {cmem}".format(lmin=lmin, + lmax=lmax, + cmem=cmem)) + if lmin == lmax: + if lmin == chain.length: + sequence.insert(Loss()) + else: + sequence.insert(ForwardEnable(lmin)) + sequence.insert(Backward(lmin)) + return sequence + + if what[cmem][lmin][lmax][0]: + sequence.insert(ForwardEnable(lmin)) + sequence.insert_sequence(_rec(chain, lmin + 1, lmax, cmem - chain.cbweigth[lmin + 1], opt_table)) + sequence.insert(Backward(lmin)) + else: + j = what[cmem][lmin][lmax][1] + sequence.insert(ForwardCheck(lmin)) + for k in range(lmin + 1, j): + sequence.insert(ForwardNograd(k)) + sequence.insert_sequence(_rec(chain, j, lmax, cmem - chain.cweigth[j], opt_table)) + sequence.insert_sequence(_rec(chain, lmin, j - 1, cmem, opt_table)) + return sequence + + +def _discretize(mem_unit, values): + return [math.ceil(value / mem_unit) for value in values] + + +def _construct_chain(node_dict: Dict[int, Node], data: torch.Tensor, mem_unit: int) -> Chain: + + fwd_time = [] + bwd_time = [] + xbar_sizes = [data.numel() * data.element_size()] + x_sizes = [data.numel() * data.element_size()] + + # currently we can't get the temp memory needed in fwd and bwd + tmp_fwd = [0] * len(node_dict) + tmp_bwd = [0] * (len(node_dict) + 1) + + for key in node_dict.keys(): + fwd_time.append(0) + bwd_time.append(0) + xbar_sizes.append(0) + x_sizes.append(node_dict[key][-1].meta['tensor_meta'].numel * + torch.tensor([], dtype=node_dict[key][-1].meta['tensor_meta'].dtype).element_size()) + for node in node_dict[key]: + fwd_time[-1] += node.__flops__ + + # currently we haven't patched the backward flops count + bwd_time[-1] += node.__flops__ * 2 + + xbar_sizes[-1] += node.__activation__ + + xbar_sizes[-1] = max(xbar_sizes[-1], x_sizes[-1]) + + bwd_time.append(0) + + fwd_time = _discretize(mem_unit, fwd_time) + bwd_time = _discretize(mem_unit, bwd_time) + xbar_sizes = _discretize(mem_unit, xbar_sizes) + x_sizes = _discretize(mem_unit, x_sizes) + tmp_fwd = _discretize(mem_unit, tmp_fwd) + tmp_bwd = _discretize(mem_unit, tmp_bwd) + + return Chain(fwd_time, bwd_time, x_sizes, xbar_sizes, tmp_fwd, tmp_bwd) + + +def _annotate_from_sequence(sequence: Sequence, node_dict: Dict[int, Node]) -> GraphModule: + op_list = sequence.list_operations() + loss_op = [op for op in op_list if isinstance(op, Loss)][0] + op_list = op_list[:op_list.index(loss_op)] + ckpt_idx = 0 + in_ckpt = False + ckpt_region = [] + for idx, op in enumerate(op_list, 1): + if in_ckpt: + if isinstance(op, ForwardNograd): + ckpt_region.append(idx) + + elif isinstance(op, ForwardEnable): + in_ckpt = False + for idx in ckpt_region: + for node in node_dict[idx]: + setattr(node, "activation_checkpoint", ckpt_idx) + + ckpt_idx += 1 + ckpt_region = [] + + elif isinstance(op, ForwardCheck): + for idx in ckpt_region: + for node in node_dict[idx]: + setattr(node, "activation_checkpoint", ckpt_idx) + + ckpt_idx += 1 + ckpt_region = [idx] + + else: + if isinstance(op, ForwardCheck): + in_ckpt = True + ckpt_region.append(idx) + + +def solver_rotor(gm: GraphModule, data: torch.Tensor, mem_limit: int, mem_slots: int = 500) -> GraphModule: + node_dict = linearize(gm) + mem_unit = mem_limit // mem_slots + MetaInfoProp(gm).run(data) + chain: Chain = _construct_chain(node_dict, data, mem_unit) + opt_table = _compute_table(chain, mem_slots) + sequence = _rec(chain, 0, chain.length, mem_slots - chain.cweight[0], opt_table) + _annotate_from_sequence(sequence, node_dict) + return gm diff --git a/colossalai/fx/passes/algorithms/linearize.py b/colossalai/fx/passes/algorithms/linearize.py new file mode 100644 index 000000000..19d84a046 --- /dev/null +++ b/colossalai/fx/passes/algorithms/linearize.py @@ -0,0 +1,89 @@ +from typing import OrderedDict +from torch.fx import GraphModule +from collections import OrderedDict +import pdb + + +def linearize(gm: GraphModule) -> dict: + status_dict = {} + node_dict = OrderedDict() + node_idx = 0 + for node in gm.graph.nodes: + last_dict_len = len(status_dict) + # remove node from users list in status_dict + for item in status_dict.values(): + if node in item: + item.remove(node) + + # pop node from status_dict if it is fully used + for key in list(status_dict): + if len(status_dict[key]) == 0: + status_dict.pop(key) + + # first node in graph, it should be in n0-n1 type, + # where n0 contains only input op, i.e. placeholder + if last_dict_len == 0: + node_dict[node_idx] = [node] + status_dict[node.name] = list(node.users) + node_idx += 1 + node_dict[node_idx] = [] + + continue + + # boundary case + if len(status_dict) == 0: + # current node region end point = next node region start point + # i.e. n1-n2-n3-... type node, each node contains only one op + if last_dict_len == 1: + if len(node_dict[node_idx]) > 0: + node_idx += 1 + node_dict[node_idx] = [] + node_dict[node_idx].append(node) + status_dict[node.name] = list(node.users) + + continue + + # n1-n2-n3, if n1 has multiple ops, the last op in n1 will be + # the one who is able to clean all others in status_dict + # and as the last_dict_len > 1, there are multiple ops are used + # by this node, we view it as the end of one node and start a new node + else: + + node_dict[node_idx].append(node) + status_dict[node.name] = list(node.users) + node_idx += 1 + node_dict[node_idx] = [] + + continue + + else: + # currently I will use bigger node structure + # if the following region is activated, the node will be smaller + ################################################# + # if last_dict_len == 1: + # if len(node_dict[node_idx]) > 0: + # node_idx += 1 + # node_dict[node_idx] = [node] + # status_dict[node.name] = list(node.users) + # + # continue + ################################################# + + # in-node case, as the current node can not clean status_dict + # we view it as in-node status, the node will be appended to the + # current node_idx + node_dict[node_idx].append(node) + status_dict[node.name] = list(node.users) + + continue + + # If the output node use multiple nodes, there might be an + # empty node after the output node + if len(node_dict[node_idx]) == 0: + node_dict.pop[node_idx] + node_idx -= 1 + + # pop the last two nodes + node_dict.pop(0) + node_dict.pop(node_idx) + return node_dict diff --git a/colossalai/fx/passes/algorithms/utils.py b/colossalai/fx/passes/algorithms/utils.py new file mode 100644 index 000000000..88efe0a0c --- /dev/null +++ b/colossalai/fx/passes/algorithms/utils.py @@ -0,0 +1,229 @@ +class Chain: + + def __init__(self, fw, bw, cw, cbw, ftmp, btmp, check=True): + self.fweight = fw + self.bweight = bw + self.cweight = cw + self.cbweight = cbw + self.fwd_tmp = ftmp + self.bwd_tmp = btmp + self.length = len(fw) + if check and not self.check_lengths(): + raise AttributeError("In Chain, input lists do not have consistent lengths") + + def check_lengths(self): + return ((len(self.fweight) == self.length) and (len(self.bweight) == self.length + 1) + and (len(self.cweight) == self.length + 1) and (len(self.fwd_tmp) == self.length) + and (len(self.bwd_tmp) == self.length + 1) and (len(self.cbweight) == self.length + 1)) + + def __repr__(self): + chain_list = [] + for i in range(self.length): + chain_list.append( + (self.fweight[i], self.bweight[i], self.cweight[i], self.cbweight[i], self.fwd_tmp[i], self.bwd_tmp[i])) + i = self.length + chain_list.append((None, self.bweight[i], self.cweight[i], self.cbweight[i], None, self.bwd_tmp[i])) + return chain_list.__repr__() + + +class Operation: + + def shift(self, value): + if type(self.index) is tuple: + self.index = tuple(x + value for x in self.index) + else: + self.index += value + + +class Forward(Operation): + + def __init__(self, index): + self.index = index + self.name = "F" + + def __repr__(self): + return "{n}_{i}".format(n=self.name, i=self.index) + + def cost(self, chain): + if chain is not None: + return chain.fweigth[self.index] + else: + return 1 + + +class ForwardEnable(Forward): + + def __init__(self, index): + super().__init__(index) + self.name = "Fe" + + +class ForwardNograd(Forward): + + def __init__(self, index): + super().__init__(index) + self.name = "Fn" + + +class ForwardCheck(Forward): + + def __init__(self, index): + super().__init__(index) + self.name = "CF" + + +class Forwards(Operation): + + def __init__(self, start, end): + self.index = (start, end) + + def __repr__(self): + return "F_{i}->{j}".format(i=self.index[0], j=self.index[1]) + + def cost(self, chain): + if chain is not None: + return sum(chain.fweigth[self.index[0]:self.index[1] + 1]) + else: + return (self.index[1] - self.index[0] + 1) + + +def isForward(op): + return type(op) is Forward or type(op) is Forwards + + +class Backward(Operation): + + def __init__(self, index): + self.index = index + + def __repr__(self): + return "B_{i}".format(i=self.index) + + def cost(self, chain): + if chain is not None: + return chain.bweigth[self.index] + else: + return 1 + + +class Loss(Operation): + + def __init__(self): + pass + + def __repr__(self): + return "L" + + def cost(self, chain): + return 0 + + +class MemoryAccess(Operation): + + def __init__(self, index): + self.index = index + + def __repr__(self): + return "{n}_{i}".format(n=self.name, i=self.index) + + def cost(self, chain): + return 0 + + +class WriteMemory(MemoryAccess): + + def __init__(self, index): + super().__init__(index) + self.name = "WM" + + +class ReadMemory(MemoryAccess): + + def __init__(self, index): + super().__init__(index) + self.name = "RM" + + +class DiscardMemory(MemoryAccess): + + def __init__(self, index): + super().__init__(index) + self.name = "DM" + + +class Function: + + def __init__(self, name, *args): + self.name = name + self.args = args + self.str_args = ','.join(str(v) for v in self.args) + + def __repr__(self): + return "{n}({args})".format(n=self.name, args=self.str_args) + + +class Sequence: + + def __init__(self, function): + self.sequence = [] #List of Operation and Sequence + self.function = function #Description the function (name and parameters) + + def __repr__(self): + return repr(self.list_operations()) + + def list_operations(self): + op_list = [] + for x in self.sequence: + if isinstance(x, Operation): + op_list.append(x) + else: + assert isinstance(x, Sequence) + op_list += x.list_operations() + return op_list + + def insert(self, operation): + self.sequence.append(operation) + + def remove(self, operation_index): + del self.sequence[operation_index] + + def insert_sequence(self, sequence): + self.sequence.append(sequence) + + def shift(self, value): + for x in self.sequence: + x.shift(value) + return self + + def remove_useless_write(self): + if self.sequence: + if isinstance(self.sequence[0], WriteMemory): + self.remove(0) + return self + + def get_makespan(self, chain): + return sum(op.cost(chain) for op in self.list_operations()) + + def without_suffix(self): + ops = self.list_operations() + end_of_first_phase = [i for i in range(len(ops)) if type(ops[i]) is Loss][0] + try: + last_idx = max(i for i in range(end_of_first_phase) if not type(ops[i]) is ForwardEnable) + except ValueError: + last_idx = -1 + if last_idx == end_of_first_phase - 1: + return (self, None) + chain_length = ops[end_of_first_phase - + 1].index ## Some assumption here about the sequence (finishes with Forward_L + start_of_fwd_enable_chain = ops[last_idx + 1].index ## And starts with B_L), but should be fine in practice + result = Sequence(Function("Strip", self.function.name, *self.function.args, start_of_fwd_enable_chain)) + for i in range(last_idx + 1): + result.insert(ops[i]) + result.insert(Loss()) + for i in range(chain_length, start_of_fwd_enable_chain - 1, -1): + position = end_of_first_phase + 1 + (chain_length - i) + assert type(ops[position]) is Backward + assert ops[position].index == i + for i in range(end_of_first_phase + 1 + 1 + chain_length - start_of_fwd_enable_chain, len(ops)): + result.insert(ops[i]) + return (result, start_of_fwd_enable_chain) diff --git a/tests/test_fx/test_ckpt_solvers/test_ckpt_torchvision.py b/tests/test_fx/test_ckpt_solvers/test_ckpt_torchvision.py index 31e54db36..1d6352d07 100644 --- a/tests/test_fx/test_ckpt_solvers/test_ckpt_torchvision.py +++ b/tests/test_fx/test_ckpt_solvers/test_ckpt_torchvision.py @@ -9,7 +9,7 @@ import colossalai from colossalai.fx import ColoTracer from colossalai.fx.graph_module import ColoGraphModule from colossalai.fx.passes.meta_info_prop import MetaInfoProp -from colossalai.fx.passes.algorithms import chen_greedy +from colossalai.fx.passes.algorithms import chen_greedy, solver_rotor from colossalai.utils import free_port from colossalai.core import global_context as gpc import pytest @@ -22,7 +22,7 @@ except: from colossalai.fx.codegen import python_code_with_activation_checkpoint with_codegen = False -SOLVERS = [chen_greedy] +SOLVERS = [chen_greedy, solver_rotor] def _is_activation_checkpoint_available(gm: GraphModule): @@ -77,7 +77,10 @@ def _run_ckpt_solver(rank): MetaInfoProp(gm).run(data) codegen = ActivationCheckpointCodeGen() gm.graph.set_codegen(codegen) - gm = solver(gm) + if solver == solver_rotor: + gm = solver(gm, data, mem_limit=500 * 1024 * 1024, mem_slots=500) + else: + gm = solver(gm) assert _is_graph_linearized(gm), f"Solver {solver} did not solve {model_cls} in a linearized manner." assert _is_activation_checkpoint_available( gm), f"Solver {solver} did not annotate {model_cls} with any activation checkpoints" @@ -106,7 +109,10 @@ def _run_ckpt_solver_torch11(rank): gm = ColoGraphModule(copy.deepcopy(m), graph, m.__class__.__name__) MetaInfoProp(gm).run(data) gm.graph._python_code = python_code_with_activation_checkpoint.__get__(graph) - gm = solver(gm) + if solver == solver_rotor: + gm = solver(gm, data, mem_limit=500 * 1024 * 1024, mem_slots=500) + else: + gm = solver(gm) assert _is_graph_linearized(gm), f"Solver {solver} did not solve {model_cls} in a linearized manner." assert _is_activation_checkpoint_available( gm), f"Solver {solver} did not annotate {model_cls} with any activation checkpoints"