mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-24 14:33:20 +00:00
[fx] fix test and algorithm bugs in activation checkpointing. (#1451)
* [fx] modify the calculation of node_size in MetaInfoProp for activation checkpointing usages * [fx] modify the calculation of node_size in MetaInfoProp for activation checkpointing usages * [fx] modify the calculation of node_size in MetaInfoProp for activation checkpointing usages * [fx] merge development into main (#1) * [fx] activation checkpointing using Chen strategies. * [fx] add test for ckpt_solver_chen * [fx] add vanilla activation checkpoint search with test on resnet and densenet * [fx] add a namespace code for solver_chen. * [fx] fix the false interpretation of algorithm 3 in https://arxiv.org/abs/1604.06174. * [fx] fix lowercase naming conventions. * [fx] simplify test for ckpt. * [fx] fix test and algorithm bugs in activation checkpointing. * mend [fx] fix test and algorithm bugs in activation checkpointing. * mend [fx] fix test and algorithm bugs in activation checkpointing. * mend [fx] fix test and algorithm bugs in activation checkpointing. * mend [fx] fix test and algorithm bugs in activation checkpointing. * [fx] polish ckpt_test. * [fx] polish ckpt_test. * [fx] polish ckpt_test.
This commit is contained in:
parent
b1553fdf96
commit
0dbd61c29b
@ -1,4 +1,4 @@
|
|||||||
from typing import Set, Tuple
|
from typing import List, Set, Tuple
|
||||||
import torch
|
import torch
|
||||||
from torch.fx import GraphModule
|
from torch.fx import GraphModule
|
||||||
import math
|
import math
|
||||||
@ -6,6 +6,14 @@ import math
|
|||||||
__all__ = ['chen_greedy', 'chen_sqrtn']
|
__all__ = ['chen_greedy', 'chen_sqrtn']
|
||||||
|
|
||||||
|
|
||||||
|
def _all_potential_ckpt_nodes(gm: GraphModule) -> List:
|
||||||
|
ckpt_nodes = []
|
||||||
|
for n in gm.graph.nodes:
|
||||||
|
if n.op == 'call_module':
|
||||||
|
ckpt_nodes.append(n)
|
||||||
|
return ckpt_nodes
|
||||||
|
|
||||||
|
|
||||||
def chen_greedy(gm: GraphModule) -> GraphModule:
|
def chen_greedy(gm: GraphModule) -> GraphModule:
|
||||||
"""
|
"""
|
||||||
This is the simple implementation of Algorithm 3 in https://arxiv.org/abs/1604.06174.
|
This is the simple implementation of Algorithm 3 in https://arxiv.org/abs/1604.06174.
|
||||||
@ -31,36 +39,40 @@ def chen_greedy(gm: GraphModule) -> GraphModule:
|
|||||||
b_min, b_max = math.floor(b_approx / math.sqrt(2)), math.ceil(b_approx * math.sqrt(2))
|
b_min, b_max = math.floor(b_approx / math.sqrt(2)), math.ceil(b_approx * math.sqrt(2))
|
||||||
b_opt = math.inf
|
b_opt = math.inf
|
||||||
for b in range(b_min, b_max, (b_max - b_min) // num_grids):
|
for b in range(b_min, b_max, (b_max - b_min) // num_grids):
|
||||||
ckpt, b_approx = run_chen_greedy(b)
|
ckpt_intv, b_approx = run_chen_greedy(b)
|
||||||
if b_approx < b_opt:
|
if b_approx < b_opt:
|
||||||
b_opt = b_approx
|
b_opt = b_approx
|
||||||
ckpt_opt = ckpt
|
ckpt_opt = ckpt_intv
|
||||||
return ckpt_opt
|
return ckpt_opt
|
||||||
|
|
||||||
def run_chen_greedy(b: int = 0) -> Tuple[Set, int]:
|
def run_chen_greedy(b: int = 0) -> Tuple[Set, int]:
|
||||||
"""
|
"""
|
||||||
This is the simple implementation of Algorithm 3 in https://arxiv.org/abs/1604.06174.
|
This is the simple implementation of Algorithm 3 in https://arxiv.org/abs/1604.06174.
|
||||||
"""
|
"""
|
||||||
ckpt = set()
|
ckpt_nodes = _all_potential_ckpt_nodes(gm)
|
||||||
|
ckpt_intv = []
|
||||||
temp = 0
|
temp = 0
|
||||||
x = 0
|
x = 0
|
||||||
y = 0
|
y = 0
|
||||||
|
prev_idx = 2
|
||||||
for (idx, n) in enumerate(gm.graph.nodes):
|
for (idx, n) in enumerate(gm.graph.nodes):
|
||||||
temp += getattr(n, 'activation_size')
|
temp += getattr(n, 'activation_size')
|
||||||
y = max(y, temp)
|
y = max(y, temp)
|
||||||
if temp > b:
|
if temp > b and n in ckpt_nodes:
|
||||||
x += getattr(n, 'activation_size')
|
x += getattr(n, 'activation_size')
|
||||||
temp = 0
|
temp = 0
|
||||||
ckpt.add(idx)
|
ckpt_intv.append((prev_idx, idx + 1))
|
||||||
return ckpt, math.floor(math.sqrt(x * y))
|
prev_idx = idx + 1
|
||||||
|
return ckpt_intv, math.floor(math.sqrt(x * y))
|
||||||
|
|
||||||
gm.graph.lint() # make sure nodes are in topological order
|
gm.graph.lint() # make sure nodes are in topological order
|
||||||
ckpt = grid_search(num_grids=6)
|
ckpt = grid_search(num_grids=6)
|
||||||
i = 0
|
node_list = list(gm.graph.nodes)
|
||||||
for idx, n in enumerate(gm.graph.nodes):
|
for i, seg in enumerate(ckpt):
|
||||||
if idx in ckpt:
|
for idx in range(*seg):
|
||||||
setattr(n, 'activation_checkpoint', str(i))
|
n = node_list[idx]
|
||||||
i += 1
|
if n.op in ['call_module', 'call_method', 'call_function']:
|
||||||
|
setattr(n, 'activation_checkpoint', str(i))
|
||||||
gm.recompile()
|
gm.recompile()
|
||||||
return gm
|
return gm
|
||||||
|
|
||||||
@ -82,7 +94,9 @@ def chen_sqrtn(gm: GraphModule) -> GraphModule:
|
|||||||
gm.graph.lint() # make sure nodes are in topological order
|
gm.graph.lint() # make sure nodes are in topological order
|
||||||
k = int(len(gm.graph.nodes)**0.5) # take approximately sqrt(n) checkpoints
|
k = int(len(gm.graph.nodes)**0.5) # take approximately sqrt(n) checkpoints
|
||||||
for idx, n in enumerate(gm.graph.nodes):
|
for idx, n in enumerate(gm.graph.nodes):
|
||||||
if (idx + 1) % k == 0:
|
# We should not add act_ckpt to the placeholder
|
||||||
|
# The last segment should not be checkpointed
|
||||||
|
if n.op != 'placeholder' and (idx + 1) // k < k:
|
||||||
setattr(n, 'activation_checkpoint', str((idx + 1) // k))
|
setattr(n, 'activation_checkpoint', str((idx + 1) // k))
|
||||||
gm.recompile()
|
gm.recompile()
|
||||||
return gm
|
return gm
|
||||||
|
@ -1,12 +1,25 @@
|
|||||||
from ctypes import Union
|
from typing import Callable
|
||||||
from colossalai.fx.passes.algorithms import chen_greedy, chen_sqrtn
|
import copy
|
||||||
import torch
|
import torch
|
||||||
|
import torch.multiprocessing as mp
|
||||||
import torchvision.models as tm
|
import torchvision.models as tm
|
||||||
from colossalai.fx import ColoTracer
|
|
||||||
from torch.fx import GraphModule
|
from torch.fx import GraphModule
|
||||||
|
import colossalai
|
||||||
|
from colossalai.fx import ColoTracer
|
||||||
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
|
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
|
||||||
|
from colossalai.fx.passes.algorithms import chen_greedy, chen_sqrtn
|
||||||
|
from colossalai.utils import free_port
|
||||||
|
from colossalai.core import global_context as gpc
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
try:
|
||||||
|
from colossalai.fx.codegen import ActivationCheckpointCodeGen
|
||||||
|
with_codegen = True
|
||||||
|
except:
|
||||||
|
# fall back to older pytorch version
|
||||||
|
from colossalai.fx.codegen import python_code_with_activation_checkpoint
|
||||||
|
with_codegen = False
|
||||||
|
|
||||||
SOLVERS = [chen_greedy, chen_sqrtn]
|
SOLVERS = [chen_greedy, chen_sqrtn]
|
||||||
|
|
||||||
|
|
||||||
@ -18,37 +31,80 @@ def _is_activation_checkpoint_available(gm: GraphModule):
|
|||||||
|
|
||||||
def _is_all_gradient_close(m: torch.nn.Module, gm: GraphModule):
|
def _is_all_gradient_close(m: torch.nn.Module, gm: GraphModule):
|
||||||
for m_p, gm_p in zip(m.parameters(), gm.parameters()):
|
for m_p, gm_p in zip(m.parameters(), gm.parameters()):
|
||||||
if not torch.allclose(m_p, gm_p):
|
if not torch.allclose(m_p.grad, gm_p.grad):
|
||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
def test_ckpt_solver():
|
def check_backward_consistency(m: torch.nn.Module, gm: GraphModule, solver: Callable[[GraphModule], GraphModule],
|
||||||
|
model_cls: Callable[[], torch.nn.Module]):
|
||||||
|
criterion = torch.nn.MSELoss()
|
||||||
|
data = torch.rand(2, 3, 32, 32)
|
||||||
|
label = torch.rand(2, 5)
|
||||||
|
loss = criterion(m(data), label)
|
||||||
|
loss.backward()
|
||||||
|
loss = criterion(gm(data), label)
|
||||||
|
loss.backward()
|
||||||
|
assert _is_all_gradient_close(m, gm), f'Solver {solver} did not work correctly in backward pass on {model_cls}'
|
||||||
|
|
||||||
|
|
||||||
|
def _run_ckpt_solver(rank):
|
||||||
|
colossalai.launch(config={}, rank=rank, world_size=1, host='localhost', port=free_port(), backend='nccl')
|
||||||
MODEL_LIST = [tm.resnet18, tm.densenet121]
|
MODEL_LIST = [tm.resnet18, tm.densenet121]
|
||||||
|
|
||||||
torch.backends.cudnn.deterministic = True
|
torch.backends.cudnn.deterministic = True
|
||||||
|
|
||||||
tracer = ColoTracer()
|
tracer = ColoTracer(trace_act_ckpt=False)
|
||||||
data = torch.rand(1, 3, 224, 224)
|
|
||||||
label = torch.rand(1, 1000)
|
|
||||||
|
|
||||||
|
data = torch.rand(2, 3, 32, 32)
|
||||||
for solver in SOLVERS:
|
for solver in SOLVERS:
|
||||||
for model_cls in MODEL_LIST:
|
for model_cls in MODEL_LIST:
|
||||||
model = model_cls()
|
m = model_cls(num_classes=5)
|
||||||
criterion = torch.nn.MSELoss()
|
graph = tracer.trace(root=m)
|
||||||
graph = tracer.trace(root=model)
|
gm = GraphModule(copy.deepcopy(m), graph, m.__class__.__name__)
|
||||||
gm = GraphModule(model, graph, model.__class__.__name__)
|
|
||||||
MetaInfoProp(gm).run(data)
|
MetaInfoProp(gm).run(data)
|
||||||
|
codegen = ActivationCheckpointCodeGen()
|
||||||
|
gm.graph.set_codegen(codegen)
|
||||||
gm = solver(gm)
|
gm = solver(gm)
|
||||||
assert _is_activation_checkpoint_available(
|
assert _is_activation_checkpoint_available(
|
||||||
gm), f"Solver {solver} did not annotate {model_cls} with any activation checkpoints"
|
gm), f"Solver {solver} did not annotate {model_cls} with any activation checkpoints"
|
||||||
loss = criterion(model(data), label)
|
check_backward_consistency(m, gm, solver, model_cls)
|
||||||
loss.backward()
|
|
||||||
loss = criterion(gm(data), label)
|
|
||||||
loss.backward()
|
@pytest.mark.skip
|
||||||
assert _is_all_gradient_close(model,
|
@pytest.mark.skipif(not with_codegen, reason='torch version is lower than 1.12.0')
|
||||||
gm), f'Solver {solver} did not work correctly in backward pass on {model_cls}'
|
def test_ckpt_solver():
|
||||||
|
mp.spawn(_run_ckpt_solver, nprocs=1)
|
||||||
|
|
||||||
|
|
||||||
|
def _run_ckpt_solver_torch11(rank):
|
||||||
|
colossalai.launch(config={}, rank=rank, world_size=1, host='localhost', port=free_port(), backend='nccl')
|
||||||
|
MODEL_LIST = [tm.resnet18, tm.densenet121]
|
||||||
|
|
||||||
|
torch.backends.cudnn.deterministic = True
|
||||||
|
|
||||||
|
tracer = ColoTracer(trace_act_ckpt=False)
|
||||||
|
|
||||||
|
data = torch.rand(2, 3, 32, 32)
|
||||||
|
for solver in SOLVERS:
|
||||||
|
for model_cls in MODEL_LIST:
|
||||||
|
m = model_cls(num_classes=5)
|
||||||
|
graph = tracer.trace(root=m)
|
||||||
|
gm = GraphModule(copy.deepcopy(m), graph, m.__class__.__name__)
|
||||||
|
MetaInfoProp(gm).run(data)
|
||||||
|
gm.graph._python_code = python_code_with_activation_checkpoint.__get__(graph)
|
||||||
|
gm = solver(gm)
|
||||||
|
assert _is_activation_checkpoint_available(
|
||||||
|
gm), f"Solver {solver} did not annotate {model_cls} with any activation checkpoints"
|
||||||
|
check_backward_consistency(m, gm, solver, model_cls)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skip
|
||||||
|
@pytest.mark.skipif(with_codegen, reason='torch version is equal to or higher than 1.12.0')
|
||||||
|
def test_ckpt_solver_torch11():
|
||||||
|
mp.spawn(_run_ckpt_solver_torch11, nprocs=1)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
test_ckpt_solver()
|
test_ckpt_solver()
|
||||||
|
test_ckpt_solver_torch11()
|
||||||
|
Loading…
Reference in New Issue
Block a user