[fx] Add offload codegen (#1598)

* [fx] add input activation offload to codegen

* [fx] modify unit test

* [fx] remove two skips in torch11

* [fx] use all_input_nodes instead of _input_nodes
This commit is contained in:
Boyuan Yao
2022-09-14 15:49:06 +08:00
committed by GitHub
parent c8e9b2ad78
commit a7cda6f57d
4 changed files with 264 additions and 20 deletions

View File

@@ -17,6 +17,38 @@ else:
__all__ = ['python_code_with_activation_checkpoint']
def _gen_saved_tensors_hooks():
"""
Generate saved tensors hooks
"""
pack_hook = """def pack_hook(self, x):
if getattr(x, "offload", None):
return (x.device, x.cpu())
else:
return x
"""
unpack_hook = """def unpack_hook(self, packed):
if isinstance(packed, tuple):
device, tensor = packed
return tensor.to(device)
else:
return packed
"""
return pack_hook, unpack_hook
def _gen_save_tensors_hooks_context():
"""
Generate save tensors hooks context
"""
context = "with torch.autograd.graph.saved_tensors_hooks(self.pack_hook, self.unpack_hook):\n"
return context
def _find_input_and_output_nodes(nodes: List[Node]):
"""
Find the input and output node names which are not found in the given list of nodes.
@@ -211,7 +243,7 @@ def emit_ckpt_func(body,
ckpt_func[-1] = ' ' + ckpt_func[-1]
delete_unused_value_func(node, ckpt_func)
ckpt_func.append(' ' + _gen_ckpt_output(outputs) + '\n')
ckpt_func.append(' ' + _gen_ckpt_output(outputs) + '\n\n')
activation_offload = getattr(node_list[0], "activation_offload", False)
usage = _gen_ckpt_usage(label, activation_offload, inputs, outputs, False)
usage += "\n"
@@ -251,7 +283,7 @@ def emit_ckpt_func(body,
delete_unused_value_func(node, ckpt_func)
node_idx += 1
ckpt_func.append(' ' + _gen_ckpt_output(outputs) + '\n')
ckpt_func.append(' ' + _gen_ckpt_output(outputs) + '\n\n')
ckpt_func += ckpt_func_buffer
activation_offload = getattr(node_list[0], "activation_offload", False)
usage = _gen_ckpt_usage(label, activation_offload, inputs, outputs, False) + '\n'
@@ -266,7 +298,7 @@ def emit_ckpt_func(body,
ckpt_func[-1] = ' ' + ckpt_func[-1]
delete_unused_value_func(node, ckpt_func)
ckpt_func.append(' ' + _gen_ckpt_output(outputs) + '\n')
ckpt_func.append(' ' + _gen_ckpt_output(outputs) + '\n\n')
activation_offload = getattr(node_list[0], "activation_offload", False)
usage = _gen_ckpt_usage(label, activation_offload, inputs, outputs, False) + '\n'
if in_ckpt:
@@ -292,6 +324,9 @@ def emit_code_with_nested_activation_checkpoint(body, ckpt_func, nodes, emit_nod
node_list = list(nodes)
# this flag is to prevent repeated insert of save tensors
# hooks definition in ckpt_func
is_hook_inserted = False
node_idx = 0
while 1:
# break if we finish the processing all the nodes
@@ -307,8 +342,27 @@ def emit_code_with_nested_activation_checkpoint(body, ckpt_func, nodes, emit_nod
# process node in forward function
else:
node = node_list[node_idx]
emit_node_func(node, body)
delete_unused_value_func(node, body)
# if a node is outside of checkpoint region and want to offload
# it's input activation, we will use torch.saved_tensors_hooks
# to complete the offload process.
if getattr(node, "activation_offload", False):
if not is_hook_inserted:
pack_hook, unpack_hook = _gen_saved_tensors_hooks()
ckpt_func.insert(0, "\n".join([pack_hook, unpack_hook]) + "\n")
for par in node.all_input_nodes:
# annotate the input tensor for pack hook
body.append(f"setattr({repr(par)}, 'offload', True)\n")
body.append(_gen_save_tensors_hooks_context())
emit_node_func(node, body)
body[-1] = ' ' + body[-1]
delete_unused_value_func(node, body)
else:
emit_node_func(node, body)
delete_unused_value_func(node, body)
node_idx += 1
@@ -323,6 +377,10 @@ def emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node_func,
node_list = list(nodes)
# use this variable to avoid inserting hook functions
# to ckpt_func repeatedly
is_hook_inserted = False
# find the input and output var names for each region
for idx, (start, end) in enumerate(ckpt_regions):
ckpt_node_list = node_list[start:end + 1]
@@ -348,8 +406,26 @@ def emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node_func,
ckpt_func[-1] = ' ' + ckpt_func[-1]
delete_unused_value_func(node, ckpt_func)
else:
emit_node_func(node, body)
delete_unused_value_func(node, body)
# if a node is outside of checkpoint region wants to offload
# it's input activation, we will use torch.saved_tensors_hooks
# to complete the offload process.
if getattr(node, "activation_offload", False):
if not is_hook_inserted:
pack_hook, unpack_hook = _gen_saved_tensors_hooks()
ckpt_func.insert(0, "\n".join([pack_hook, unpack_hook]) + "\n")
for par in node.all_input_nodes:
# annotate the input tensor for pack hook
body.append(f"setattr({repr(par)}, 'offload', True)\n")
body.append(_gen_save_tensors_hooks_context())
emit_node_func(node, body)
body[-1] = ' ' + body[-1]
delete_unused_value_func(node, body)
else:
emit_node_func(node, body)
delete_unused_value_func(node, body)
if idx in end_idx:
# if this is the last node of the ckpt region
@@ -587,10 +663,13 @@ if CODEGEN_AVAILABLE:
# Modified for activation checkpointing
ckpt_func = []
if all(not isinstance(getattr(node, "activation_checkpoint", None), list) for node in nodes):
emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node, delete_unused_values)
else:
# if any node has a list of labels for activation_checkpoint, we
# will use nested type of activation checkpoint codegen
if any(isinstance(getattr(node, "activation_checkpoint", None), list) for node in nodes):
emit_code_with_nested_activation_checkpoint(body, ckpt_func, nodes, emit_node, delete_unused_values)
else:
emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node, delete_unused_values)
if len(body) == 0:
# If the Graph has no non-placeholder nodes, no lines for the body
@@ -612,7 +691,6 @@ if CODEGEN_AVAILABLE:
# as we need colossalai.utils.checkpoint, we need to import colossalai
# in forward function
# TODO: Remove inline import
prologue = self.gen_fn_def(free_vars, maybe_return_annotation[0])
prologue = ''.join(ckpt_func) + prologue
prologue = prologue
@@ -788,10 +866,13 @@ else:
# Modified for activation checkpointing
ckpt_func = []
if all(not isinstance(getattr(node, "activation_checkpoint", None), list) for node in self.nodes):
emit_code_with_activation_checkpoint(body, ckpt_func, self.nodes, emit_node, delete_unused_values)
else:
# if any node has a list of labels for activation_checkpoint, we
# will use nested type of activation checkpoint codegen
if any(isinstance(getattr(node, "activation_checkpoint", None), list) for node in self.nodes):
emit_code_with_nested_activation_checkpoint(body, ckpt_func, self.nodes, emit_node, delete_unused_values)
else:
emit_code_with_activation_checkpoint(body, ckpt_func, self.nodes, emit_node, delete_unused_values)
if len(body) == 0:
# If the Graph has no non-placeholder nodes, no lines for the body
@@ -827,7 +908,6 @@ else:
# as we need colossalai.utils.checkpoint, we need to import colossalai
# in forward function
# TODO: Remove inline import
fn_code = f"""
{wrap_stmts}