align evoformer

This commit is contained in:
oahzxl
2022-11-02 15:49:25 +08:00
parent 86f2a31474
commit 820ea4d056
6 changed files with 66 additions and 191 deletions

View File

@@ -1,5 +1,6 @@
import colossalai
import torch
import copy
from typing import List, Callable, Any, Tuple, Dict, Iterable
try:
@@ -17,74 +18,18 @@ else:
__all__ = ['python_code_with_activation_checkpoint']
def _gen_saved_tensors_hooks():
"""
Generate saved tensors hooks
"""
pack_hook = """def pack_hook_input(self, x):
if getattr(x, "offload", False):
return (x.device, x.cpu())
else:
return x
def pack_hook_no_input(self, x):
if getattr(x, "offload", True):
return (x.device, x.cpu())
else:
return x
"""
unpack_hook = """def unpack_hook(self, packed):
if isinstance(packed, tuple):
device, tensor = packed
return tensor.to(device)
else:
return packed
"""
return pack_hook, unpack_hook
def _gen_loop_5(to_keep):
context = "chunk_result = []\nfor gen_loop_idx in range(4):\n"
context += " chunk_tensor = " + to_keep + "[gen_loop_idx, :]\n"
def _gen_loop_start(to_keep, chunk_size=2):
context = "chunk_result = []; chunk_size = %d\nfor gen_loop_idx in range(0, %s.shape[0], chunk_size):\n" % (chunk_size, to_keep[0])
context += " chunk_tensor = " + to_keep + "[gen_loop_idx:gen_loop_idx + chunk_size, :]\n"
return context
def _gen_loop_5_final(final_name, to_keep):
def _gen_loop_end(final_name, to_keep):
context = " chunk_result.append(" + final_name + ")\n"
context += "chunk_result = torch.cat(chunk_result, dim=0); " + to_keep[0] + " = None\n"
context += final_name + " = chunk_result; chunk_result = None\n"
return context
def _gen_save_tensors_hooks_context(offload_input=True) -> str:
"""Generate customized saved_tensors_hooks
Args:
offload_input (bool, optional): whether we need offload input, if offload_input=False,
we will use self.pack_hook_no_input instead. Defaults to True.
Returns:
str: generated context
"""
if offload_input:
context = "with torch.autograd.graph.saved_tensors_hooks(self.pack_hook_input, self.unpack_hook):\n"
else:
context = "with torch.autograd.graph.saved_tensors_hooks(self.pack_hook_no_input, self.unpack_hook):\n"
return context
def _gen_save_on_cpu_context():
"""
Generate save on cpu context
"""
context = "with torch.autograd.graph.save_on_cpu(pin_memory=True):\n"
return context
def _find_input_and_output_nodes(nodes: List[Node]):
"""
@@ -112,49 +57,6 @@ def _find_input_and_output_nodes(nodes: List[Node]):
return input_nodes, output_nodes
def _find_ckpt_regions(nodes: List[Node]):
"""
Find the checkpoint regions given a list of consecutive nodes. The outputs will be list
of tuples, each tuple is in the form of (start_index, end_index).
"""
ckpt_nodes = []
ckpt_regions = []
start = -1
end = -1
current_region = None
for idx, node in enumerate(nodes):
if hasattr(node, 'activation_checkpoint'):
act_ckpt_label = node.activation_checkpoint
# this activation checkpoint label is not set yet
# meaning this is the first node of the activation ckpt region
if current_region is None:
current_region = act_ckpt_label
start = idx
# if activation checkpoint has changed
# we restart the tracking
# e.g. node ckpt states = [ckpt1, ckpt2, ckpt2, ckpt2]
if act_ckpt_label != current_region:
assert start != -1
ckpt_regions.append((start, idx - 1))
current_region = act_ckpt_label
start = idx
end = -1
elif current_region is not None and not hasattr(node, 'activation_checkpoint'):
# used to check the case below
# node ckpt states = [ckpt, ckpt, non-ckpt]
end = idx - 1
assert start != -1 and end != -1
ckpt_regions.append((start, end))
start = end = -1
current_region = None
else:
pass
return ckpt_regions
def _find_offload_regions(nodes: List[Node]):
"""This function is to find the offload regions
In pofo algorithm, during annotation, we will annotate the offload region with the
@@ -400,12 +302,9 @@ def emit_code_with_chunk(body, ckpt_func, nodes, emit_node_func, delete_unused_v
emit_node_func: function to emit node
delete_unused_value_func: function to remove the unused value
"""
ckpt_regions = _find_nested_ckpt_regions(nodes, 0)
start_idx = [item[0] for item in ckpt_regions]
end_idx = [item[1] for item in ckpt_regions]
# find the offload regions
chunk_regions, chunk_labels = _find_offload_regions(nodes)
chunk_regions = [(1, 4)]
chunk_starts = [item[0] for item in chunk_regions]
chunk_ends = [item[1] for item in chunk_regions]
chunk_inputs = []
@@ -424,7 +323,7 @@ def emit_code_with_chunk(body, ckpt_func, nodes, emit_node_func, delete_unused_v
# this flag is to prevent repeated insert of save tensors
# hooks definition in ckpt_func
node_idx = 0
to_keep = []
chunk_var = []
while node_idx < len(node_list):
# break if we finish the processing all the nodes
if node_idx >= len(node_list):
@@ -435,28 +334,30 @@ def emit_code_with_chunk(body, ckpt_func, nodes, emit_node_func, delete_unused_v
node = node_list[node_idx]
if node_idx in chunk_starts:
# save chunk input var, dont delete it
to_keep.extend(node.args[0].name)
within_chunk_region = True
# add for loop
body.append(_gen_loop_5(to_keep[0]))
# change first node's input to new chunked var
node_args = list(node.args)
node_args[0] = 'chunk_tensor'
# save chunk input var, dont delete it
chunk_var.append(node.args[0].name)
# add for loop
body.append(_gen_loop_start(chunk_var[0]))
if within_chunk_region:
emit_node_func(node, body)
# replace input var with chunk var
if node_idx in chunk_starts:
body[-1] = body[-1].replace("("+ chunk_var[0] +")", '(chunk_tensor)')
body[-1] = ' ' + body[-1]
delete_unused_value_func(node, body, to_keep)
delete_unused_value_func(node, body, chunk_var)
else:
emit_node_func(node, body)
if node_idx not in chunk_inputs:
delete_unused_value_func(node, body, to_keep)
delete_unused_value_func(node, body, chunk_var)
if node_idx in chunk_ends:
body.append(_gen_loop_5_final(node.name, to_keep))
to_keep = []
body.append(_gen_loop_end(node.name, chunk_var))
chunk_var = []
within_chunk_region = False
node_idx += 1
@@ -580,9 +481,7 @@ if CODEGEN_AVAILABLE:
body.append('\n')
return
nodes_to_delete = user_to_last_uses.get(user, [])
for n in nodes_to_delete:
if n.name in to_keep:
nodes_to_delete.remove(n)
nodes_to_delete = [i for i in nodes_to_delete if i.name not in to_keep]
if len(nodes_to_delete):
to_delete_str = ' = '.join([repr(n) for n in nodes_to_delete] + ['None'])
body.append(f'; {to_delete_str}\n')