[fx] Add use_reentrant=False to checkpoint in codegen (#1463)

* [utils] Add use_reetrant=False into colossalai checkpoint

* [utils] add some annotation in utils.activaion_checkpoint

* [test] add reset_seed at the beginning of tests in test_actiavion_checkpointing.py

* [test] modify test_activation_checkpoint.py

* [test] modify test for reentrant=False

* [fx] Add use_reentrant=False of checkpoint into codegen
This commit is contained in:
Boyuan Yao
2022-08-17 10:34:50 +08:00
committed by GitHub
parent 47fd8e4a02
commit 092b9c8f49
2 changed files with 53 additions and 18 deletions

View File

@@ -1,5 +1,6 @@
from operator import mod
import torch
import torch.nn.functional as F
import pytest
import torch.multiprocessing as mp
from torch.utils.checkpoint import checkpoint
@@ -26,7 +27,17 @@ class MLP(torch.nn.Module):
self.linear2 = torch.nn.Linear(4, 4)
def forward(self, x):
return self.linear1(x), self.linear1(x)
return self.linear1(x), self.linear2(x)
class relu(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.relu = torch.nn.ReLU(inplace=True)
def forward(self, x):
return self.relu(x)
class MyModule(torch.nn.Module):
@@ -34,12 +45,17 @@ class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.mlp1 = MLP()
self.mlp2 = MLP()
self.relu = relu()
self.linear3 = torch.nn.Linear(4, 4)
def forward(self, x):
y1, y2 = checkpoint(self.mlp1, x)
y3, y4 = checkpoint(self.mlp2, x)
y3 = checkpoint(self.relu, x)
def ckpt2(x):
return F.relu(x, inplace=True)
y4 = checkpoint(ckpt2, x)
return y1 + y2 + y3 + y4
@@ -65,8 +81,8 @@ def _run_act_ckpt_codegen(rank):
# check ops are annotated with ckpt
# also annotate the selected node for offloading
ckpt_nodes = ['mlp1_linear1', 'mlp1_linear1_1', 'mlp2_linear1', 'mlp2_linear1_1']
offload_starts = ['mlp2_linear1']
ckpt_nodes = ['mlp1_linear1', 'mlp1_linear2', 'relu_relu', 'relu']
offload_starts = ['mlp1_linear1']
for node in graph.nodes:
if node.name in ckpt_nodes:
assert hasattr(node, 'activation_checkpoint')
@@ -75,15 +91,17 @@ def _run_act_ckpt_codegen(rank):
if node.name in offload_starts:
setattr(node, 'activation_offload', True)
gm = GraphModule(model, graph)
gm.recompile()
# assert checkpoint function will be generated and
# the offload option is correct
code = graph.python_code('self').src
assert 'colossalai.utils.activation_checkpoint.checkpoint(checkpoint_0, False, x)' in code and \
'colossalai.utils.activation_checkpoint.checkpoint(checkpoint_1, True, x)' in code
assert 'colossalai.utils.activation_checkpoint.checkpoint(checkpoint_0, True, x, use_reentrant=True)' in code and \
'colossalai.utils.activation_checkpoint.checkpoint(checkpoint_1, False, x, use_reentrant=False)' in code and \
'colossalai.utils.activation_checkpoint.checkpoint(checkpoint_2, False, x, use_reentrant=False)' in code
# recompile and verify the outputs are consistent
gm = GraphModule(model, graph)
gm.recompile()
fx_out = gm(data)
assert torch.equal(non_fx_out, fx_out)
@@ -117,8 +135,8 @@ def _run_act_ckpt_python_code_torch11(rank):
graph._python_code = python_code_with_activation_checkpoint.__get__(graph)
# check ops are annotated with ckpt
ckpt_nodes = ['mlp1_linear1', 'mlp1_linear1_1', 'mlp2_linear1', 'mlp2_linear1_1']
offload_starts = ['mlp2_linear1']
ckpt_nodes = ['mlp1_linear1', 'mlp1_linear2', 'relu_relu', 'relu']
offload_starts = ['mlp1_linear1']
for node in graph.nodes:
if node.name in ckpt_nodes:
assert hasattr(node, 'activation_checkpoint')
@@ -127,15 +145,16 @@ def _run_act_ckpt_python_code_torch11(rank):
if node.name in offload_starts:
setattr(node, 'activation_offload', True)
gm = GraphModule(model, graph)
gm.recompile()
# assert checkpoint function will be generated and
# the offload option is correct
code = graph.python_code('self').src
assert 'colossalai.utils.activation_checkpoint.checkpoint(checkpoint_0, False, x)' in code and \
'colossalai.utils.activation_checkpoint.checkpoint(checkpoint_1, True, x)' in code
assert 'colossalai.utils.activation_checkpoint.checkpoint(checkpoint_0, True, x, use_reentrant=True)' in code and \
'colossalai.utils.activation_checkpoint.checkpoint(checkpoint_1, False, x, use_reentrant=False)' in code and \
'colossalai.utils.activation_checkpoint.checkpoint(checkpoint_2, False, x, use_reentrant=False)' in code
# recompile and verify the outputs are consistent
gm = GraphModule(model, graph)
gm.recompile()
fx_out = gm(data)
assert torch.equal(non_fx_out, fx_out)