[fx] Fix ckpt functions' definitions in forward (#1476)

* [fx] fix defining ckpt functions inside forward

* [fx] Modify activation checkpoint codegen and add ColoGraphModule

* [fx] some modification

* some modifications

* some modifications

* some modifications

* some modifications

* some code modifications
This commit is contained in:
Boyuan Yao
2022-08-22 16:59:54 +08:00
committed by GitHub
parent bb5f5289e0
commit 1f2e547f7a
4 changed files with 217 additions and 39 deletions

View File

@@ -7,6 +7,7 @@ import torchvision.models as tm
from torch.fx import GraphModule
import colossalai
from colossalai.fx import ColoTracer
from colossalai.fx.graph_module import ColoGraphModule
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
from colossalai.fx.passes.algorithms import chen_greedy
from colossalai.utils import free_port
@@ -72,7 +73,7 @@ def _run_ckpt_solver(rank):
for model_cls in MODEL_LIST:
m = model_cls(num_classes=5)
graph = tracer.trace(root=m)
gm = GraphModule(copy.deepcopy(m), graph, m.__class__.__name__)
gm = ColoGraphModule(copy.deepcopy(m), graph, m.__class__.__name__)
MetaInfoProp(gm).run(data)
codegen = ActivationCheckpointCodeGen()
gm.graph.set_codegen(codegen)
@@ -102,7 +103,7 @@ def _run_ckpt_solver_torch11(rank):
for model_cls in MODEL_LIST:
m = model_cls(num_classes=5)
graph = tracer.trace(root=m)
gm = GraphModule(copy.deepcopy(m), graph, m.__class__.__name__)
gm = ColoGraphModule(copy.deepcopy(m), graph, m.__class__.__name__)
MetaInfoProp(gm).run(data)
gm.graph._python_code = python_code_with_activation_checkpoint.__get__(graph)
gm = solver(gm)
@@ -114,10 +115,12 @@ def _run_ckpt_solver_torch11(rank):
@pytest.mark.skipif(with_codegen, reason='torch version is equal to or higher than 1.12.0')
@pytest.mark.skip(reason="currently torch11 ColoGraphModule is not done")
def test_ckpt_solver_torch11():
mp.spawn(_run_ckpt_solver_torch11, nprocs=1)
if __name__ == '__main__':
test_ckpt_solver()
test_ckpt_solver_torch11()
_run_ckpt_solver(rank=0)
# test_ckpt_solver()
# test_ckpt_solver_torch11()

View File

@@ -9,6 +9,7 @@ from colossalai.fx import ColoTracer
import colossalai
from colossalai.utils import free_port
from colossalai.core import global_context as gpc
from colossalai.fx.graph_module import ColoGraphModule
try:
from colossalai.fx.codegen import ActivationCheckpointCodeGen
@@ -46,7 +47,7 @@ class MyModule(torch.nn.Module):
super().__init__()
self.mlp1 = MLP()
self.relu = relu()
self.linear3 = torch.nn.Linear(4, 4)
self.linear2 = torch.nn.Linear(4, 4)
def forward(self, x):
y1, y2 = checkpoint(self.mlp1, x)
@@ -56,6 +57,7 @@ class MyModule(torch.nn.Module):
return F.relu(x, inplace=True)
y4 = checkpoint(ckpt2, x)
y4 = self.linear2(y4)
return y1 + y2 + y3 + y4
@@ -91,15 +93,15 @@ def _run_act_ckpt_codegen(rank):
if node.name in offload_starts:
setattr(node, 'activation_offload', True)
gm = GraphModule(model, graph)
gm = ColoGraphModule(model, graph)
gm.recompile()
# assert checkpoint function will be generated and
# the offload option is correct
code = graph.python_code('self').src
assert 'colossalai.utils.activation_checkpoint.checkpoint(checkpoint_0, True, x, use_reentrant=True)' in code and \
'colossalai.utils.activation_checkpoint.checkpoint(checkpoint_1, False, x, use_reentrant=False)' in code and \
'colossalai.utils.activation_checkpoint.checkpoint(checkpoint_2, False, x, use_reentrant=False)' in code
assert 'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0, True, x, use_reentrant=True)' in code and \
'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_1, False, x, use_reentrant=False)' in code and \
'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_2, False, x, use_reentrant=False)' in code
# recompile and verify the outputs are consistent
fx_out = gm(data)
@@ -145,14 +147,14 @@ def _run_act_ckpt_python_code_torch11(rank):
if node.name in offload_starts:
setattr(node, 'activation_offload', True)
gm = GraphModule(model, graph)
gm = ColoGraphModule(model, graph)
gm.recompile()
# assert checkpoint function will be generated and
# the offload option is correct
code = graph.python_code('self').src
assert 'colossalai.utils.activation_checkpoint.checkpoint(checkpoint_0, True, x, use_reentrant=True)' in code and \
'colossalai.utils.activation_checkpoint.checkpoint(checkpoint_1, False, x, use_reentrant=False)' in code and \
'colossalai.utils.activation_checkpoint.checkpoint(checkpoint_2, False, x, use_reentrant=False)' in code
assert 'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0, True, x, use_reentrant=True)' in code and \
'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_1, False, x, use_reentrant=False)' in code and \
'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_2, False, x, use_reentrant=False)' in code
# recompile and verify the outputs are consistent
fx_out = gm(data)
@@ -162,11 +164,10 @@ def _run_act_ckpt_python_code_torch11(rank):
@pytest.mark.skipif(with_codegen, reason='torch version is equal to or higher than 1.12.0')
@pytest.mark.skip(reason="currently torch11 ColoGraphModule is not done")
def test_act_ckpt_python_code_torch11():
mp.spawn(_run_act_ckpt_python_code_torch11, nprocs=1)
if __name__ == '__main__':
test_act_ckpt_codegen()
test_act_ckpt_python_code_torch11()
_run_act_ckpt_codegen(rank=0)