mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-12 20:54:35 +00:00
[fx] add vanilla activation checkpoint search with test on resnet and densenet (#1433)
* [fx] activation checkpointing using Chen strategies. * [fx] add test for ckpt_solver_chen * [fx] add vanilla activation checkpoint search with test on resnet and densenet * [fx] add vanilla activation checkpoint search with test on resnet and densenet * [fx] add a namespace code for solver_chen.
This commit is contained in:
40
tests/test_fx/test_ckpt_solvers/test_ckpt_torchvision.py
Normal file
40
tests/test_fx/test_ckpt_solvers/test_ckpt_torchvision.py
Normal file
@@ -0,0 +1,40 @@
|
||||
from colossalai.fx.passes.algorithms import chen_greedy, chen_sqrtn
|
||||
import torch
|
||||
import torchvision.models as tm
|
||||
from colossalai.fx import ColoTracer
|
||||
from torch.fx import GraphModule
|
||||
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
|
||||
from functools import partial
|
||||
import pytest
|
||||
|
||||
SOLVERS = [partial(chen_greedy, B=1024 * 1024 * 64), chen_sqrtn]
|
||||
|
||||
|
||||
def _is_activation_checkpoint_available(gm: GraphModule):
|
||||
for n in gm.graph.nodes:
|
||||
if hasattr(n, 'activation_checkpoint') and getattr(n, 'activation_checkpoint') is not None:
|
||||
return True
|
||||
|
||||
|
||||
def test_ckpt_solver():
|
||||
MODEL_LIST = [tm.resnet18, tm.densenet121]
|
||||
|
||||
torch.backends.cudnn.deterministic = True
|
||||
|
||||
tracer = ColoTracer()
|
||||
data = torch.rand(1, 3, 224, 224)
|
||||
|
||||
for solver in SOLVERS:
|
||||
for model_cls in MODEL_LIST:
|
||||
model = model_cls()
|
||||
graph = tracer.trace(root=model)
|
||||
gm = GraphModule(model, graph, model.__class__.__name__)
|
||||
MetaInfoProp(gm).run(data)
|
||||
gm = solver(gm)
|
||||
assert _is_activation_checkpoint_available(
|
||||
gm), f"Solver {solver} did not annotate {model_cls} with any activation checkpoints"
|
||||
assert torch.allclose(gm(data), model(data))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_ckpt_solver()
|
Reference in New Issue
Block a user