[fx] Use colossalai checkpoint and add offload recognition in codegen (#1439)

* [fx] Use colossalai.utils.checkpoint to replace torch.utils.checkpoint for offload activation and add offload annotation recognition in codegen

* [fx] Use colossalai.utils.checkpoint to replace torch.utils.checkpoint for offload activation and add offload annotation recognition in codegen

* Modification of test and add TODO in codegen

* [fx] Modification of colossal ckpt usage

* [fx] add gpc.destroy() to test_codegen
This commit is contained in:
Boyuan Yao
2022-08-12 12:23:30 +08:00
committed by GitHub
parent e9460b45c8
commit 5774fe0270
2 changed files with 63 additions and 7 deletions

View File

@@ -99,13 +99,13 @@ def _gen_ckpt_output(output_vars: List[str]) -> str:
return f"return {', '.join(output_vars)}"
def _gen_ckpt_usage(label, input_vars, output_vars):
def _gen_ckpt_usage(label, activation_offload, input_vars, output_vars):
"""
Generate the checkpoint function call code text
"""
outputs = ', '.join(output_vars)
inputs = ', '.join(input_vars)
return f'{outputs} = torch.utils.checkpoint.checkpoint(checkpoint_{label}, {inputs})'
return f'{outputs} = colossalai.utils.activation_checkpoint.checkpoint(checkpoint_{label}, {activation_offload}, {inputs})'
def emit_code_with_activation_checkpoint(body, nodes, emit_node_func, delete_unused_value_func):
@@ -155,8 +155,15 @@ def emit_code_with_activation_checkpoint(body, nodes, emit_node_func, delete_unu
return_statement = f' {return_statement}\n'
body.append(return_statement)
# we need to check if the checkpoint need to offload the input
start_node_idx = start_idx[label]
if hasattr(node_list[start_node_idx], 'activation_offload'):
activation_offload = node_list[start_node_idx].activation_offload
else:
activation_offload = False
# generate checkpoint function call in a new line
usage = _gen_ckpt_usage(label, input_vars[label], output_vars[label])
usage = _gen_ckpt_usage(label, activation_offload, input_vars[label], output_vars[label])
usage += '\n'
body.append(usage)
within_ckpt_region = False
@@ -368,7 +375,11 @@ if codegen_available:
for name, value in self.additional_globals():
add_global(name, value)
# as we need colossalai.utils.checkpoint, we need to import colossalai
# in forward function
# TODO: Remove inline import
prologue = self.gen_fn_def(free_vars, maybe_return_annotation[0])
prologue = prologue + "\n import colossalai"
code = ''.join(body)
code = '\n'.join(' ' + line for line in code.split('\n'))
@@ -566,9 +577,14 @@ else:
orig_args.insert(0, 'self')
code = ''.join(body)
code = '\n'.join(' ' + line for line in code.split('\n'))
# as we need colossalai.utils.checkpoint, we need to import colossalai
# in forward function
# TODO: Remove inline import
fn_code = f"""
{wrap_stmts}
def forward({', '.join(orig_args)}){maybe_return_annotation[0]}:
import colossalai
{code}"""
return PythonCode(fn_code, globals_)