[fx] fix the false interpretation of algorithm 3 in https://arxiv.org/abs/1604.06174. (#1446)

* [fx] modify the calculation of node_size in MetaInfoProp for activation checkpointing usages

* [fx] modify the calculation of node_size in MetaInfoProp for activation checkpointing usages

* [fx] modify the calculation of node_size in MetaInfoProp for activation checkpointing usages

* [fx] activation checkpointing using Chen strategies.

* [fx] add test for ckpt_solver_chen

* mend

* [fx] add vanilla activation checkpoint search with test on resnet and densenet

* [fx] add vanilla activation checkpoint search with test on resnet and densenet

* [fx] add a namespace code for solver_chen.

* [fx] fix the false interpretation of algorithm 3 in https://arxiv.org/abs/1604.06174.

* [fx] fix lowercase naming conventions.
This commit is contained in:
Super Daniel
2022-08-12 11:28:50 +08:00
committed by GitHub
parent 821c6172e2
commit d40a9392ba
2 changed files with 62 additions and 22 deletions

View File

@@ -1,13 +1,13 @@
from ctypes import Union
from colossalai.fx.passes.algorithms import chen_greedy, chen_sqrtn
import torch
import torchvision.models as tm
from colossalai.fx import ColoTracer
from torch.fx import GraphModule
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
from functools import partial
import pytest
SOLVERS = [partial(chen_greedy, B=1024 * 1024 * 64), chen_sqrtn]
SOLVERS = [chen_greedy, chen_sqrtn]
def _is_activation_checkpoint_available(gm: GraphModule):
@@ -16,6 +16,13 @@ def _is_activation_checkpoint_available(gm: GraphModule):
return True
def _is_all_gradient_close(m: torch.nn.Module, gm: GraphModule):
for m_p, gm_p in zip(m.parameters(), gm.parameters()):
if not torch.allclose(m_p, gm_p):
return False
return True
def test_ckpt_solver():
MODEL_LIST = [tm.resnet18, tm.densenet121]
@@ -23,17 +30,24 @@ def test_ckpt_solver():
tracer = ColoTracer()
data = torch.rand(1, 3, 224, 224)
label = torch.rand(1, 1000)
for solver in SOLVERS:
for model_cls in MODEL_LIST:
model = model_cls()
criterion = torch.nn.MSELoss()
graph = tracer.trace(root=model)
gm = GraphModule(model, graph, model.__class__.__name__)
MetaInfoProp(gm).run(data)
gm = solver(gm)
assert _is_activation_checkpoint_available(
gm), f"Solver {solver} did not annotate {model_cls} with any activation checkpoints"
assert torch.allclose(gm(data), model(data))
loss = criterion(model(data), label)
loss.backward()
loss = criterion(gm(data), label)
loss.backward()
assert _is_all_gradient_close(model,
gm), f'Solver {solver} did not work correctly in backward pass on {model_cls}'
if __name__ == '__main__':