[fx] fix test and algorithm bugs in activation checkpointing. (#1451)

* [fx] modify the calculation of node_size in MetaInfoProp for activation checkpointing usages

* [fx] modify the calculation of node_size in MetaInfoProp for activation checkpointing usages

* [fx] modify the calculation of node_size in MetaInfoProp for activation checkpointing usages

* [fx] merge development into main (#1)

* [fx] activation checkpointing using Chen strategies.

* [fx] add test for ckpt_solver_chen

* [fx] add vanilla activation checkpoint search with test on resnet and densenet

* [fx] add a namespace code for solver_chen.

* [fx] fix the false interpretation of algorithm 3 in https://arxiv.org/abs/1604.06174.

* [fx] fix lowercase naming conventions.

* [fx] simplify test for ckpt.

* [fx] fix test and algorithm bugs in activation checkpointing.

* mend

[fx] fix test and algorithm bugs in activation checkpointing.

* mend

[fx] fix test and algorithm bugs in activation checkpointing.

* mend

[fx] fix test and algorithm bugs in activation checkpointing.

* mend

[fx] fix test and algorithm bugs in activation checkpointing.

* [fx] polish ckpt_test.

* [fx] polish ckpt_test.

* [fx] polish ckpt_test.
This commit is contained in:
Super Daniel
2022-08-15 19:09:19 +08:00
committed by GitHub
parent b1553fdf96
commit 0dbd61c29b
2 changed files with 101 additions and 31 deletions

View File

@@ -1,4 +1,4 @@
from typing import Set, Tuple
from typing import List, Set, Tuple
import torch
from torch.fx import GraphModule
import math
@@ -6,6 +6,14 @@ import math
__all__ = ['chen_greedy', 'chen_sqrtn']
def _all_potential_ckpt_nodes(gm: GraphModule) -> List:
ckpt_nodes = []
for n in gm.graph.nodes:
if n.op == 'call_module':
ckpt_nodes.append(n)
return ckpt_nodes
def chen_greedy(gm: GraphModule) -> GraphModule:
"""
This is the simple implementation of Algorithm 3 in https://arxiv.org/abs/1604.06174.
@@ -31,36 +39,40 @@ def chen_greedy(gm: GraphModule) -> GraphModule:
b_min, b_max = math.floor(b_approx / math.sqrt(2)), math.ceil(b_approx * math.sqrt(2))
b_opt = math.inf
for b in range(b_min, b_max, (b_max - b_min) // num_grids):
ckpt, b_approx = run_chen_greedy(b)
ckpt_intv, b_approx = run_chen_greedy(b)
if b_approx < b_opt:
b_opt = b_approx
ckpt_opt = ckpt
ckpt_opt = ckpt_intv
return ckpt_opt
def run_chen_greedy(b: int = 0) -> Tuple[Set, int]:
"""
This is the simple implementation of Algorithm 3 in https://arxiv.org/abs/1604.06174.
"""
ckpt = set()
ckpt_nodes = _all_potential_ckpt_nodes(gm)
ckpt_intv = []
temp = 0
x = 0
y = 0
prev_idx = 2
for (idx, n) in enumerate(gm.graph.nodes):
temp += getattr(n, 'activation_size')
y = max(y, temp)
if temp > b:
if temp > b and n in ckpt_nodes:
x += getattr(n, 'activation_size')
temp = 0
ckpt.add(idx)
return ckpt, math.floor(math.sqrt(x * y))
ckpt_intv.append((prev_idx, idx + 1))
prev_idx = idx + 1
return ckpt_intv, math.floor(math.sqrt(x * y))
gm.graph.lint() # make sure nodes are in topological order
ckpt = grid_search(num_grids=6)
i = 0
for idx, n in enumerate(gm.graph.nodes):
if idx in ckpt:
setattr(n, 'activation_checkpoint', str(i))
i += 1
node_list = list(gm.graph.nodes)
for i, seg in enumerate(ckpt):
for idx in range(*seg):
n = node_list[idx]
if n.op in ['call_module', 'call_method', 'call_function']:
setattr(n, 'activation_checkpoint', str(i))
gm.recompile()
return gm
@@ -82,7 +94,9 @@ def chen_sqrtn(gm: GraphModule) -> GraphModule:
gm.graph.lint() # make sure nodes are in topological order
k = int(len(gm.graph.nodes)**0.5) # take approximately sqrt(n) checkpoints
for idx, n in enumerate(gm.graph.nodes):
if (idx + 1) % k == 0:
# We should not add act_ckpt to the placeholder
# The last segment should not be checkpointed
if n.op != 'placeholder' and (idx + 1) // k < k:
setattr(n, 'activation_checkpoint', str((idx + 1) // k))
gm.recompile()
return gm