mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-12 12:47:21 +00:00
[fx] Fix activation codegen dealing with checkpointing first op (#1510)
This commit is contained in:
@@ -49,16 +49,20 @@ class MyModule(torch.nn.Module):
|
||||
self.relu = relu()
|
||||
self.linear2 = torch.nn.Linear(4, 4)
|
||||
|
||||
def forward(self, x):
|
||||
def ckpt2(self, x):
|
||||
return F.relu(x, inplace=True)
|
||||
|
||||
def ckpt3(self, x, y):
|
||||
return self.linear2(x) + self.linear2(y)
|
||||
|
||||
def forward(self, x, y):
|
||||
y1, y2 = checkpoint(self.mlp1, x)
|
||||
y3 = checkpoint(self.relu, x)
|
||||
|
||||
def ckpt2(x):
|
||||
return F.relu(x, inplace=True)
|
||||
|
||||
y4 = checkpoint(ckpt2, x)
|
||||
y4 = self.linear2(y4)
|
||||
return y1 + y2 + y3 + y4
|
||||
y4 = checkpoint(self.ckpt2, y)
|
||||
y5 = checkpoint(self.ckpt3, y, y4)
|
||||
y6 = self.linear2(y4)
|
||||
return y1 + y2 + y3 + y4 + y5 + y6
|
||||
|
||||
|
||||
def _run_act_ckpt_codegen(rank):
|
||||
@@ -67,13 +71,15 @@ def _run_act_ckpt_codegen(rank):
|
||||
|
||||
# build model and run forward
|
||||
model = MyModule()
|
||||
data = torch.rand(4, 4)
|
||||
data1 = torch.rand(4, 4)
|
||||
data2 = torch.rand(4, 4)
|
||||
|
||||
# copy model to cuda
|
||||
model = model.to(device="cuda")
|
||||
data = data.to(device="cuda")
|
||||
data1 = data1.to(device="cuda")
|
||||
data2 = data2.to(device="cuda")
|
||||
|
||||
non_fx_out = model(data)
|
||||
non_fx_out = model(data1, data2)
|
||||
|
||||
# trace the module and replace codegen
|
||||
tracer = ColoTracer(trace_act_ckpt=True)
|
||||
@@ -99,12 +105,13 @@ def _run_act_ckpt_codegen(rank):
|
||||
# assert checkpoint function will be generated and
|
||||
# the offload option is correct
|
||||
code = graph.python_code('self').src
|
||||
assert 'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0, True, x, use_reentrant=True)' in code and \
|
||||
assert 'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0, True, x, use_reentrant=False)' in code and \
|
||||
'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_1, False, x, use_reentrant=False)' in code and \
|
||||
'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_2, False, x, use_reentrant=False)' in code
|
||||
'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_2, False, y, use_reentrant=False)' in code and \
|
||||
'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_3, False, y, relu, use_reentrant=True)' in code
|
||||
|
||||
# recompile and verify the outputs are consistent
|
||||
fx_out = gm(data)
|
||||
fx_out = gm(data1, data2)
|
||||
assert torch.equal(non_fx_out, fx_out)
|
||||
|
||||
gpc.destroy()
|
||||
@@ -121,13 +128,14 @@ def _run_act_ckpt_python_code_torch11(rank):
|
||||
|
||||
# build model and run forward
|
||||
model = MyModule()
|
||||
data = torch.rand(4, 4)
|
||||
data1 = torch.rand(4, 4)
|
||||
data2 = torch.rand(4, 4)
|
||||
|
||||
# copy model to cuda
|
||||
model = model.to(device="cuda")
|
||||
data = data.to(device="cuda")
|
||||
data1 = data1.to(device="cuda")
|
||||
data2 = data2.to(device="cuda")
|
||||
|
||||
non_fx_out = model(data)
|
||||
non_fx_out = model(data1, data2)
|
||||
|
||||
# trace the module and replace codegen
|
||||
tracer = ColoTracer(trace_act_ckpt=True)
|
||||
@@ -152,12 +160,13 @@ def _run_act_ckpt_python_code_torch11(rank):
|
||||
# assert checkpoint function will be generated and
|
||||
# the offload option is correct
|
||||
code = graph.python_code('self').src
|
||||
assert 'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0, True, x, use_reentrant=True)' in code and \
|
||||
assert 'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0, True, x, use_reentrant=False)' in code and \
|
||||
'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_1, False, x, use_reentrant=False)' in code and \
|
||||
'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_2, False, x, use_reentrant=False)' in code
|
||||
'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_2, False, y, use_reentrant=False)' in code and \
|
||||
'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_3, False, y, relu, use_reentrant=True)' in code
|
||||
|
||||
# recompile and verify the outputs are consistent
|
||||
fx_out = gm(data)
|
||||
fx_out = gm(data1, data2)
|
||||
assert torch.equal(non_fx_out, fx_out)
|
||||
|
||||
gpc.destroy()
|
||||
|
Reference in New Issue
Block a user