mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 11:32:10 +00:00
[fx] Fix activation codegen dealing with checkpointing first op (#1510)
This commit is contained in:
@@ -165,9 +165,12 @@ def emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node_func,
|
||||
|
||||
# we need to check if the checkpoint need use_reentrant=False
|
||||
use_reentrant = True
|
||||
non_leaf_input = 0
|
||||
for var in input_vars[label]:
|
||||
input_node = [item for item in node_list if item.name == var]
|
||||
input_node = input_node[0]
|
||||
if input_node.op != "placeholder":
|
||||
non_leaf_input = 1
|
||||
for user in input_node.users:
|
||||
if hasattr(user, "activation_checkpoint"):
|
||||
if user.activation_checkpoint == label:
|
||||
@@ -179,6 +182,10 @@ def emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node_func,
|
||||
if "inplace" in user.kwargs:
|
||||
use_reentrant = not user.kwargs["inplace"]
|
||||
|
||||
# if all the inputs are leaf nodes, we need to set use_reentrant = False
|
||||
if not non_leaf_input:
|
||||
use_reentrant = False
|
||||
|
||||
# generate checkpoint function call in a new line
|
||||
usage = _gen_ckpt_usage(label, activation_offload, input_vars[label], output_vars[label], use_reentrant)
|
||||
usage += '\n'
|
||||
|
Reference in New Issue
Block a user