mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-02 17:46:42 +00:00
[fx] Modify solver linearize and add corresponding test (#1531)
* [fx] modify solver linearize and add test * [fx] add torch11 test of linearize but skip it * [fx] remove some unused imports
This commit is contained in:
128
tests/test_fx/test_ckpt_solvers/test_linearize.py
Normal file
128
tests/test_fx/test_ckpt_solvers/test_linearize.py
Normal file
@@ -0,0 +1,128 @@
|
||||
import torch
|
||||
import torchvision.models as tm
|
||||
from colossalai.fx import ColoTracer
|
||||
from colossalai.fx.graph_module import ColoGraphModule
|
||||
from colossalai.fx.passes.algorithms import solver_rotor, linearize
|
||||
from colossalai.fx.passes.algorithms.utils import Loss, ForwardCheck, ForwardEnable, ForwardNograd
|
||||
import pytest
|
||||
|
||||
try:
|
||||
from colossalai.fx.codegen import ActivationCheckpointCodeGen
|
||||
with_codegen = True
|
||||
except:
|
||||
# fall back to older pytorch version
|
||||
from colossalai.fx.codegen import python_code_with_activation_checkpoint
|
||||
with_codegen = False
|
||||
|
||||
|
||||
@pytest.mark.skipif(not with_codegen, reason="torch version is lower than 1.12.0")
|
||||
def test_linearize():
|
||||
MODEL_DICT = {tm.resnet18: [2100, 3000], tm.densenet121: [8100, 17000]}
|
||||
tracer = ColoTracer()
|
||||
for M, budgets in MODEL_DICT.items():
|
||||
for budget in budgets:
|
||||
model = M()
|
||||
graph = tracer.trace(model)
|
||||
graph.set_codegen(ActivationCheckpointCodeGen())
|
||||
gm = ColoGraphModule(model, graph, model.__class__.__name__)
|
||||
node_list = linearize(gm)
|
||||
gm = solver_rotor(gm, data=torch.rand(128, 3, 224, 224, device="meta"), mem_limit=budget * 1024**2)
|
||||
op_list = gm.__sequence__.list_operations()
|
||||
loss_op = next(op for op in op_list if isinstance(op, Loss))
|
||||
op_list = op_list[:op_list.index(loss_op)]
|
||||
in_ckpt = False
|
||||
ckpt_idx = 0
|
||||
for idx, op in enumerate(op_list):
|
||||
if in_ckpt:
|
||||
if isinstance(op, ForwardNograd):
|
||||
for n in node_list[idx]:
|
||||
assert hasattr(n, "activation_checkpoint"), f"{n} is not annotated!"
|
||||
assert n.activation_checkpoint == ckpt_idx, f"{n} ckpt_idx wrong, should be {ckpt_idx}!"
|
||||
|
||||
continue
|
||||
|
||||
if isinstance(op, ForwardEnable):
|
||||
for n in node_list[idx]:
|
||||
assert getattr(n, "activation_checkpoint", None) == None, f"{n} should not be annotated!"
|
||||
in_ckpt = False
|
||||
|
||||
ckpt_idx += 1
|
||||
continue
|
||||
|
||||
if isinstance(op, ForwardCheck):
|
||||
ckpt_idx += 1
|
||||
for n in node_list[idx]:
|
||||
assert hasattr(n, "activation_checkpoint"), f"{n} is not annotated!"
|
||||
assert n.activation_checkpoint == ckpt_idx, f"{n} ckpt_idx wrong, should be {ckpt_idx}!"
|
||||
|
||||
continue
|
||||
|
||||
else:
|
||||
if isinstance(op, ForwardCheck):
|
||||
in_ckpt = True
|
||||
for n in node_list[idx]:
|
||||
assert hasattr(n, "activation_checkpoint"), f"{n} is not annotated!"
|
||||
assert n.activation_checkpoint == ckpt_idx, f"{n} ckpt_idx wrong, should be {ckpt_idx}!"
|
||||
|
||||
del model
|
||||
del gm
|
||||
del node_list
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="torch11 meta tensor not implemented")
|
||||
@pytest.mark.skipif(with_codegen, reason="torch version is equal to or higher than 1.12.0")
|
||||
def test_linearize_torch11():
|
||||
MODEL_DICT = {tm.resnet18: [2100, 3000], tm.densenet121: [8100, 17000]}
|
||||
tracer = ColoTracer()
|
||||
for M, budgets in MODEL_DICT.items():
|
||||
for budget in budgets:
|
||||
model = M()
|
||||
graph = tracer.trace(model)
|
||||
gm = ColoGraphModule(model, graph, model.__class__.__name__)
|
||||
gm.graph._python_code = python_code_with_activation_checkpoint.__get__(graph)
|
||||
node_list = linearize(gm)
|
||||
gm = solver_rotor(gm, data=torch.rand(128, 3, 224, 224, device="meta"), mem_limit=budget * 1024**2)
|
||||
op_list = gm.__sequence__.list_operations()
|
||||
loss_op = next(op for op in op_list if isinstance(op, Loss))
|
||||
op_list = op_list[:op_list.index(loss_op)]
|
||||
in_ckpt = False
|
||||
ckpt_idx = 0
|
||||
for idx, op in enumerate(op_list):
|
||||
if in_ckpt:
|
||||
if isinstance(op, ForwardNograd):
|
||||
for n in node_list[idx]:
|
||||
assert hasattr(n, "activation_checkpoint"), f"{n} is not annotated!"
|
||||
assert n.activation_checkpoint == ckpt_idx, f"{n} ckpt_idx wrong, should be {ckpt_idx}!"
|
||||
|
||||
continue
|
||||
|
||||
if isinstance(op, ForwardEnable):
|
||||
for n in node_list[idx]:
|
||||
assert getattr(n, "activation_checkpoint", None) == None, f"{n} should not be annotated!"
|
||||
in_ckpt = False
|
||||
|
||||
ckpt_idx += 1
|
||||
continue
|
||||
|
||||
if isinstance(op, ForwardCheck):
|
||||
ckpt_idx += 1
|
||||
for n in node_list[idx]:
|
||||
assert hasattr(n, "activation_checkpoint"), f"{n} is not annotated!"
|
||||
assert n.activation_checkpoint == ckpt_idx, f"{n} ckpt_idx wrong, should be {ckpt_idx}!"
|
||||
|
||||
continue
|
||||
|
||||
else:
|
||||
if isinstance(op, ForwardCheck):
|
||||
in_ckpt = True
|
||||
for n in node_list[idx]:
|
||||
assert hasattr(n, "activation_checkpoint"), f"{n} is not annotated!"
|
||||
assert n.activation_checkpoint == ckpt_idx, f"{n} ckpt_idx wrong, should be {ckpt_idx}!"
|
||||
|
||||
del model
|
||||
del gm
|
||||
del node_list
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_linearize()
|
Reference in New Issue
Block a user