From 56159049e869744bbfdf01a08fa6a0557a2a236d Mon Sep 17 00:00:00 2001 From: Boyuan Yao <70263930+Cypher30@users.noreply.github.com> Date: Fri, 2 Sep 2022 10:24:41 +0800 Subject: [PATCH] [fx] Modify solver linearize and add corresponding test (#1531) * [fx] modify solver linearize and add test * [fx] add torch11 test of linearize but skip it * [fx] remove some unused imports --- .../codegen/activation_checkpoint_codegen.py | 3 +- .../fx/passes/algorithms/ckpt_solver_rotor.py | 40 +++--- colossalai/fx/passes/algorithms/linearize.py | 109 +++++---------- .../test_ckpt_solvers/test_linearize.py | 128 ++++++++++++++++++ 4 files changed, 181 insertions(+), 99 deletions(-) create mode 100644 tests/test_fx/test_ckpt_solvers/test_linearize.py diff --git a/colossalai/fx/codegen/activation_checkpoint_codegen.py b/colossalai/fx/codegen/activation_checkpoint_codegen.py index def67c60f..391d64405 100644 --- a/colossalai/fx/codegen/activation_checkpoint_codegen.py +++ b/colossalai/fx/codegen/activation_checkpoint_codegen.py @@ -167,8 +167,7 @@ def emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node_func, use_reentrant = True non_leaf_input = 0 for var in input_vars[label]: - input_node = [item for item in node_list if item.name == var] - input_node = input_node[0] + input_node = next(item for item in node_list if item.name == var) if input_node.op != "placeholder": non_leaf_input = 1 for user in input_node.users: diff --git a/colossalai/fx/passes/algorithms/ckpt_solver_rotor.py b/colossalai/fx/passes/algorithms/ckpt_solver_rotor.py index ce209b674..eeb43f3a7 100644 --- a/colossalai/fx/passes/algorithms/ckpt_solver_rotor.py +++ b/colossalai/fx/passes/algorithms/ckpt_solver_rotor.py @@ -114,7 +114,7 @@ def _discretize(mem_unit, values): return [math.ceil(value / mem_unit) for value in values] -def _construct_chain(node_dict: Dict[int, Node], data: torch.Tensor, mem_unit: int) -> Chain: +def _construct_chain(node_list: List[List[Node]], data: torch.Tensor, mem_unit: int) -> Chain: fwd_time = [] bwd_time = [] @@ -122,22 +122,22 @@ def _construct_chain(node_dict: Dict[int, Node], data: torch.Tensor, mem_unit: i x_sizes = [data.numel() * data.element_size()] # currently we can't get the temp memory needed in fwd and bwd - tmp_fwd = [0] * len(node_dict) - tmp_bwd = [0] * (len(node_dict) + 1) + tmp_fwd = [0] * len(node_list) + tmp_bwd = [0] * (len(node_list) + 1) - for key in node_dict.keys(): + for idx, node in enumerate(node_list): fwd_time.append(0) bwd_time.append(0) xbar_sizes.append(0) - x_sizes.append(node_dict[key][-1].meta['tensor_meta'].numel * - torch.tensor([], dtype=node_dict[key][-1].meta['tensor_meta'].dtype).element_size()) - for node in node_dict[key]: - fwd_time[-1] += max(node.__flops__, 1) + x_sizes.append(node[-1].meta['tensor_meta'].numel * + torch.tensor([], dtype=node[-1].meta['tensor_meta'].dtype).element_size()) + for n in node: + fwd_time[-1] += max(n.__flops__, 1) # currently we haven't patched the backward flops count - bwd_time[-1] += max(node.__flops__ * 2, 2) + bwd_time[-1] += max(n.__flops__ * 2, 2) - xbar_sizes[-1] += node.__activation__ + xbar_sizes[-1] += n.__activation__ xbar_sizes[-1] = max(xbar_sizes[-1], x_sizes[-1]) @@ -151,14 +151,14 @@ def _construct_chain(node_dict: Dict[int, Node], data: torch.Tensor, mem_unit: i return Chain(fwd_time, bwd_time, x_sizes, xbar_sizes, tmp_fwd, tmp_bwd) -def _annotate_from_sequence(sequence: Sequence, node_dict: Dict[int, Node]) -> GraphModule: +def _annotate_from_sequence(sequence: Sequence, node_list: List[List[Node]]) -> GraphModule: op_list = sequence.list_operations() - loss_op = [op for op in op_list if isinstance(op, Loss)][0] + loss_op = next(op for op in op_list if isinstance(op, Loss)) op_list = op_list[:op_list.index(loss_op)] ckpt_idx = 0 in_ckpt = False ckpt_region = [] - for idx, op in enumerate(op_list, 1): + for idx, op in enumerate(op_list, 0): if in_ckpt: if isinstance(op, ForwardNograd): ckpt_region.append(idx) @@ -166,16 +166,16 @@ def _annotate_from_sequence(sequence: Sequence, node_dict: Dict[int, Node]) -> G elif isinstance(op, ForwardEnable): in_ckpt = False for node_idx in ckpt_region: - for node in node_dict[node_idx]: - setattr(node, "activation_checkpoint", ckpt_idx) + for n in node_list[node_idx]: + setattr(n, "activation_checkpoint", ckpt_idx) ckpt_idx += 1 ckpt_region = [] elif isinstance(op, ForwardCheck): for node_idx in ckpt_region: - for node in node_dict[node_idx]: - setattr(node, "activation_checkpoint", ckpt_idx) + for n in node_list[node_idx]: + setattr(n, "activation_checkpoint", ckpt_idx) ckpt_idx += 1 ckpt_region = [idx] @@ -199,13 +199,13 @@ def solver_rotor(gm: ColoGraphModule, data: torch.Tensor, mem_limit: int, mem_sl ColoGraphModule: annotated ColoGraphModuled with __sequence__ attribute """ - node_dict = linearize(gm) + node_list = linearize(gm) mem_unit = mem_limit // mem_slots MetaInfoProp(gm).run(data) - chain: Chain = _construct_chain(node_dict, data, mem_unit) + chain: Chain = _construct_chain(node_list, data, mem_unit) opt_table = _compute_table(chain, mem_slots) sequence = _rec(chain, 0, chain.length, mem_slots - chain.cweight[0], opt_table) - _annotate_from_sequence(sequence, node_dict) + _annotate_from_sequence(sequence, node_list) # set __sequence__ attribute to GraphModule setattr(gm, "__sequence__", sequence) diff --git a/colossalai/fx/passes/algorithms/linearize.py b/colossalai/fx/passes/algorithms/linearize.py index 19d84a046..e6b47a7ba 100644 --- a/colossalai/fx/passes/algorithms/linearize.py +++ b/colossalai/fx/passes/algorithms/linearize.py @@ -1,89 +1,44 @@ -from typing import OrderedDict -from torch.fx import GraphModule -from collections import OrderedDict -import pdb +from typing import List +from torch.fx import GraphModule, Node -def linearize(gm: GraphModule) -> dict: - status_dict = {} - node_dict = OrderedDict() - node_idx = 0 - for node in gm.graph.nodes: - last_dict_len = len(status_dict) - # remove node from users list in status_dict - for item in status_dict.values(): - if node in item: - item.remove(node) +def linearize(gm: GraphModule) -> List[List[Node]]: + """Linearizing the graph - # pop node from status_dict if it is fully used - for key in list(status_dict): - if len(status_dict[key]) == 0: - status_dict.pop(key) + Args: + gm (GraphModule): GraphModule derived by tracing - # first node in graph, it should be in n0-n1 type, - # where n0 contains only input op, i.e. placeholder - if last_dict_len == 0: - node_dict[node_idx] = [node] - status_dict[node.name] = list(node.users) - node_idx += 1 - node_dict[node_idx] = [] + Returns: + List[List[Node]]: List of list, each inside list of Node presents + the actual 'node' in linearized manner. + """ - continue + def _is_sink() -> bool: + """Check if we can free all dependencies - # boundary case - if len(status_dict) == 0: - # current node region end point = next node region start point - # i.e. n1-n2-n3-... type node, each node contains only one op - if last_dict_len == 1: - if len(node_dict[node_idx]) > 0: - node_idx += 1 - node_dict[node_idx] = [] - node_dict[node_idx].append(node) - status_dict[node.name] = list(node.users) + Returns: + bool + """ - continue + return not sum([v for _, v in deps.items()]) - # n1-n2-n3, if n1 has multiple ops, the last op in n1 will be - # the one who is able to clean all others in status_dict - # and as the last_dict_len > 1, there are multiple ops are used - # by this node, we view it as the end of one node and start a new node - else: + deps = {} + linearized_nodes = [] + region = [] - node_dict[node_idx].append(node) - status_dict[node.name] = list(node.users) - node_idx += 1 - node_dict[node_idx] = [] + for n in gm.graph.nodes: + for n_par in n._input_nodes: + deps[n_par] -= 1 + region.append(n) - continue + # if the node could free all dependencies in graph + # we could begin a new node + if _is_sink(): + linearized_nodes.append(region) + region = [] - else: - # currently I will use bigger node structure - # if the following region is activated, the node will be smaller - ################################################# - # if last_dict_len == 1: - # if len(node_dict[node_idx]) > 0: - # node_idx += 1 - # node_dict[node_idx] = [node] - # status_dict[node.name] = list(node.users) - # - # continue - ################################################# + deps[n] = len(n.users) - # in-node case, as the current node can not clean status_dict - # we view it as in-node status, the node will be appended to the - # current node_idx - node_dict[node_idx].append(node) - status_dict[node.name] = list(node.users) - - continue - - # If the output node use multiple nodes, there might be an - # empty node after the output node - if len(node_dict[node_idx]) == 0: - node_dict.pop[node_idx] - node_idx -= 1 - - # pop the last two nodes - node_dict.pop(0) - node_dict.pop(node_idx) - return node_dict + # Remove input + linearized_nodes = linearized_nodes[1:-1] + return linearized_nodes diff --git a/tests/test_fx/test_ckpt_solvers/test_linearize.py b/tests/test_fx/test_ckpt_solvers/test_linearize.py new file mode 100644 index 000000000..36bd87b42 --- /dev/null +++ b/tests/test_fx/test_ckpt_solvers/test_linearize.py @@ -0,0 +1,128 @@ +import torch +import torchvision.models as tm +from colossalai.fx import ColoTracer +from colossalai.fx.graph_module import ColoGraphModule +from colossalai.fx.passes.algorithms import solver_rotor, linearize +from colossalai.fx.passes.algorithms.utils import Loss, ForwardCheck, ForwardEnable, ForwardNograd +import pytest + +try: + from colossalai.fx.codegen import ActivationCheckpointCodeGen + with_codegen = True +except: + # fall back to older pytorch version + from colossalai.fx.codegen import python_code_with_activation_checkpoint + with_codegen = False + + +@pytest.mark.skipif(not with_codegen, reason="torch version is lower than 1.12.0") +def test_linearize(): + MODEL_DICT = {tm.resnet18: [2100, 3000], tm.densenet121: [8100, 17000]} + tracer = ColoTracer() + for M, budgets in MODEL_DICT.items(): + for budget in budgets: + model = M() + graph = tracer.trace(model) + graph.set_codegen(ActivationCheckpointCodeGen()) + gm = ColoGraphModule(model, graph, model.__class__.__name__) + node_list = linearize(gm) + gm = solver_rotor(gm, data=torch.rand(128, 3, 224, 224, device="meta"), mem_limit=budget * 1024**2) + op_list = gm.__sequence__.list_operations() + loss_op = next(op for op in op_list if isinstance(op, Loss)) + op_list = op_list[:op_list.index(loss_op)] + in_ckpt = False + ckpt_idx = 0 + for idx, op in enumerate(op_list): + if in_ckpt: + if isinstance(op, ForwardNograd): + for n in node_list[idx]: + assert hasattr(n, "activation_checkpoint"), f"{n} is not annotated!" + assert n.activation_checkpoint == ckpt_idx, f"{n} ckpt_idx wrong, should be {ckpt_idx}!" + + continue + + if isinstance(op, ForwardEnable): + for n in node_list[idx]: + assert getattr(n, "activation_checkpoint", None) == None, f"{n} should not be annotated!" + in_ckpt = False + + ckpt_idx += 1 + continue + + if isinstance(op, ForwardCheck): + ckpt_idx += 1 + for n in node_list[idx]: + assert hasattr(n, "activation_checkpoint"), f"{n} is not annotated!" + assert n.activation_checkpoint == ckpt_idx, f"{n} ckpt_idx wrong, should be {ckpt_idx}!" + + continue + + else: + if isinstance(op, ForwardCheck): + in_ckpt = True + for n in node_list[idx]: + assert hasattr(n, "activation_checkpoint"), f"{n} is not annotated!" + assert n.activation_checkpoint == ckpt_idx, f"{n} ckpt_idx wrong, should be {ckpt_idx}!" + + del model + del gm + del node_list + + +@pytest.mark.skip(reason="torch11 meta tensor not implemented") +@pytest.mark.skipif(with_codegen, reason="torch version is equal to or higher than 1.12.0") +def test_linearize_torch11(): + MODEL_DICT = {tm.resnet18: [2100, 3000], tm.densenet121: [8100, 17000]} + tracer = ColoTracer() + for M, budgets in MODEL_DICT.items(): + for budget in budgets: + model = M() + graph = tracer.trace(model) + gm = ColoGraphModule(model, graph, model.__class__.__name__) + gm.graph._python_code = python_code_with_activation_checkpoint.__get__(graph) + node_list = linearize(gm) + gm = solver_rotor(gm, data=torch.rand(128, 3, 224, 224, device="meta"), mem_limit=budget * 1024**2) + op_list = gm.__sequence__.list_operations() + loss_op = next(op for op in op_list if isinstance(op, Loss)) + op_list = op_list[:op_list.index(loss_op)] + in_ckpt = False + ckpt_idx = 0 + for idx, op in enumerate(op_list): + if in_ckpt: + if isinstance(op, ForwardNograd): + for n in node_list[idx]: + assert hasattr(n, "activation_checkpoint"), f"{n} is not annotated!" + assert n.activation_checkpoint == ckpt_idx, f"{n} ckpt_idx wrong, should be {ckpt_idx}!" + + continue + + if isinstance(op, ForwardEnable): + for n in node_list[idx]: + assert getattr(n, "activation_checkpoint", None) == None, f"{n} should not be annotated!" + in_ckpt = False + + ckpt_idx += 1 + continue + + if isinstance(op, ForwardCheck): + ckpt_idx += 1 + for n in node_list[idx]: + assert hasattr(n, "activation_checkpoint"), f"{n} is not annotated!" + assert n.activation_checkpoint == ckpt_idx, f"{n} ckpt_idx wrong, should be {ckpt_idx}!" + + continue + + else: + if isinstance(op, ForwardCheck): + in_ckpt = True + for n in node_list[idx]: + assert hasattr(n, "activation_checkpoint"), f"{n} is not annotated!" + assert n.activation_checkpoint == ckpt_idx, f"{n} ckpt_idx wrong, should be {ckpt_idx}!" + + del model + del gm + del node_list + + +if __name__ == "__main__": + test_linearize()