mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-02 09:38:05 +00:00
[fx] Add use_reentrant=False to checkpoint in codegen (#1463)
* [utils] Add use_reetrant=False into colossalai checkpoint * [utils] add some annotation in utils.activaion_checkpoint * [test] add reset_seed at the beginning of tests in test_actiavion_checkpointing.py * [test] modify test_activation_checkpoint.py * [test] modify test for reentrant=False * [fx] Add use_reentrant=False of checkpoint into codegen
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
from operator import mod
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import pytest
|
||||
import torch.multiprocessing as mp
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
@@ -26,7 +27,17 @@ class MLP(torch.nn.Module):
|
||||
self.linear2 = torch.nn.Linear(4, 4)
|
||||
|
||||
def forward(self, x):
|
||||
return self.linear1(x), self.linear1(x)
|
||||
return self.linear1(x), self.linear2(x)
|
||||
|
||||
|
||||
class relu(torch.nn.Module):
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.relu = torch.nn.ReLU(inplace=True)
|
||||
|
||||
def forward(self, x):
|
||||
return self.relu(x)
|
||||
|
||||
|
||||
class MyModule(torch.nn.Module):
|
||||
@@ -34,12 +45,17 @@ class MyModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.mlp1 = MLP()
|
||||
self.mlp2 = MLP()
|
||||
self.relu = relu()
|
||||
self.linear3 = torch.nn.Linear(4, 4)
|
||||
|
||||
def forward(self, x):
|
||||
y1, y2 = checkpoint(self.mlp1, x)
|
||||
y3, y4 = checkpoint(self.mlp2, x)
|
||||
y3 = checkpoint(self.relu, x)
|
||||
|
||||
def ckpt2(x):
|
||||
return F.relu(x, inplace=True)
|
||||
|
||||
y4 = checkpoint(ckpt2, x)
|
||||
return y1 + y2 + y3 + y4
|
||||
|
||||
|
||||
@@ -65,8 +81,8 @@ def _run_act_ckpt_codegen(rank):
|
||||
|
||||
# check ops are annotated with ckpt
|
||||
# also annotate the selected node for offloading
|
||||
ckpt_nodes = ['mlp1_linear1', 'mlp1_linear1_1', 'mlp2_linear1', 'mlp2_linear1_1']
|
||||
offload_starts = ['mlp2_linear1']
|
||||
ckpt_nodes = ['mlp1_linear1', 'mlp1_linear2', 'relu_relu', 'relu']
|
||||
offload_starts = ['mlp1_linear1']
|
||||
for node in graph.nodes:
|
||||
if node.name in ckpt_nodes:
|
||||
assert hasattr(node, 'activation_checkpoint')
|
||||
@@ -75,15 +91,17 @@ def _run_act_ckpt_codegen(rank):
|
||||
if node.name in offload_starts:
|
||||
setattr(node, 'activation_offload', True)
|
||||
|
||||
gm = GraphModule(model, graph)
|
||||
gm.recompile()
|
||||
|
||||
# assert checkpoint function will be generated and
|
||||
# the offload option is correct
|
||||
code = graph.python_code('self').src
|
||||
assert 'colossalai.utils.activation_checkpoint.checkpoint(checkpoint_0, False, x)' in code and \
|
||||
'colossalai.utils.activation_checkpoint.checkpoint(checkpoint_1, True, x)' in code
|
||||
assert 'colossalai.utils.activation_checkpoint.checkpoint(checkpoint_0, True, x, use_reentrant=True)' in code and \
|
||||
'colossalai.utils.activation_checkpoint.checkpoint(checkpoint_1, False, x, use_reentrant=False)' in code and \
|
||||
'colossalai.utils.activation_checkpoint.checkpoint(checkpoint_2, False, x, use_reentrant=False)' in code
|
||||
|
||||
# recompile and verify the outputs are consistent
|
||||
gm = GraphModule(model, graph)
|
||||
gm.recompile()
|
||||
fx_out = gm(data)
|
||||
assert torch.equal(non_fx_out, fx_out)
|
||||
|
||||
@@ -117,8 +135,8 @@ def _run_act_ckpt_python_code_torch11(rank):
|
||||
graph._python_code = python_code_with_activation_checkpoint.__get__(graph)
|
||||
|
||||
# check ops are annotated with ckpt
|
||||
ckpt_nodes = ['mlp1_linear1', 'mlp1_linear1_1', 'mlp2_linear1', 'mlp2_linear1_1']
|
||||
offload_starts = ['mlp2_linear1']
|
||||
ckpt_nodes = ['mlp1_linear1', 'mlp1_linear2', 'relu_relu', 'relu']
|
||||
offload_starts = ['mlp1_linear1']
|
||||
for node in graph.nodes:
|
||||
if node.name in ckpt_nodes:
|
||||
assert hasattr(node, 'activation_checkpoint')
|
||||
@@ -127,15 +145,16 @@ def _run_act_ckpt_python_code_torch11(rank):
|
||||
if node.name in offload_starts:
|
||||
setattr(node, 'activation_offload', True)
|
||||
|
||||
gm = GraphModule(model, graph)
|
||||
gm.recompile()
|
||||
# assert checkpoint function will be generated and
|
||||
# the offload option is correct
|
||||
code = graph.python_code('self').src
|
||||
assert 'colossalai.utils.activation_checkpoint.checkpoint(checkpoint_0, False, x)' in code and \
|
||||
'colossalai.utils.activation_checkpoint.checkpoint(checkpoint_1, True, x)' in code
|
||||
assert 'colossalai.utils.activation_checkpoint.checkpoint(checkpoint_0, True, x, use_reentrant=True)' in code and \
|
||||
'colossalai.utils.activation_checkpoint.checkpoint(checkpoint_1, False, x, use_reentrant=False)' in code and \
|
||||
'colossalai.utils.activation_checkpoint.checkpoint(checkpoint_2, False, x, use_reentrant=False)' in code
|
||||
|
||||
# recompile and verify the outputs are consistent
|
||||
gm = GraphModule(model, graph)
|
||||
gm.recompile()
|
||||
fx_out = gm(data)
|
||||
assert torch.equal(non_fx_out, fx_out)
|
||||
|
||||
|
Reference in New Issue
Block a user