diff --git a/colossalai/fx/passes/algorithms/__init__.py b/colossalai/fx/passes/algorithms/__init__.py index 465fef432..9ccf135d0 100644 --- a/colossalai/fx/passes/algorithms/__init__.py +++ b/colossalai/fx/passes/algorithms/__init__.py @@ -1,3 +1,4 @@ from .ckpt_solver_chen import chen_greedy from .linearize import linearize from .ckpt_solver_rotor import solver_rotor +from .ckpt_solver_pofo import solver_pofo diff --git a/colossalai/fx/passes/algorithms/ckpt_solver_pofo.py b/colossalai/fx/passes/algorithms/ckpt_solver_pofo.py new file mode 100644 index 000000000..03cf30714 --- /dev/null +++ b/colossalai/fx/passes/algorithms/ckpt_solver_pofo.py @@ -0,0 +1,404 @@ +from typing import List, Tuple +import copy +import torch +from torch.fx import GraphModule, Node +from colossalai.fx.graph_module import ColoGraphModule +from colossalai.fx.profiler import parameter_size +import math +from .linearize import linearize +from .operation import ForwardCheck, ForwardEnable, ForwardNograd, Backward, Loss, Chain, Sequence, Function, Offload, Prefetch +from colossalai.fx.passes.meta_info_prop import MetaInfoProp +from colossalai.fx.codegen.activation_checkpoint_codegen import _find_nested_ckpt_regions +from colossalai.fx.passes.algorithms.ckpt_solver_rotor import _construct_chain, _compute_table, _rec + +INF = float("inf") + + +def _normalize_flops(chain: Chain, flops) -> Chain: + """ + Normalize flops + """ + for i in range(chain.length): + chain.fweight[i] /= flops + chain.bweight[i] /= flops + + return chain + + +class PofoTable: + """PofoTable + The PofoTable contains the necessary components to store intermediate results + of dynamic programming and the operations alone the way. + """ + + def __init__(self, chain_length: int, mem_slots: int): + """Init pofo table + The pofo table contains two tables, opt and what, indicating values and + operations. + + Args: + chain_length (int): chain length + mem_slots (int): number of memory slots + """ + + self.length = chain_length + self.mem_slots = mem_slots + + # initializing tables + # the first bool indicates whether the input has bar + # opt table is for value, opt[True/False][i][A][(df, db)] = OCx(i, A, df, db) + # what table is for decision, what[True/False][i][A][(df, db)] = (is_enable, is_offload, index) + # where is_enable indicates whether we enable the gradient, is_offload indicates whether we + # offload the input, index indicates the end of F_\empty sequence if is_enable = False + self.opt = { + False: [[{} for _ in range(mem_slots + 1)] for _ in range(self.length + 1)], + True: [[{} for _ in range(mem_slots + 1)] for _ in range(self.length + 1)] + } + self.what = { + False: [[{} for _ in range(mem_slots + 1)] for _ in range(self.length + 1)], + True: [[{} for _ in range(mem_slots + 1)] for _ in range(self.length + 1)] + } + + def _get_value(self, state, table, default): + i, act_size, df, db, input_has_bar = state + if act_size + df > self.mem_slots or act_size + db > self.mem_slots: + return default + + try: + return table[input_has_bar][i][act_size][(df, db)] + except KeyError: + print(f"state not found {state}") + + def get_opt(self, state): + return self._get_value(state, self.opt, INF) + + def get_what(self, state): + return self._get_value(state, self.what, INF) + + def set_value(self, state, opt, what): + i, act_size, df, db, input_has_bar = state + self.opt[input_has_bar][i][act_size][(df, db)] = opt + self.what[input_has_bar][i][act_size][(df, db)] = what + + +class PofoSolver: + """PofoSolver that executes algorithm mentioned in https://proceedings.neurips.cc/paper/2021/hash/c8461bf13fca8a2b9912ab2eb1668e4b-Abstract.html + The new pofo solver is based on paper Efficient Combination of Rematerialization and Offloading for Training DNNs + and it's code given in the supplemental. Currently we doesn't use the whole set up in the original paper and reuse + rotor solver for the backward sequence as suggested in supplemental. The solver now is able to find strategy with offload. + """ + + def __init__(self, chain: Chain, max_memory: int, bandwidth, mem_slots: int) -> None: + self.chain = chain + self.length = chain.length + self.max_memory = max_memory + self.mem_slots = mem_slots + self.mem_unit = max_memory / mem_slots + self.bandwidth = bandwidth + + self.disc_chain = copy.deepcopy(self.chain) + + self.rotor_table = _compute_table(self.disc_chain, mem_slots) + self._compute_pofo_table() + + def _discretize(self, *values) -> Tuple: + return tuple(math.ceil(value / self.mem_unit) for value in values) + + def _undiscretize(self, *discrete_values) -> Tuple: + if len(discrete_values) == 1: + return discrete_values[0] * self.mem_unit + else: + return tuple(d * self.mem_unit for d in discrete_values) + + def _mmax_all(self, idx: int): + """ + Calculate the maximum memory usage of Fi_all + """ + + return self.chain.cbweight[idx + 1] + self.chain.fwd_mem_tmp[idx] + + def _mmax_b(self, idx: int): + """ + Calculate the maximum memory usage of Bi + """ + + return self.chain.cbweight[idx + + 1] + self.chain.cweight[idx + + 1] + self.chain.cweight[idx] + self.chain.bwd_mem_tmp[idx] + + def _mmax_ng(self, i: int, j: int): + """ + Calculate the maximum memory usage of CF_i, F_i+1\empty, ... F_j\empty + """ + + res = self.chain.cweight[j + 1] + self.chain.fwd_mem_tmp[j] + if j > i: + res += self.chain.cweight[j] + return res + + def _rotor_estimated_bwd(self, i, j, m, delta): + compute = self.rotor_table[0][math.floor((m - self.chain.cweight[i]) / self.mem_unit)][i][j] + comm = delta / self.bandwidth + return (max(compute, comm) + compute + comm) / 2 + + def _rotor_estimated_bwd_sequence(self, i, j, m, delta): + return _rec(self.disc_chain, i, j, math.floor(m - self.chain.cweight[i] / self.mem_unit), self.rotor_table) + + def _common_values_enable(self, state: Tuple): + + idx, act_size, df, db, input_has_bar = state + input_size = self.chain.cbweight[idx] if input_has_bar else self.chain.cweight[idx] + mf = act_size + df + input_size + mb = act_size + db + input_size + mem_avail = self.max_memory - act_size - input_size + f_usage = self._mmax_all(idx) + b_usage = self._mmax_b(idx) + + # infeasible + if f_usage > mem_avail or b_usage > mem_avail: + return None + + # calculate idle time + eps_f_beta = max(0, f_usage - self.max_memory + mf) + eps_b_beta = max(0, b_usage - self.max_memory + mb) + idle_time = (eps_f_beta + eps_b_beta) / self.bandwidth + + # calculate offload and prefetch data + offload_data = self.chain.fweight[idx] * self.bandwidth + eps_f_beta + prefetch_data = self.chain.bweight[idx] * self.bandwidth + eps_b_beta + + # total_time + total_time = self.chain.fweight[idx] + self.chain.bweight[idx] + idle_time + + return (offload_data, prefetch_data, total_time, idle_time) + + def _common_values_nograd(self, state: Tuple, j: int, iterative: bool = False): + + i, act_size, df, db, input_has_bar = state + + # compute new epsilon_tmp and sum_fwds + if iterative: + self.epsilon_tmp = max(self.epsilon_tmp, self._mmax_ng(i, j) - self.bandwidth * self.sum_fwds) + self.sum_fwds += self.chain.fweight[j] + else: + self.epsilon_tmp = max( + self._mmax_ng(i, k) - self.bandwidth * sum(self.chain.fweight[i:k]) for k in range(i, j + 1)) + self.sum_fwds = sum(self.chain.fweight[i:j + 1]) + + input_size = self.chain.cbweight[i] if input_has_bar else self.chain.cweight[i] + mf = act_size + df + input_size + mem_avail = self.max_memory - act_size - input_size + + # if infeasible + if max(self._mmax_ng(i, k) for k in range(i, self.length)) > mem_avail: + return None + + eps_f_beta = max(0, self.epsilon_tmp - self.max_memory + mf) + offload_data = self.sum_fwds * self.bandwidth + eps_f_beta + + # TODO: Implement the precise backward recompute sequence mentioned in the paper + # currently we will use an approximate way to get the backward time + time_backward = self._rotor_estimated_bwd(i, j, mem_avail, db) + + prefetch_data = time_backward * self.bandwidth + idle_time = eps_f_beta / self.bandwidth + total_time = self.sum_fwds + idle_time + time_backward + + return (offload_data, prefetch_data, total_time, idle_time) + + def _new_values(self, state: Tuple, do_offload: bool, common_values: Tuple) -> Tuple: + """Generate new values for next state + + Args: + state (Tuple): undiscretized states + do_offload (bool): bool type indicates whether we need to do offload + common_values (Tuple): common values (offload_data, prefetch_data, total_time, idle_time) + + Returns: + Tuple: (new_act_size, new_df, new_db) + """ + idx, act_size, df, db, input_has_bar = state + offload_data, prefetch_data, *_ = common_values + input_size = self.chain.cbweight[idx] if input_has_bar else self.chain.cweight[idx] + if do_offload: + new_act_size = act_size + new_df = max(0, df + input_size - offload_data) + new_db = max(0, db - prefetch_data) + input_size + else: + new_act_size = act_size + input_size + new_df = max(0, df - offload_data) + new_db = max(0, db - prefetch_data) + + return (new_act_size, new_df, new_db) + + def _compute_pofo_table(self): + self.table = PofoTable(self.length, self.mem_slots) + + # initializing the loss + for act_size in range(self.mem_slots + 1): + for df in range(self.mem_slots - act_size + 1): + for db in range(self.mem_slots - act_size + 1): + # undiscretize for idle time calculation + origin_values = self._undiscretize(act_size, df, db) + + for input_has_bar in (False, True): + disc_state = (self.length, act_size, df, db, input_has_bar) + state = (self.length, *origin_values, input_has_bar) + common_values = self._common_values_enable(state) + + # if no feasible choice + if common_values is None: + self.table.set_value(disc_state, INF, None) + continue + + # if there is feasible choice + new_act_size, new_df, new_db = self._new_values(state, False, common_values) + eps_g = (new_df + new_db) / self.bandwidth + total_time = common_values[2] + eps_g + self.table.set_value(disc_state, total_time, (True, False)) + + # main loop + for i in reversed(range(self.length)): + for act_size in range(self.mem_slots + 1): + for df in range(self.mem_slots - act_size + 1): + for db in range(self.mem_slots - act_size + 1): + # undiscretize for idle time calculation + origin_values = self._undiscretize(act_size, df, db) + + for input_has_bar in (False, True): + best_result = INF + best_choice = None + disc_state = (i, act_size, df, db, input_has_bar) + state = (i, *origin_values, input_has_bar) + + # case 1: start with F_all + vals_enable = self._common_values_enable(state) + if vals_enable is not None: + for do_offload in (True, False): + new_state = self._new_values(state, do_offload, vals_enable) + new_state = (i + 1, *self._discretize(*new_state), True) + total_time = vals_enable[2] + results_all = self.table.get_opt(new_state) + total_time + if results_all < best_result: + best_result = results_all + best_choice = (True, do_offload) + + # case 2: start with F_ck + self.sum_fwds = 0 + self.epsilon_tmp = 0 + for j in range(i, self.length): + vals_nograd = self._common_values_nograd(state, j, True) + + # if infeasible + if vals_nograd is None: + continue + + for do_offload in (True, False): + new_state = self._new_values(state, do_offload, vals_nograd) + new_state = (j + 1, *self._discretize(*new_state), False) + total_time = vals_nograd[2] + result_nograd = total_time + self.table.get_opt(new_state) + if result_nograd < best_result: + best_result = result_nograd + best_choice = (False, do_offload, j) + + self.table.set_value(disc_state, best_result, best_choice) + + def pofo_rec(self, disc_state): + i, act_size, df, db, input_has_bar = disc_state + result = Sequence(Function("pofo", *disc_state)) + what = self.table.get_what(disc_state) + state = self._undiscretize(act_size, df, db) + state = (i, *state, input_has_bar) + i, act_size, df, db, input_has_bar = state + + if what is None: + return None + + # if loss + if i == self.length: + result.insert(Loss()) + return result + + if what[0]: + do_offload = what[1] + values = self._common_values_enable(state) + new_state = self._discretize(*self._new_values(state, do_offload, values)) + new_state = (i + 1, *new_state, True) + if do_offload: + result.insert(Offload(i, input_has_bar)) + result.insert(ForwardEnable(i)) + result.insert_sequence(self.pofo_rec(new_state)) + if do_offload: + result.insert(Prefetch(i, input_has_bar)) + result.insert(Backward(i)) + + else: + _, do_offload, j = what + values = self._common_values_nograd(state, j) + new_state = self._discretize(*self._new_values(state, do_offload, values)) + new_state = (j + 1, *new_state, False) + if do_offload: + result.insert(Offload(i, input_has_bar)) + result.insert(ForwardCheck(i)) + for k in range(i + 1, j + 1): + result.insert(ForwardNograd(k)) + result.insert_sequence(self.pofo_rec(new_state)) + if do_offload: + result.insert(Prefetch(i, input_has_bar)) + m = self.max_memory - act_size - (self.chain.cbweight[i] if input_has_bar else self.chain.cweight[i]) + + #TODO: Implement the precise backward recompute sequence mentioned in the paper + result.insert_sequence(self._rotor_estimated_bwd_sequence(i, j, m, db)) + + return result + + +def solver_pofo(gm: ColoGraphModule, + data, + bandwidth, + flops, + mem_limit: int, + mem_slots: int = 50, + cnode: List[str] = None, + eps: float = 0.0) -> ColoGraphModule: + """Solver that combine offload and activation checkpoint + Reference: https://proceedings.neurips.cc/paper/2021/hash/c8461bf13fca8a2b9912ab2eb1668e4b-Abstract.html + + Args: + gm (ColoGraphModule): ColoGraphModule derived from tracer + data: input of the model + bandwidth: offload bandwidth, unit Byte/s + flops: FLOPS of device, unit FLOPs/s + mem_limit (int): memory limit, unit Byte + mem_slots (int, optional): number of memory slots. Defaults to 500. + cnode (List[str], optional): common node for linearize. Defaults to None. + eps (float, optional): epsilon for memory decay. Defaults to 0.02. + + Returns: + ColoGraphModule: annotated graph module + """ + + node_list = linearize(gm, cnode) + mem_limit -= parameter_size(gm) + + # prepare data + MetaInfoProp(gm).run(data) + chain: Chain = _construct_chain(node_list, data) + chain = _normalize_flops(chain, flops) + # currently we view loss as an op without expense + chain.cbweight.append(0) + chain.cweight.append(0) + chain.fwd_mem_tmp.append(0) + chain.bwd_mem_tmp.append(0) + chain.fweight.append(0) + chain.bweight.append(0) + + solver = PofoSolver(chain, mem_limit, bandwidth, mem_slots) + first_state = (0, 0, 0, 0, False) + sequence = solver.pofo_rec(first_state) + if sequence == None: + print(f"Can not solve strategy with {mem_limit / 1024**2} MB memory!") + + setattr(gm, "__sequence__", sequence) + return gm diff --git a/colossalai/fx/passes/algorithms/ckpt_solver_rotor.py b/colossalai/fx/passes/algorithms/ckpt_solver_rotor.py index d0928b405..f9991c407 100644 --- a/colossalai/fx/passes/algorithms/ckpt_solver_rotor.py +++ b/colossalai/fx/passes/algorithms/ckpt_solver_rotor.py @@ -4,7 +4,7 @@ from colossalai.fx.graph_module import ColoGraphModule from colossalai.fx.profiler import activation_size, parameter_size import math from .linearize import linearize -from .utils import * +from .operation import ForwardCheck, ForwardEnable, ForwardNograd, Backward, Loss, Chain, Sequence, Function from colossalai.fx.passes.meta_info_prop import MetaInfoProp from colossalai.fx.codegen.activation_checkpoint_codegen import _find_nested_ckpt_regions @@ -110,10 +110,6 @@ def _rec(chain: Chain, lmin, lmax, cmem, opt_table): return sequence -def _discretize(mem_unit, values): - return [math.ceil(value / mem_unit) for value in values] - - def _fwd_xbar(node: List[Node]) -> int: """Get the forward xbar of a node @@ -204,7 +200,7 @@ def _get_bwd_mem_tmp(node: List[Node]) -> int: return bwd_mem_tmp -def _construct_chain(node_list: List[List[Node]], input, mem_unit: int) -> Chain: +def _construct_chain(node_list: List[List[Node]], input) -> Chain: fwd_time = [] bwd_time = [] @@ -226,11 +222,6 @@ def _construct_chain(node_list: List[List[Node]], input, mem_unit: int) -> Chain # currently we view loss backward temp as zero tmp_bwd.append(0) - xbar_sizes = _discretize(mem_unit, xbar_sizes) - x_sizes = _discretize(mem_unit, x_sizes) - tmp_fwd = _discretize(mem_unit, tmp_fwd) - tmp_bwd = _discretize(mem_unit, tmp_bwd) - return Chain(fwd_time, bwd_time, x_sizes, xbar_sizes, tmp_fwd, tmp_bwd) @@ -345,7 +336,9 @@ def solver_rotor(gm: ColoGraphModule, mem_limit -= parameter_size(gm) mem_unit = mem_limit * (1.0 - eps) // mem_slots MetaInfoProp(gm).run(data) - chain: Chain = _construct_chain(node_list, data, mem_unit) + + chain: Chain = _construct_chain(node_list, data) + chain._discretize(mem_unit) opt_table = _compute_table(chain, mem_slots) sequence = _rec(chain, 0, chain.length, mem_slots - chain.cweight[0], opt_table) _annotate_from_sequence(sequence, node_list) diff --git a/colossalai/fx/passes/algorithms/utils.py b/colossalai/fx/passes/algorithms/operation.py similarity index 86% rename from colossalai/fx/passes/algorithms/utils.py rename to colossalai/fx/passes/algorithms/operation.py index 78fb0c363..cedbbe85e 100644 --- a/colossalai/fx/passes/algorithms/utils.py +++ b/colossalai/fx/passes/algorithms/operation.py @@ -1,3 +1,10 @@ +import math + + +def _discretize(mem_unit, values): + return [math.ceil(value / mem_unit) for value in values] + + class Chain: def __init__(self, fw, bw, cw, cbw, ftmp, btmp, check=True): @@ -25,6 +32,12 @@ class Chain: chain_list.append((None, self.bweight[i], self.cweight[i], self.cbweight[i], None, self.bwd_mem_tmp[i])) return chain_list.__repr__() + def _discretize(self, mem_unit): + self.cweight = _discretize(mem_unit, self.cweight) + self.cbweight = _discretize(mem_unit, self.cbweight) + self.fwd_mem_tmp = _discretize(mem_unit, self.fwd_mem_tmp) + self.bwd_mem_tmp = _discretize(mem_unit, self.bwd_mem_tmp) + class Operation: @@ -35,6 +48,32 @@ class Operation: self.index += value +class Offload(Operation): + + def __init__(self, index, has_bar=False) -> None: + super().__init__() + self.index = index + self.name = "Off" + if has_bar: + self.name += "wBar" + + def __repr__(self): + return f"{self.name}_{self.index}" + + +class Prefetch(Operation): + + def __init__(self, index, has_bar=False) -> None: + super().__init__() + self.index = index + self.name = "Pre" + if has_bar: + self.name += "wBar" + + def __repr__(self): + return f"{self.name}_{self.index}" + + class Forward(Operation): def __init__(self, index): diff --git a/tests/test_fx/test_ckpt_solvers/test_linearize.py b/tests/test_fx/test_ckpt_solvers/test_linearize.py index 4b6f91a4d..8f5d6abe5 100644 --- a/tests/test_fx/test_ckpt_solvers/test_linearize.py +++ b/tests/test_fx/test_ckpt_solvers/test_linearize.py @@ -3,7 +3,7 @@ import torchvision.models as tm from colossalai.fx import ColoTracer from colossalai.fx.graph_module import ColoGraphModule from colossalai.fx.passes.algorithms import solver_rotor, linearize -from colossalai.fx.passes.algorithms.utils import Loss, ForwardCheck, ForwardEnable, ForwardNograd +from colossalai.fx.passes.algorithms.operation import Loss, ForwardCheck, ForwardEnable, ForwardNograd import pytest try: