[fx] added activation checkpoint codegen support for torch < 1.12 (#1359)

This commit is contained in:
Frank Lee
2022-07-25 23:35:31 +08:00
committed by GitHub
parent 4417804129
commit cd063ac37f
3 changed files with 435 additions and 192 deletions

View File

@@ -6,8 +6,11 @@ from colossalai.fx import ColoTracer
try:
from colossalai.fx.codegen import ActivationCheckpointCodeGen
with_codegen = True
except:
pass
# fall back to older pytorch version
from colossalai.fx.codegen import python_code_with_activation_checkpoint
with_codegen = False
class MLP(torch.nn.Module):
@@ -35,7 +38,7 @@ class MyModule(torch.nn.Module):
return y1 + y2 + y3 + y4
@pytest.mark.skip("torch 1.12 is required")
@pytest.mark.skipif(not with_codegen, reason='torch version is lower than 1.12.0')
def test_act_ckpt_codegen():
# build model and run forward
model = MyModule()
@@ -65,5 +68,37 @@ def test_act_ckpt_codegen():
assert torch.equal(non_fx_out, fx_out)
@pytest.mark.skipif(with_codegen, reason='torch version is equal to or higher than 1.12.0')
def test_act_ckpt_python_code_torch11():
# build model and run forward
model = MyModule()
data = torch.rand(4, 4)
non_fx_out = model(data)
# trace the module and replace codegen
tracer = ColoTracer(trace_act_ckpt=True)
graph = tracer.trace(model)
# replace a bound method of an object
graph._python_code = python_code_with_activation_checkpoint.__get__(graph)
# check ops are annotated with ckpt
ckpt_nodes = ['mlp1_linear1', 'mlp1_linear1_1', 'mlp2_linear1', 'mlp2_linear1_1']
for node in graph.nodes:
if node.name in ckpt_nodes:
assert hasattr(node, 'activation_checkpoint')
# assert checkpoint function will be generated
code = graph.python_code('self').src
assert 'checkpoint_0' in code and 'checkpoint_1' in code
# recompile and verify the outputs are consistent
gm = GraphModule(model, graph)
gm.recompile()
fx_out = gm(data)
assert torch.equal(non_fx_out, fx_out)
if __name__ == '__main__':
test_act_ckpt_codegen()
test_act_ckpt_python_code_torch11()