[fx] Fix ckpt functions' definitions in forward (#1476)

* [fx] fix defining ckpt functions inside forward

* [fx] Modify activation checkpoint codegen and add ColoGraphModule

* [fx] some modification

* some modifications

* some modifications

* some modifications

* some modifications

* some code modifications
This commit is contained in:
Boyuan Yao
2022-08-22 16:59:54 +08:00
committed by GitHub
parent bb5f5289e0
commit 1f2e547f7a
4 changed files with 217 additions and 39 deletions

View File

@@ -1,12 +1,13 @@
import colossalai
import torch
from typing import List, Callable, Any, Tuple, Dict
try:
from torch.fx.node import Node, Argument, map_arg, _type_repr, _get_qualified_name
from torch.fx.graph import _Namespace, PythonCode, _custom_builtins, _is_from_torch, _format_target, magic_methods, CodeGen, _origin_type_map, inplace_methods
from torch.fx.graph import _Namespace, PythonCode, _custom_builtins, _is_from_torch, _format_target, magic_methods, CodeGen, _origin_type_map, inplace_methods, _CustomBuiltin
CODEGEN_AVAILABLE = True
except:
from torch.fx.graph import _Namespace, PythonCode, _custom_builtins, _is_from_torch, _format_target, magic_methods, _origin_type_map, _format_args
from torch.fx.graph import _Namespace, PythonCode, _custom_builtins, _is_from_torch, _format_target, magic_methods, _origin_type_map, _format_args, _CustomBuiltin
from torch.fx.node import Node, Argument, map_arg, _type_repr, _get_qualified_name
CODEGEN_AVAILABLE = False
@@ -89,7 +90,7 @@ def _gen_ckpt_fn_def(label, free_vars: List[str]) -> str:
"""
Generate the checkpoint function definition
"""
return f"def checkpoint_{label}({', '.join(free_vars)}):"
return f"def checkpoint_{label}({', '.join(['self'] + free_vars)}):"
def _gen_ckpt_output(output_vars: List[str]) -> str:
@@ -105,10 +106,10 @@ def _gen_ckpt_usage(label, activation_offload, input_vars, output_vars, use_reen
"""
outputs = ', '.join(output_vars)
inputs = ', '.join(input_vars)
return f'{outputs} = colossalai.utils.activation_checkpoint.checkpoint(checkpoint_{label}, {activation_offload}, {inputs}, use_reentrant={use_reentrant})'
return f'{outputs} = colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_{label}, {activation_offload}, {inputs}, use_reentrant={use_reentrant})'
def emit_code_with_activation_checkpoint(body, nodes, emit_node_func, delete_unused_value_func):
def emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node_func, delete_unused_value_func):
# find the activation checkpoint regions
ckpt_regions = _find_ckpt_regions(nodes)
start_idx = [item[0] for item in ckpt_regions]
@@ -133,27 +134,27 @@ def emit_code_with_activation_checkpoint(body, nodes, emit_node_func, delete_unu
if idx in start_idx:
label = start_idx.index(idx)
ckpt_fn_def = _gen_ckpt_fn_def(label, input_vars[label])
body.append(f'{ckpt_fn_def}\n')
ckpt_func.append(f'{ckpt_fn_def}\n')
within_ckpt_region = True
# NOTE: emit_node does not emit a string with newline. It depends
# on delete_unused_values to append one
emit_node_func(node)
# add indentation to the emmited node
# NOTE: currently we separate body and ckpt_func definition
if within_ckpt_region:
body[-1] = ' ' + body[-1]
# delete unused values
delete_unused_value_func(node)
emit_node_func(node, ckpt_func)
ckpt_func[-1] = ' ' + ckpt_func[-1]
delete_unused_value_func(node, ckpt_func)
else:
emit_node_func(node, body)
delete_unused_value_func(node, body)
if idx in end_idx:
# if this is the last node of the ckpt region
# generate return statement
label = end_idx.index(idx)
return_statement = _gen_ckpt_output(output_vars[label])
return_statement = f' {return_statement}\n'
body.append(return_statement)
return_statement = f' {return_statement}\n\n'
ckpt_func.append(return_statement)
# we need to check if the checkpoint need to offload the input
start_node_idx = start_idx[label]
@@ -221,6 +222,9 @@ if CODEGEN_AVAILABLE:
globals_[global_name] = obj
return global_name
# set _custom_builtins here so that we needn't import colossalai in forward
_custom_builtins["colossalai"] = _CustomBuiltin("import colossalai", colossalai)
# Pre-fill the globals table with registered builtins.
for name, (_, obj) in _custom_builtins.items():
add_global(name, obj)
@@ -287,7 +291,8 @@ if CODEGEN_AVAILABLE:
map_arg(node.args, lambda n: register_last_uses(n, node))
map_arg(node.kwargs, lambda n: register_last_uses(n, node))
def delete_unused_values(user: Node):
# NOTE: we add a variable to distinguish body and ckpt_func
def delete_unused_values(user: Node, body):
"""
Delete values after their last use. This ensures that values that are
not used in the remainder of the code are freed and the memory usage
@@ -305,7 +310,8 @@ if CODEGEN_AVAILABLE:
else:
body.append('\n')
def emit_node(node: Node):
# NOTE: we add a variable to distinguish body and ckpt_func
def emit_node(node: Node, body):
maybe_type_annotation = '' if node.type is None else f' : {type_repr(node.type)}'
if node.op == 'placeholder':
assert isinstance(node.target, str)
@@ -371,7 +377,8 @@ if CODEGEN_AVAILABLE:
raise NotImplementedError(f'node: {node.op} {node.target}')
# Modified for activation checkpointing
emit_code_with_activation_checkpoint(body, nodes, emit_node, delete_unused_values)
ckpt_func = []
emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node, delete_unused_values)
if len(body) == 0:
# If the Graph has no non-placeholder nodes, no lines for the body
@@ -395,7 +402,8 @@ if CODEGEN_AVAILABLE:
# in forward function
# TODO: Remove inline import
prologue = self.gen_fn_def(free_vars, maybe_return_annotation[0])
prologue = prologue + "\n import colossalai"
prologue = ''.join(ckpt_func) + prologue
prologue = prologue
code = ''.join(body)
code = '\n'.join(' ' + line for line in code.split('\n'))
@@ -444,6 +452,9 @@ else:
globals_[global_name] = obj
return global_name
# set _custom_builtins here so that we needn't import colossalai in forward
_custom_builtins["colossalai"] = _CustomBuiltin("import colossalai", colossalai)
# Pre-fill the globals table with registered builtins.
for name, (_, obj) in _custom_builtins.items():
add_global(name, obj)
@@ -484,7 +495,8 @@ else:
map_arg(node.args, lambda n: register_last_uses(n, node))
map_arg(node.kwargs, lambda n: register_last_uses(n, node))
def delete_unused_values(user: Node):
# NOTE: we add a variable to distinguish body and ckpt_func
def delete_unused_values(user: Node, body):
"""
Delete values after their last use. This ensures that values that are
not used in the remainder of the code are freed and the memory usage
@@ -502,7 +514,8 @@ else:
else:
body.append('\n')
def emit_node(node: Node):
# NOTE: we add a variable to distinguish body and ckpt_func
def emit_node(node: Node, body):
maybe_type_annotation = '' if node.type is None else f' : {type_repr(node.type)}'
if node.op == 'placeholder':
assert isinstance(node.target, str)
@@ -562,7 +575,8 @@ else:
raise NotImplementedError(f'node: {node.op} {node.target}')
# Modified for activation checkpointing
emit_code_with_activation_checkpoint(body, self.nodes, emit_node, delete_unused_values)
ckpt_func = []
emit_code_with_activation_checkpoint(body, ckpt_func, self.nodes, emit_node, delete_unused_values)
if len(body) == 0:
# If the Graph has no non-placeholder nodes, no lines for the body
@@ -587,6 +601,8 @@ else:
else:
wrap_stmts = ''
ckpt_func = ''.join(ckpt_func)
# If the original function didn't have self as its first argument, we
# would have added it.
if len(orig_args) == 0 or orig_args[0] != 'self':
@@ -600,7 +616,7 @@ else:
fn_code = f"""
{wrap_stmts}
{ckpt_func}
def forward({', '.join(orig_args)}){maybe_return_annotation[0]}:
import colossalai
{code}"""
return PythonCode(fn_code, globals_)