diff --git a/colossalai/fx/codegen/activation_checkpoint_codegen.py b/colossalai/fx/codegen/activation_checkpoint_codegen.py index 5978dd315..def67c60f 100644 --- a/colossalai/fx/codegen/activation_checkpoint_codegen.py +++ b/colossalai/fx/codegen/activation_checkpoint_codegen.py @@ -165,9 +165,12 @@ def emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node_func, # we need to check if the checkpoint need use_reentrant=False use_reentrant = True + non_leaf_input = 0 for var in input_vars[label]: input_node = [item for item in node_list if item.name == var] input_node = input_node[0] + if input_node.op != "placeholder": + non_leaf_input = 1 for user in input_node.users: if hasattr(user, "activation_checkpoint"): if user.activation_checkpoint == label: @@ -179,6 +182,10 @@ def emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node_func, if "inplace" in user.kwargs: use_reentrant = not user.kwargs["inplace"] + # if all the inputs are leaf nodes, we need to set use_reentrant = False + if not non_leaf_input: + use_reentrant = False + # generate checkpoint function call in a new line usage = _gen_ckpt_usage(label, activation_offload, input_vars[label], output_vars[label], use_reentrant) usage += '\n' diff --git a/tests/test_fx/test_codegen/test_activation_checkpoint_codegen.py b/tests/test_fx/test_codegen/test_activation_checkpoint_codegen.py index 368222dfe..54a11bb48 100644 --- a/tests/test_fx/test_codegen/test_activation_checkpoint_codegen.py +++ b/tests/test_fx/test_codegen/test_activation_checkpoint_codegen.py @@ -49,16 +49,20 @@ class MyModule(torch.nn.Module): self.relu = relu() self.linear2 = torch.nn.Linear(4, 4) - def forward(self, x): + def ckpt2(self, x): + return F.relu(x, inplace=True) + + def ckpt3(self, x, y): + return self.linear2(x) + self.linear2(y) + + def forward(self, x, y): y1, y2 = checkpoint(self.mlp1, x) y3 = checkpoint(self.relu, x) - def ckpt2(x): - return F.relu(x, inplace=True) - - y4 = checkpoint(ckpt2, x) - y4 = self.linear2(y4) - return y1 + y2 + y3 + y4 + y4 = checkpoint(self.ckpt2, y) + y5 = checkpoint(self.ckpt3, y, y4) + y6 = self.linear2(y4) + return y1 + y2 + y3 + y4 + y5 + y6 def _run_act_ckpt_codegen(rank): @@ -67,13 +71,15 @@ def _run_act_ckpt_codegen(rank): # build model and run forward model = MyModule() - data = torch.rand(4, 4) + data1 = torch.rand(4, 4) + data2 = torch.rand(4, 4) # copy model to cuda model = model.to(device="cuda") - data = data.to(device="cuda") + data1 = data1.to(device="cuda") + data2 = data2.to(device="cuda") - non_fx_out = model(data) + non_fx_out = model(data1, data2) # trace the module and replace codegen tracer = ColoTracer(trace_act_ckpt=True) @@ -99,12 +105,13 @@ def _run_act_ckpt_codegen(rank): # assert checkpoint function will be generated and # the offload option is correct code = graph.python_code('self').src - assert 'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0, True, x, use_reentrant=True)' in code and \ + assert 'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0, True, x, use_reentrant=False)' in code and \ 'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_1, False, x, use_reentrant=False)' in code and \ - 'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_2, False, x, use_reentrant=False)' in code + 'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_2, False, y, use_reentrant=False)' in code and \ + 'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_3, False, y, relu, use_reentrant=True)' in code # recompile and verify the outputs are consistent - fx_out = gm(data) + fx_out = gm(data1, data2) assert torch.equal(non_fx_out, fx_out) gpc.destroy() @@ -121,13 +128,14 @@ def _run_act_ckpt_python_code_torch11(rank): # build model and run forward model = MyModule() - data = torch.rand(4, 4) + data1 = torch.rand(4, 4) + data2 = torch.rand(4, 4) # copy model to cuda - model = model.to(device="cuda") - data = data.to(device="cuda") + data1 = data1.to(device="cuda") + data2 = data2.to(device="cuda") - non_fx_out = model(data) + non_fx_out = model(data1, data2) # trace the module and replace codegen tracer = ColoTracer(trace_act_ckpt=True) @@ -152,12 +160,13 @@ def _run_act_ckpt_python_code_torch11(rank): # assert checkpoint function will be generated and # the offload option is correct code = graph.python_code('self').src - assert 'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0, True, x, use_reentrant=True)' in code and \ + assert 'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0, True, x, use_reentrant=False)' in code and \ 'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_1, False, x, use_reentrant=False)' in code and \ - 'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_2, False, x, use_reentrant=False)' in code + 'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_2, False, y, use_reentrant=False)' in code and \ + 'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_3, False, y, relu, use_reentrant=True)' in code # recompile and verify the outputs are consistent - fx_out = gm(data) + fx_out = gm(data1, data2) assert torch.equal(non_fx_out, fx_out) gpc.destroy()