mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-02 01:28:31 +00:00
[fx] Add use_reentrant=False to checkpoint in codegen (#1463)
* [utils] Add use_reetrant=False into colossalai checkpoint * [utils] add some annotation in utils.activaion_checkpoint * [test] add reset_seed at the beginning of tests in test_actiavion_checkpointing.py * [test] modify test_activation_checkpoint.py * [test] modify test for reentrant=False * [fx] Add use_reentrant=False of checkpoint into codegen
This commit is contained in:
@@ -99,13 +99,13 @@ def _gen_ckpt_output(output_vars: List[str]) -> str:
|
||||
return f"return {', '.join(output_vars)}"
|
||||
|
||||
|
||||
def _gen_ckpt_usage(label, activation_offload, input_vars, output_vars):
|
||||
def _gen_ckpt_usage(label, activation_offload, input_vars, output_vars, use_reentrant=True):
|
||||
"""
|
||||
Generate the checkpoint function call code text
|
||||
"""
|
||||
outputs = ', '.join(output_vars)
|
||||
inputs = ', '.join(input_vars)
|
||||
return f'{outputs} = colossalai.utils.activation_checkpoint.checkpoint(checkpoint_{label}, {activation_offload}, {inputs})'
|
||||
return f'{outputs} = colossalai.utils.activation_checkpoint.checkpoint(checkpoint_{label}, {activation_offload}, {inputs}, use_reentrant={use_reentrant})'
|
||||
|
||||
|
||||
def emit_code_with_activation_checkpoint(body, nodes, emit_node_func, delete_unused_value_func):
|
||||
@@ -162,8 +162,24 @@ def emit_code_with_activation_checkpoint(body, nodes, emit_node_func, delete_unu
|
||||
else:
|
||||
activation_offload = False
|
||||
|
||||
# we need to check if the checkpoint need use_reentrant=False
|
||||
use_reentrant = True
|
||||
for var in input_vars[label]:
|
||||
input_node = [item for item in node_list if item.name == var]
|
||||
input_node = input_node[0]
|
||||
for user in input_node.users:
|
||||
if hasattr(user, "activation_checkpoint"):
|
||||
if user.activation_checkpoint == label:
|
||||
if user.op == "call_module":
|
||||
if hasattr(user.graph.owning_module.get_submodule(user.target), "inplace"):
|
||||
use_reentrant = not user.graph.owning_module.get_submodule(user.target).inplace
|
||||
|
||||
elif user.op == "call_function":
|
||||
if "inplace" in user.kwargs:
|
||||
use_reentrant = not user.kwargs["inplace"]
|
||||
|
||||
# generate checkpoint function call in a new line
|
||||
usage = _gen_ckpt_usage(label, activation_offload, input_vars[label], output_vars[label])
|
||||
usage = _gen_ckpt_usage(label, activation_offload, input_vars[label], output_vars[label], use_reentrant)
|
||||
usage += '\n'
|
||||
body.append(usage)
|
||||
within_ckpt_region = False
|
||||
|
Reference in New Issue
Block a user