From 0dbd61c29b2d5373bcd9a3bc2be7ce14a1d07ef0 Mon Sep 17 00:00:00 2001 From: Super Daniel <78588128+super-dainiu@users.noreply.github.com> Date: Mon, 15 Aug 2022 19:09:19 +0800 Subject: [PATCH] [fx] fix test and algorithm bugs in activation checkpointing. (#1451) * [fx] modify the calculation of node_size in MetaInfoProp for activation checkpointing usages * [fx] modify the calculation of node_size in MetaInfoProp for activation checkpointing usages * [fx] modify the calculation of node_size in MetaInfoProp for activation checkpointing usages * [fx] merge development into main (#1) * [fx] activation checkpointing using Chen strategies. * [fx] add test for ckpt_solver_chen * [fx] add vanilla activation checkpoint search with test on resnet and densenet * [fx] add a namespace code for solver_chen. * [fx] fix the false interpretation of algorithm 3 in https://arxiv.org/abs/1604.06174. * [fx] fix lowercase naming conventions. * [fx] simplify test for ckpt. * [fx] fix test and algorithm bugs in activation checkpointing. * mend [fx] fix test and algorithm bugs in activation checkpointing. * mend [fx] fix test and algorithm bugs in activation checkpointing. * mend [fx] fix test and algorithm bugs in activation checkpointing. * mend [fx] fix test and algorithm bugs in activation checkpointing. * [fx] polish ckpt_test. * [fx] polish ckpt_test. * [fx] polish ckpt_test. --- .../fx/passes/algorithms/ckpt_solver_chen.py | 40 +++++--- .../test_ckpt_torchvision.py | 92 +++++++++++++++---- 2 files changed, 101 insertions(+), 31 deletions(-) diff --git a/colossalai/fx/passes/algorithms/ckpt_solver_chen.py b/colossalai/fx/passes/algorithms/ckpt_solver_chen.py index 046b165a6..8b404e3a6 100644 --- a/colossalai/fx/passes/algorithms/ckpt_solver_chen.py +++ b/colossalai/fx/passes/algorithms/ckpt_solver_chen.py @@ -1,4 +1,4 @@ -from typing import Set, Tuple +from typing import List, Set, Tuple import torch from torch.fx import GraphModule import math @@ -6,6 +6,14 @@ import math __all__ = ['chen_greedy', 'chen_sqrtn'] +def _all_potential_ckpt_nodes(gm: GraphModule) -> List: + ckpt_nodes = [] + for n in gm.graph.nodes: + if n.op == 'call_module': + ckpt_nodes.append(n) + return ckpt_nodes + + def chen_greedy(gm: GraphModule) -> GraphModule: """ This is the simple implementation of Algorithm 3 in https://arxiv.org/abs/1604.06174. @@ -31,36 +39,40 @@ def chen_greedy(gm: GraphModule) -> GraphModule: b_min, b_max = math.floor(b_approx / math.sqrt(2)), math.ceil(b_approx * math.sqrt(2)) b_opt = math.inf for b in range(b_min, b_max, (b_max - b_min) // num_grids): - ckpt, b_approx = run_chen_greedy(b) + ckpt_intv, b_approx = run_chen_greedy(b) if b_approx < b_opt: b_opt = b_approx - ckpt_opt = ckpt + ckpt_opt = ckpt_intv return ckpt_opt def run_chen_greedy(b: int = 0) -> Tuple[Set, int]: """ This is the simple implementation of Algorithm 3 in https://arxiv.org/abs/1604.06174. """ - ckpt = set() + ckpt_nodes = _all_potential_ckpt_nodes(gm) + ckpt_intv = [] temp = 0 x = 0 y = 0 + prev_idx = 2 for (idx, n) in enumerate(gm.graph.nodes): temp += getattr(n, 'activation_size') y = max(y, temp) - if temp > b: + if temp > b and n in ckpt_nodes: x += getattr(n, 'activation_size') temp = 0 - ckpt.add(idx) - return ckpt, math.floor(math.sqrt(x * y)) + ckpt_intv.append((prev_idx, idx + 1)) + prev_idx = idx + 1 + return ckpt_intv, math.floor(math.sqrt(x * y)) gm.graph.lint() # make sure nodes are in topological order ckpt = grid_search(num_grids=6) - i = 0 - for idx, n in enumerate(gm.graph.nodes): - if idx in ckpt: - setattr(n, 'activation_checkpoint', str(i)) - i += 1 + node_list = list(gm.graph.nodes) + for i, seg in enumerate(ckpt): + for idx in range(*seg): + n = node_list[idx] + if n.op in ['call_module', 'call_method', 'call_function']: + setattr(n, 'activation_checkpoint', str(i)) gm.recompile() return gm @@ -82,7 +94,9 @@ def chen_sqrtn(gm: GraphModule) -> GraphModule: gm.graph.lint() # make sure nodes are in topological order k = int(len(gm.graph.nodes)**0.5) # take approximately sqrt(n) checkpoints for idx, n in enumerate(gm.graph.nodes): - if (idx + 1) % k == 0: + # We should not add act_ckpt to the placeholder + # The last segment should not be checkpointed + if n.op != 'placeholder' and (idx + 1) // k < k: setattr(n, 'activation_checkpoint', str((idx + 1) // k)) gm.recompile() return gm diff --git a/tests/test_fx/test_ckpt_solvers/test_ckpt_torchvision.py b/tests/test_fx/test_ckpt_solvers/test_ckpt_torchvision.py index 169b4bcb6..1772c2840 100644 --- a/tests/test_fx/test_ckpt_solvers/test_ckpt_torchvision.py +++ b/tests/test_fx/test_ckpt_solvers/test_ckpt_torchvision.py @@ -1,12 +1,25 @@ -from ctypes import Union -from colossalai.fx.passes.algorithms import chen_greedy, chen_sqrtn +from typing import Callable +import copy import torch +import torch.multiprocessing as mp import torchvision.models as tm -from colossalai.fx import ColoTracer from torch.fx import GraphModule +import colossalai +from colossalai.fx import ColoTracer from colossalai.fx.passes.meta_info_prop import MetaInfoProp +from colossalai.fx.passes.algorithms import chen_greedy, chen_sqrtn +from colossalai.utils import free_port +from colossalai.core import global_context as gpc import pytest +try: + from colossalai.fx.codegen import ActivationCheckpointCodeGen + with_codegen = True +except: + # fall back to older pytorch version + from colossalai.fx.codegen import python_code_with_activation_checkpoint + with_codegen = False + SOLVERS = [chen_greedy, chen_sqrtn] @@ -18,37 +31,80 @@ def _is_activation_checkpoint_available(gm: GraphModule): def _is_all_gradient_close(m: torch.nn.Module, gm: GraphModule): for m_p, gm_p in zip(m.parameters(), gm.parameters()): - if not torch.allclose(m_p, gm_p): + if not torch.allclose(m_p.grad, gm_p.grad): return False return True -def test_ckpt_solver(): +def check_backward_consistency(m: torch.nn.Module, gm: GraphModule, solver: Callable[[GraphModule], GraphModule], + model_cls: Callable[[], torch.nn.Module]): + criterion = torch.nn.MSELoss() + data = torch.rand(2, 3, 32, 32) + label = torch.rand(2, 5) + loss = criterion(m(data), label) + loss.backward() + loss = criterion(gm(data), label) + loss.backward() + assert _is_all_gradient_close(m, gm), f'Solver {solver} did not work correctly in backward pass on {model_cls}' + + +def _run_ckpt_solver(rank): + colossalai.launch(config={}, rank=rank, world_size=1, host='localhost', port=free_port(), backend='nccl') MODEL_LIST = [tm.resnet18, tm.densenet121] torch.backends.cudnn.deterministic = True - tracer = ColoTracer() - data = torch.rand(1, 3, 224, 224) - label = torch.rand(1, 1000) + tracer = ColoTracer(trace_act_ckpt=False) + data = torch.rand(2, 3, 32, 32) for solver in SOLVERS: for model_cls in MODEL_LIST: - model = model_cls() - criterion = torch.nn.MSELoss() - graph = tracer.trace(root=model) - gm = GraphModule(model, graph, model.__class__.__name__) + m = model_cls(num_classes=5) + graph = tracer.trace(root=m) + gm = GraphModule(copy.deepcopy(m), graph, m.__class__.__name__) MetaInfoProp(gm).run(data) + codegen = ActivationCheckpointCodeGen() + gm.graph.set_codegen(codegen) gm = solver(gm) assert _is_activation_checkpoint_available( gm), f"Solver {solver} did not annotate {model_cls} with any activation checkpoints" - loss = criterion(model(data), label) - loss.backward() - loss = criterion(gm(data), label) - loss.backward() - assert _is_all_gradient_close(model, - gm), f'Solver {solver} did not work correctly in backward pass on {model_cls}' + check_backward_consistency(m, gm, solver, model_cls) + + +@pytest.mark.skip +@pytest.mark.skipif(not with_codegen, reason='torch version is lower than 1.12.0') +def test_ckpt_solver(): + mp.spawn(_run_ckpt_solver, nprocs=1) + + +def _run_ckpt_solver_torch11(rank): + colossalai.launch(config={}, rank=rank, world_size=1, host='localhost', port=free_port(), backend='nccl') + MODEL_LIST = [tm.resnet18, tm.densenet121] + + torch.backends.cudnn.deterministic = True + + tracer = ColoTracer(trace_act_ckpt=False) + + data = torch.rand(2, 3, 32, 32) + for solver in SOLVERS: + for model_cls in MODEL_LIST: + m = model_cls(num_classes=5) + graph = tracer.trace(root=m) + gm = GraphModule(copy.deepcopy(m), graph, m.__class__.__name__) + MetaInfoProp(gm).run(data) + gm.graph._python_code = python_code_with_activation_checkpoint.__get__(graph) + gm = solver(gm) + assert _is_activation_checkpoint_available( + gm), f"Solver {solver} did not annotate {model_cls} with any activation checkpoints" + check_backward_consistency(m, gm, solver, model_cls) + + +@pytest.mark.skip +@pytest.mark.skipif(with_codegen, reason='torch version is equal to or higher than 1.12.0') +def test_ckpt_solver_torch11(): + mp.spawn(_run_ckpt_solver_torch11, nprocs=1) if __name__ == '__main__': test_ckpt_solver() + test_ckpt_solver_torch11()