From f9217336214a6325ec98e3532535fbb876e05f41 Mon Sep 17 00:00:00 2001 From: Boyuan Yao <70263930+Cypher30@users.noreply.github.com> Date: Sat, 24 Sep 2022 01:52:57 +0800 Subject: [PATCH] [autoparallel] Add pofo sequence annotation (#1637) * [autoparallel] annotate pofo sequence * [autoparallel] remove unused print * [autoparallel] fix some code --- .../codegen/activation_checkpoint_codegen.py | 4 +- .../fx/passes/algorithms/ckpt_solver_pofo.py | 129 +++++++++++++++++- colossalai/fx/passes/algorithms/operation.py | 6 +- 3 files changed, 133 insertions(+), 6 deletions(-) diff --git a/colossalai/fx/codegen/activation_checkpoint_codegen.py b/colossalai/fx/codegen/activation_checkpoint_codegen.py index 4da4315d4..5bb04a68e 100644 --- a/colossalai/fx/codegen/activation_checkpoint_codegen.py +++ b/colossalai/fx/codegen/activation_checkpoint_codegen.py @@ -145,7 +145,7 @@ def _find_ckpt_regions(nodes: List[Node]): def _find_offload_regions(nodes: List[Node]): """This function is to find the offload regions In pofo algorithm, during annotation, we will annotate the offload region with the - tuple in the form of (idx, offload_input, offload_bar). idx indicates the offload + list in the form of [idx, offload_input, offload_bar]. idx indicates the offload region's index, offload_input is a bool type indicates whether we need to offload the input, offload_bar is a bool type indicates whether we need to offload all the intermediate x_bars of this region. @@ -157,7 +157,7 @@ def _find_offload_regions(nodes: List[Node]): current_region = None for idx, node in enumerate(nodes): - if hasattr(node, 'activation_offload') and isinstance(getattr(node, 'activation_offload', False), tuple): + if hasattr(node, 'activation_offload') and isinstance(getattr(node, 'activation_offload', None), list): act_offload_label = node.activation_offload if current_region == None: diff --git a/colossalai/fx/passes/algorithms/ckpt_solver_pofo.py b/colossalai/fx/passes/algorithms/ckpt_solver_pofo.py index 03cf30714..b895eb038 100644 --- a/colossalai/fx/passes/algorithms/ckpt_solver_pofo.py +++ b/colossalai/fx/passes/algorithms/ckpt_solver_pofo.py @@ -97,6 +97,7 @@ class PofoSolver: self.bandwidth = bandwidth self.disc_chain = copy.deepcopy(self.chain) + self.disc_chain._discretize(self.mem_unit) self.rotor_table = _compute_table(self.disc_chain, mem_slots) self._compute_pofo_table() @@ -142,7 +143,7 @@ class PofoSolver: return (max(compute, comm) + compute + comm) / 2 def _rotor_estimated_bwd_sequence(self, i, j, m, delta): - return _rec(self.disc_chain, i, j, math.floor(m - self.chain.cweight[i] / self.mem_unit), self.rotor_table) + return _rec(self.disc_chain, i, j, math.floor((m - self.chain.cweight[i]) / self.mem_unit), self.rotor_table) def _common_values_enable(self, state: Tuple): @@ -354,6 +355,129 @@ class PofoSolver: return result +def _annotate_from_pofo_sequence(sequence: Sequence, node_list: List[List[Node]]): + op_list = sequence.list_operations() + loss_op = next(op for op in op_list if isinstance(op, Loss)) + fwd_list = op_list[:op_list.index(loss_op)] + bwd_list = op_list[op_list.index(loss_op) + 1:] + ckpt_idx = 0 + in_ckpt = False + ckpt_region = [] + + # forward annotation + for op in fwd_list: + if in_ckpt: + if isinstance(op, ForwardNograd): + ckpt_region.append(op.index) + + elif isinstance(op, ForwardEnable): + in_ckpt = False + for node_idx in ckpt_region: + for n in node_list[node_idx]: + setattr(n, "activation_checkpoint", [ckpt_idx]) + + ckpt_idx += 1 + ckpt_region = [] + + elif isinstance(op, ForwardCheck): + for node_idx in ckpt_region: + for n in node_list[node_idx]: + setattr(n, "activation_checkpoint", [ckpt_idx]) + + ckpt_idx += 1 + ckpt_region = [op.index] + + else: + if isinstance(op, ForwardCheck): + in_ckpt = True + ckpt_region.append(op.index) + + # annotate the backward if there is any nested activation checkpoint + in_recompute = False + for op in bwd_list: + if in_recompute: + if isinstance(op, ForwardNograd): + ckpt_region.append(op.index) + + elif isinstance(op, ForwardEnable): + for node_idx in ckpt_region: + for n in node_list[node_idx]: + n.activation_checkpoint.append(ckpt_idx) + + ckpt_idx += 1 + ckpt_region = [] + + elif isinstance(op, ForwardCheck): + for node_idx in ckpt_region: + for n in node_list[node_idx]: + n.activation_checkpoint.append(ckpt_idx) + + ckpt_idx += 1 + ckpt_region = [op.index] + + elif isinstance(op, Backward): + for node_idx in ckpt_region: + for n in node_list[node_idx]: + n.activation_checkpoint.append(ckpt_idx) + + in_recompute = False + + else: + if not isinstance(op, Backward): + in_recompute = True + ckpt_idx = 0 + ckpt_region = [] + if isinstance(op, ForwardCheck): + ckpt_region.append(op.index) + + # postprocess, make sure every activation checkpoint label in the + # same activation checkpoint region (level = 0) has the same length + op_list = [] + for node in node_list: + op_list += node + ckpt_regions = _find_nested_ckpt_regions(op_list) + for (start_idx, end_idx) in ckpt_regions: + nested_length = max(len(op_list[idx].activation_checkpoint) for idx in range(start_idx, end_idx + 1)) + for idx in range(start_idx, end_idx + 1): + op_list[idx].activation_checkpoint += [None] * (nested_length - len(op_list[idx].activation_checkpoint)) + + # annotate the offload + offload_idx = 0 + for idx, op in enumerate(fwd_list): + if isinstance(op, Offload): + # corner case: offload input + if op.index == 0: + if isinstance(fwd_list[idx + 1], ForwardCheck): + for n in node_list[op.index]: + setattr(n, "activation_offload", True) + else: + for n in node_list[op.index]: + setattr(n, "activation_offload", (offload_idx, True, False)) + offload_idx += 1 + + else: + if op.has_bar: + # annotate previous node + if hasattr(node_list[op.index - 1][0], "activation_offload"): + for n in node_list[op.index - 1]: + n.activation_offload[-1] = True + else: + for n in node_list[op.index - 1]: + setattr(n, "activation_offload", [offload_idx, False, True]) + + offload_idx += 1 + + # annotate this node + if isinstance(fwd_list[idx + 1], ForwardCheck): + for n in node_list[op.index]: + setattr(n, "activation_offload", True) + else: + for n in node_list[op.index]: + setattr(n, "activation_offload", [offload_idx, True, False]) + + offload_idx += 1 + + def solver_pofo(gm: ColoGraphModule, data, bandwidth, @@ -398,7 +522,8 @@ def solver_pofo(gm: ColoGraphModule, first_state = (0, 0, 0, 0, False) sequence = solver.pofo_rec(first_state) if sequence == None: - print(f"Can not solve strategy with {mem_limit / 1024**2} MB memory!") + raise ValueError(f"Cannot solve sequence with {mem_limit} Bytes memory") + _annotate_from_pofo_sequence(sequence, node_list) setattr(gm, "__sequence__", sequence) return gm diff --git a/colossalai/fx/passes/algorithms/operation.py b/colossalai/fx/passes/algorithms/operation.py index cedbbe85e..8bfa3452b 100644 --- a/colossalai/fx/passes/algorithms/operation.py +++ b/colossalai/fx/passes/algorithms/operation.py @@ -54,7 +54,8 @@ class Offload(Operation): super().__init__() self.index = index self.name = "Off" - if has_bar: + self.has_bar = has_bar + if self.has_bar: self.name += "wBar" def __repr__(self): @@ -67,7 +68,8 @@ class Prefetch(Operation): super().__init__() self.index = index self.name = "Pre" - if has_bar: + self.has_bar = has_bar + if self.has_bar: self.name += "wBar" def __repr__(self):