mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-12-22 03:53:30 +00:00
align evoformer
This commit is contained in:
143
chunk_codegen.py
143
chunk_codegen.py
@@ -1,5 +1,6 @@
|
||||
import colossalai
|
||||
import torch
|
||||
import copy
|
||||
from typing import List, Callable, Any, Tuple, Dict, Iterable
|
||||
|
||||
try:
|
||||
@@ -17,74 +18,18 @@ else:
|
||||
__all__ = ['python_code_with_activation_checkpoint']
|
||||
|
||||
|
||||
def _gen_saved_tensors_hooks():
|
||||
"""
|
||||
Generate saved tensors hooks
|
||||
"""
|
||||
|
||||
pack_hook = """def pack_hook_input(self, x):
|
||||
if getattr(x, "offload", False):
|
||||
return (x.device, x.cpu())
|
||||
else:
|
||||
return x
|
||||
|
||||
def pack_hook_no_input(self, x):
|
||||
if getattr(x, "offload", True):
|
||||
return (x.device, x.cpu())
|
||||
else:
|
||||
return x
|
||||
"""
|
||||
|
||||
unpack_hook = """def unpack_hook(self, packed):
|
||||
if isinstance(packed, tuple):
|
||||
device, tensor = packed
|
||||
return tensor.to(device)
|
||||
else:
|
||||
return packed
|
||||
"""
|
||||
|
||||
return pack_hook, unpack_hook
|
||||
|
||||
|
||||
def _gen_loop_5(to_keep):
|
||||
context = "chunk_result = []\nfor gen_loop_idx in range(4):\n"
|
||||
context += " chunk_tensor = " + to_keep + "[gen_loop_idx, :]\n"
|
||||
def _gen_loop_start(to_keep, chunk_size=2):
|
||||
context = "chunk_result = []; chunk_size = %d\nfor gen_loop_idx in range(0, %s.shape[0], chunk_size):\n" % (chunk_size, to_keep[0])
|
||||
context += " chunk_tensor = " + to_keep + "[gen_loop_idx:gen_loop_idx + chunk_size, :]\n"
|
||||
return context
|
||||
|
||||
|
||||
def _gen_loop_5_final(final_name, to_keep):
|
||||
def _gen_loop_end(final_name, to_keep):
|
||||
context = " chunk_result.append(" + final_name + ")\n"
|
||||
context += "chunk_result = torch.cat(chunk_result, dim=0); " + to_keep[0] + " = None\n"
|
||||
context += final_name + " = chunk_result; chunk_result = None\n"
|
||||
return context
|
||||
|
||||
|
||||
def _gen_save_tensors_hooks_context(offload_input=True) -> str:
|
||||
"""Generate customized saved_tensors_hooks
|
||||
|
||||
Args:
|
||||
offload_input (bool, optional): whether we need offload input, if offload_input=False,
|
||||
we will use self.pack_hook_no_input instead. Defaults to True.
|
||||
|
||||
Returns:
|
||||
str: generated context
|
||||
"""
|
||||
|
||||
if offload_input:
|
||||
context = "with torch.autograd.graph.saved_tensors_hooks(self.pack_hook_input, self.unpack_hook):\n"
|
||||
else:
|
||||
context = "with torch.autograd.graph.saved_tensors_hooks(self.pack_hook_no_input, self.unpack_hook):\n"
|
||||
return context
|
||||
|
||||
|
||||
def _gen_save_on_cpu_context():
|
||||
"""
|
||||
Generate save on cpu context
|
||||
"""
|
||||
|
||||
context = "with torch.autograd.graph.save_on_cpu(pin_memory=True):\n"
|
||||
return context
|
||||
|
||||
|
||||
def _find_input_and_output_nodes(nodes: List[Node]):
|
||||
"""
|
||||
@@ -112,49 +57,6 @@ def _find_input_and_output_nodes(nodes: List[Node]):
|
||||
return input_nodes, output_nodes
|
||||
|
||||
|
||||
def _find_ckpt_regions(nodes: List[Node]):
|
||||
"""
|
||||
Find the checkpoint regions given a list of consecutive nodes. The outputs will be list
|
||||
of tuples, each tuple is in the form of (start_index, end_index).
|
||||
"""
|
||||
ckpt_nodes = []
|
||||
ckpt_regions = []
|
||||
start = -1
|
||||
end = -1
|
||||
current_region = None
|
||||
|
||||
for idx, node in enumerate(nodes):
|
||||
if hasattr(node, 'activation_checkpoint'):
|
||||
act_ckpt_label = node.activation_checkpoint
|
||||
|
||||
# this activation checkpoint label is not set yet
|
||||
# meaning this is the first node of the activation ckpt region
|
||||
if current_region is None:
|
||||
current_region = act_ckpt_label
|
||||
start = idx
|
||||
|
||||
# if activation checkpoint has changed
|
||||
# we restart the tracking
|
||||
# e.g. node ckpt states = [ckpt1, ckpt2, ckpt2, ckpt2]
|
||||
if act_ckpt_label != current_region:
|
||||
assert start != -1
|
||||
ckpt_regions.append((start, idx - 1))
|
||||
current_region = act_ckpt_label
|
||||
start = idx
|
||||
end = -1
|
||||
elif current_region is not None and not hasattr(node, 'activation_checkpoint'):
|
||||
# used to check the case below
|
||||
# node ckpt states = [ckpt, ckpt, non-ckpt]
|
||||
end = idx - 1
|
||||
assert start != -1 and end != -1
|
||||
ckpt_regions.append((start, end))
|
||||
start = end = -1
|
||||
current_region = None
|
||||
else:
|
||||
pass
|
||||
return ckpt_regions
|
||||
|
||||
|
||||
def _find_offload_regions(nodes: List[Node]):
|
||||
"""This function is to find the offload regions
|
||||
In pofo algorithm, during annotation, we will annotate the offload region with the
|
||||
@@ -400,12 +302,9 @@ def emit_code_with_chunk(body, ckpt_func, nodes, emit_node_func, delete_unused_v
|
||||
emit_node_func: function to emit node
|
||||
delete_unused_value_func: function to remove the unused value
|
||||
"""
|
||||
ckpt_regions = _find_nested_ckpt_regions(nodes, 0)
|
||||
start_idx = [item[0] for item in ckpt_regions]
|
||||
end_idx = [item[1] for item in ckpt_regions]
|
||||
|
||||
# find the offload regions
|
||||
chunk_regions, chunk_labels = _find_offload_regions(nodes)
|
||||
chunk_regions = [(1, 4)]
|
||||
chunk_starts = [item[0] for item in chunk_regions]
|
||||
chunk_ends = [item[1] for item in chunk_regions]
|
||||
chunk_inputs = []
|
||||
@@ -424,7 +323,7 @@ def emit_code_with_chunk(body, ckpt_func, nodes, emit_node_func, delete_unused_v
|
||||
# this flag is to prevent repeated insert of save tensors
|
||||
# hooks definition in ckpt_func
|
||||
node_idx = 0
|
||||
to_keep = []
|
||||
chunk_var = []
|
||||
while node_idx < len(node_list):
|
||||
# break if we finish the processing all the nodes
|
||||
if node_idx >= len(node_list):
|
||||
@@ -435,28 +334,30 @@ def emit_code_with_chunk(body, ckpt_func, nodes, emit_node_func, delete_unused_v
|
||||
node = node_list[node_idx]
|
||||
|
||||
if node_idx in chunk_starts:
|
||||
# save chunk input var, dont delete it
|
||||
to_keep.extend(node.args[0].name)
|
||||
within_chunk_region = True
|
||||
# add for loop
|
||||
body.append(_gen_loop_5(to_keep[0]))
|
||||
# change first node's input to new chunked var
|
||||
node_args = list(node.args)
|
||||
node_args[0] = 'chunk_tensor'
|
||||
|
||||
# save chunk input var, dont delete it
|
||||
chunk_var.append(node.args[0].name)
|
||||
|
||||
# add for loop
|
||||
body.append(_gen_loop_start(chunk_var[0]))
|
||||
|
||||
if within_chunk_region:
|
||||
emit_node_func(node, body)
|
||||
# replace input var with chunk var
|
||||
if node_idx in chunk_starts:
|
||||
body[-1] = body[-1].replace("("+ chunk_var[0] +")", '(chunk_tensor)')
|
||||
body[-1] = ' ' + body[-1]
|
||||
delete_unused_value_func(node, body, to_keep)
|
||||
delete_unused_value_func(node, body, chunk_var)
|
||||
|
||||
else:
|
||||
emit_node_func(node, body)
|
||||
if node_idx not in chunk_inputs:
|
||||
delete_unused_value_func(node, body, to_keep)
|
||||
delete_unused_value_func(node, body, chunk_var)
|
||||
|
||||
if node_idx in chunk_ends:
|
||||
body.append(_gen_loop_5_final(node.name, to_keep))
|
||||
to_keep = []
|
||||
body.append(_gen_loop_end(node.name, chunk_var))
|
||||
chunk_var = []
|
||||
within_chunk_region = False
|
||||
|
||||
node_idx += 1
|
||||
@@ -580,9 +481,7 @@ if CODEGEN_AVAILABLE:
|
||||
body.append('\n')
|
||||
return
|
||||
nodes_to_delete = user_to_last_uses.get(user, [])
|
||||
for n in nodes_to_delete:
|
||||
if n.name in to_keep:
|
||||
nodes_to_delete.remove(n)
|
||||
nodes_to_delete = [i for i in nodes_to_delete if i.name not in to_keep]
|
||||
if len(nodes_to_delete):
|
||||
to_delete_str = ' = '.join([repr(n) for n in nodes_to_delete] + ['None'])
|
||||
body.append(f'; {to_delete_str}\n')
|
||||
|
||||
Reference in New Issue
Block a user