mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 03:20:52 +00:00
[fx] Use colossalai checkpoint and add offload recognition in codegen (#1439)
* [fx] Use colossalai.utils.checkpoint to replace torch.utils.checkpoint for offload activation and add offload annotation recognition in codegen * [fx] Use colossalai.utils.checkpoint to replace torch.utils.checkpoint for offload activation and add offload annotation recognition in codegen * Modification of test and add TODO in codegen * [fx] Modification of colossal ckpt usage * [fx] add gpc.destroy() to test_codegen
This commit is contained in:
@@ -99,13 +99,13 @@ def _gen_ckpt_output(output_vars: List[str]) -> str:
|
||||
return f"return {', '.join(output_vars)}"
|
||||
|
||||
|
||||
def _gen_ckpt_usage(label, input_vars, output_vars):
|
||||
def _gen_ckpt_usage(label, activation_offload, input_vars, output_vars):
|
||||
"""
|
||||
Generate the checkpoint function call code text
|
||||
"""
|
||||
outputs = ', '.join(output_vars)
|
||||
inputs = ', '.join(input_vars)
|
||||
return f'{outputs} = torch.utils.checkpoint.checkpoint(checkpoint_{label}, {inputs})'
|
||||
return f'{outputs} = colossalai.utils.activation_checkpoint.checkpoint(checkpoint_{label}, {activation_offload}, {inputs})'
|
||||
|
||||
|
||||
def emit_code_with_activation_checkpoint(body, nodes, emit_node_func, delete_unused_value_func):
|
||||
@@ -155,8 +155,15 @@ def emit_code_with_activation_checkpoint(body, nodes, emit_node_func, delete_unu
|
||||
return_statement = f' {return_statement}\n'
|
||||
body.append(return_statement)
|
||||
|
||||
# we need to check if the checkpoint need to offload the input
|
||||
start_node_idx = start_idx[label]
|
||||
if hasattr(node_list[start_node_idx], 'activation_offload'):
|
||||
activation_offload = node_list[start_node_idx].activation_offload
|
||||
else:
|
||||
activation_offload = False
|
||||
|
||||
# generate checkpoint function call in a new line
|
||||
usage = _gen_ckpt_usage(label, input_vars[label], output_vars[label])
|
||||
usage = _gen_ckpt_usage(label, activation_offload, input_vars[label], output_vars[label])
|
||||
usage += '\n'
|
||||
body.append(usage)
|
||||
within_ckpt_region = False
|
||||
@@ -368,7 +375,11 @@ if codegen_available:
|
||||
for name, value in self.additional_globals():
|
||||
add_global(name, value)
|
||||
|
||||
# as we need colossalai.utils.checkpoint, we need to import colossalai
|
||||
# in forward function
|
||||
# TODO: Remove inline import
|
||||
prologue = self.gen_fn_def(free_vars, maybe_return_annotation[0])
|
||||
prologue = prologue + "\n import colossalai"
|
||||
|
||||
code = ''.join(body)
|
||||
code = '\n'.join(' ' + line for line in code.split('\n'))
|
||||
@@ -566,9 +577,14 @@ else:
|
||||
orig_args.insert(0, 'self')
|
||||
code = ''.join(body)
|
||||
code = '\n'.join(' ' + line for line in code.split('\n'))
|
||||
|
||||
# as we need colossalai.utils.checkpoint, we need to import colossalai
|
||||
# in forward function
|
||||
# TODO: Remove inline import
|
||||
fn_code = f"""
|
||||
{wrap_stmts}
|
||||
|
||||
def forward({', '.join(orig_args)}){maybe_return_annotation[0]}:
|
||||
import colossalai
|
||||
{code}"""
|
||||
return PythonCode(fn_code, globals_)
|
||||
|
Reference in New Issue
Block a user