mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-07 12:01:39 +00:00
[fx] Fix ckpt functions' definitions in forward (#1476)
* [fx] fix defining ckpt functions inside forward * [fx] Modify activation checkpoint codegen and add ColoGraphModule * [fx] some modification * some modifications * some modifications * some modifications * some modifications * some code modifications
This commit is contained in:
@@ -1,12 +1,13 @@
|
||||
import colossalai
|
||||
import torch
|
||||
from typing import List, Callable, Any, Tuple, Dict
|
||||
|
||||
try:
|
||||
from torch.fx.node import Node, Argument, map_arg, _type_repr, _get_qualified_name
|
||||
from torch.fx.graph import _Namespace, PythonCode, _custom_builtins, _is_from_torch, _format_target, magic_methods, CodeGen, _origin_type_map, inplace_methods
|
||||
from torch.fx.graph import _Namespace, PythonCode, _custom_builtins, _is_from_torch, _format_target, magic_methods, CodeGen, _origin_type_map, inplace_methods, _CustomBuiltin
|
||||
CODEGEN_AVAILABLE = True
|
||||
except:
|
||||
from torch.fx.graph import _Namespace, PythonCode, _custom_builtins, _is_from_torch, _format_target, magic_methods, _origin_type_map, _format_args
|
||||
from torch.fx.graph import _Namespace, PythonCode, _custom_builtins, _is_from_torch, _format_target, magic_methods, _origin_type_map, _format_args, _CustomBuiltin
|
||||
from torch.fx.node import Node, Argument, map_arg, _type_repr, _get_qualified_name
|
||||
CODEGEN_AVAILABLE = False
|
||||
|
||||
@@ -89,7 +90,7 @@ def _gen_ckpt_fn_def(label, free_vars: List[str]) -> str:
|
||||
"""
|
||||
Generate the checkpoint function definition
|
||||
"""
|
||||
return f"def checkpoint_{label}({', '.join(free_vars)}):"
|
||||
return f"def checkpoint_{label}({', '.join(['self'] + free_vars)}):"
|
||||
|
||||
|
||||
def _gen_ckpt_output(output_vars: List[str]) -> str:
|
||||
@@ -105,10 +106,10 @@ def _gen_ckpt_usage(label, activation_offload, input_vars, output_vars, use_reen
|
||||
"""
|
||||
outputs = ', '.join(output_vars)
|
||||
inputs = ', '.join(input_vars)
|
||||
return f'{outputs} = colossalai.utils.activation_checkpoint.checkpoint(checkpoint_{label}, {activation_offload}, {inputs}, use_reentrant={use_reentrant})'
|
||||
return f'{outputs} = colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_{label}, {activation_offload}, {inputs}, use_reentrant={use_reentrant})'
|
||||
|
||||
|
||||
def emit_code_with_activation_checkpoint(body, nodes, emit_node_func, delete_unused_value_func):
|
||||
def emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node_func, delete_unused_value_func):
|
||||
# find the activation checkpoint regions
|
||||
ckpt_regions = _find_ckpt_regions(nodes)
|
||||
start_idx = [item[0] for item in ckpt_regions]
|
||||
@@ -133,27 +134,27 @@ def emit_code_with_activation_checkpoint(body, nodes, emit_node_func, delete_unu
|
||||
if idx in start_idx:
|
||||
label = start_idx.index(idx)
|
||||
ckpt_fn_def = _gen_ckpt_fn_def(label, input_vars[label])
|
||||
body.append(f'{ckpt_fn_def}\n')
|
||||
ckpt_func.append(f'{ckpt_fn_def}\n')
|
||||
within_ckpt_region = True
|
||||
|
||||
# NOTE: emit_node does not emit a string with newline. It depends
|
||||
# on delete_unused_values to append one
|
||||
emit_node_func(node)
|
||||
|
||||
# add indentation to the emmited node
|
||||
# NOTE: currently we separate body and ckpt_func definition
|
||||
if within_ckpt_region:
|
||||
body[-1] = ' ' + body[-1]
|
||||
|
||||
# delete unused values
|
||||
delete_unused_value_func(node)
|
||||
emit_node_func(node, ckpt_func)
|
||||
ckpt_func[-1] = ' ' + ckpt_func[-1]
|
||||
delete_unused_value_func(node, ckpt_func)
|
||||
else:
|
||||
emit_node_func(node, body)
|
||||
delete_unused_value_func(node, body)
|
||||
|
||||
if idx in end_idx:
|
||||
# if this is the last node of the ckpt region
|
||||
# generate return statement
|
||||
label = end_idx.index(idx)
|
||||
return_statement = _gen_ckpt_output(output_vars[label])
|
||||
return_statement = f' {return_statement}\n'
|
||||
body.append(return_statement)
|
||||
return_statement = f' {return_statement}\n\n'
|
||||
ckpt_func.append(return_statement)
|
||||
|
||||
# we need to check if the checkpoint need to offload the input
|
||||
start_node_idx = start_idx[label]
|
||||
@@ -221,6 +222,9 @@ if CODEGEN_AVAILABLE:
|
||||
globals_[global_name] = obj
|
||||
return global_name
|
||||
|
||||
# set _custom_builtins here so that we needn't import colossalai in forward
|
||||
_custom_builtins["colossalai"] = _CustomBuiltin("import colossalai", colossalai)
|
||||
|
||||
# Pre-fill the globals table with registered builtins.
|
||||
for name, (_, obj) in _custom_builtins.items():
|
||||
add_global(name, obj)
|
||||
@@ -287,7 +291,8 @@ if CODEGEN_AVAILABLE:
|
||||
map_arg(node.args, lambda n: register_last_uses(n, node))
|
||||
map_arg(node.kwargs, lambda n: register_last_uses(n, node))
|
||||
|
||||
def delete_unused_values(user: Node):
|
||||
# NOTE: we add a variable to distinguish body and ckpt_func
|
||||
def delete_unused_values(user: Node, body):
|
||||
"""
|
||||
Delete values after their last use. This ensures that values that are
|
||||
not used in the remainder of the code are freed and the memory usage
|
||||
@@ -305,7 +310,8 @@ if CODEGEN_AVAILABLE:
|
||||
else:
|
||||
body.append('\n')
|
||||
|
||||
def emit_node(node: Node):
|
||||
# NOTE: we add a variable to distinguish body and ckpt_func
|
||||
def emit_node(node: Node, body):
|
||||
maybe_type_annotation = '' if node.type is None else f' : {type_repr(node.type)}'
|
||||
if node.op == 'placeholder':
|
||||
assert isinstance(node.target, str)
|
||||
@@ -371,7 +377,8 @@ if CODEGEN_AVAILABLE:
|
||||
raise NotImplementedError(f'node: {node.op} {node.target}')
|
||||
|
||||
# Modified for activation checkpointing
|
||||
emit_code_with_activation_checkpoint(body, nodes, emit_node, delete_unused_values)
|
||||
ckpt_func = []
|
||||
emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node, delete_unused_values)
|
||||
|
||||
if len(body) == 0:
|
||||
# If the Graph has no non-placeholder nodes, no lines for the body
|
||||
@@ -395,7 +402,8 @@ if CODEGEN_AVAILABLE:
|
||||
# in forward function
|
||||
# TODO: Remove inline import
|
||||
prologue = self.gen_fn_def(free_vars, maybe_return_annotation[0])
|
||||
prologue = prologue + "\n import colossalai"
|
||||
prologue = ''.join(ckpt_func) + prologue
|
||||
prologue = prologue
|
||||
|
||||
code = ''.join(body)
|
||||
code = '\n'.join(' ' + line for line in code.split('\n'))
|
||||
@@ -444,6 +452,9 @@ else:
|
||||
globals_[global_name] = obj
|
||||
return global_name
|
||||
|
||||
# set _custom_builtins here so that we needn't import colossalai in forward
|
||||
_custom_builtins["colossalai"] = _CustomBuiltin("import colossalai", colossalai)
|
||||
|
||||
# Pre-fill the globals table with registered builtins.
|
||||
for name, (_, obj) in _custom_builtins.items():
|
||||
add_global(name, obj)
|
||||
@@ -484,7 +495,8 @@ else:
|
||||
map_arg(node.args, lambda n: register_last_uses(n, node))
|
||||
map_arg(node.kwargs, lambda n: register_last_uses(n, node))
|
||||
|
||||
def delete_unused_values(user: Node):
|
||||
# NOTE: we add a variable to distinguish body and ckpt_func
|
||||
def delete_unused_values(user: Node, body):
|
||||
"""
|
||||
Delete values after their last use. This ensures that values that are
|
||||
not used in the remainder of the code are freed and the memory usage
|
||||
@@ -502,7 +514,8 @@ else:
|
||||
else:
|
||||
body.append('\n')
|
||||
|
||||
def emit_node(node: Node):
|
||||
# NOTE: we add a variable to distinguish body and ckpt_func
|
||||
def emit_node(node: Node, body):
|
||||
maybe_type_annotation = '' if node.type is None else f' : {type_repr(node.type)}'
|
||||
if node.op == 'placeholder':
|
||||
assert isinstance(node.target, str)
|
||||
@@ -562,7 +575,8 @@ else:
|
||||
raise NotImplementedError(f'node: {node.op} {node.target}')
|
||||
|
||||
# Modified for activation checkpointing
|
||||
emit_code_with_activation_checkpoint(body, self.nodes, emit_node, delete_unused_values)
|
||||
ckpt_func = []
|
||||
emit_code_with_activation_checkpoint(body, ckpt_func, self.nodes, emit_node, delete_unused_values)
|
||||
|
||||
if len(body) == 0:
|
||||
# If the Graph has no non-placeholder nodes, no lines for the body
|
||||
@@ -587,6 +601,8 @@ else:
|
||||
else:
|
||||
wrap_stmts = ''
|
||||
|
||||
ckpt_func = ''.join(ckpt_func)
|
||||
|
||||
# If the original function didn't have self as its first argument, we
|
||||
# would have added it.
|
||||
if len(orig_args) == 0 or orig_args[0] != 'self':
|
||||
@@ -600,7 +616,7 @@ else:
|
||||
fn_code = f"""
|
||||
{wrap_stmts}
|
||||
|
||||
{ckpt_func}
|
||||
def forward({', '.join(orig_args)}){maybe_return_annotation[0]}:
|
||||
import colossalai
|
||||
{code}"""
|
||||
return PythonCode(fn_code, globals_)
|
||||
|
Reference in New Issue
Block a user