mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 11:32:10 +00:00
[autochunk] support evoformer tracer (#2485)
support full evoformer tracer, which is a main module of alphafold. previously we just support a simplifed version of it. 1. support some evoformer's op in fx 2. support evoformer test 3. add repos for test code
This commit is contained in:
@@ -10,6 +10,7 @@ from .utils import (
|
||||
|
||||
|
||||
class TraceFlow(object):
|
||||
|
||||
def __init__(self, trace_indice: TraceIndice) -> None:
|
||||
self.trace_indice = trace_indice
|
||||
|
||||
@@ -28,9 +29,7 @@ class TraceFlow(object):
|
||||
start_node_idx = find_idx_by_name(start_node.name, self.trace_indice.node_list)
|
||||
end_node_trace = self.trace_indice._find_trace_from_node(end_node)
|
||||
end_node_trace_source = end_node_trace["source"][end_dim]
|
||||
sorted_source = sorted(
|
||||
end_node_trace_source.items(), key=lambda d: d[0], reverse=True
|
||||
)
|
||||
sorted_source = sorted(end_node_trace_source.items(), key=lambda d: d[0], reverse=True)
|
||||
for node_idx, node_dim in sorted_source:
|
||||
if node_idx == start_node_idx and start_dim in node_dim:
|
||||
return True
|
||||
@@ -70,10 +69,8 @@ class TraceFlow(object):
|
||||
input_node_idx = find_idx_by_name(input_node.name, self.trace_indice.node_list)
|
||||
node_trace_source = self.trace_indice._find_source_trace_from_node(node)
|
||||
for node_dim in range(len(get_node_shape(node))):
|
||||
if (
|
||||
input_node_idx in node_trace_source[node_dim]
|
||||
and input_dim[0] in node_trace_source[node_dim][input_node_idx]
|
||||
):
|
||||
if (input_node_idx in node_trace_source[node_dim]
|
||||
and input_dim[0] in node_trace_source[node_dim][input_node_idx]):
|
||||
return node_dim
|
||||
return None
|
||||
|
||||
@@ -81,15 +78,11 @@ class TraceFlow(object):
|
||||
input_dim_after_node = {}
|
||||
for input_node_idx, input_node in enumerate(chunk_infos["inputs"]):
|
||||
for k, v in chunk_infos["inputs_dim"][input_node_idx].items():
|
||||
inherit_dim = self._find_inherit_dim(
|
||||
input_node, v, self.trace_indice.node_list[k]
|
||||
)
|
||||
inherit_dim = self._find_inherit_dim(input_node, v, self.trace_indice.node_list[k])
|
||||
if inherit_dim:
|
||||
input_dim_after_node[k] = inherit_dim
|
||||
|
||||
for node in self.trace_indice.node_list[
|
||||
chunk_infos["region"][0] : chunk_infos["region"][1] + 1
|
||||
]:
|
||||
for node in self.trace_indice.node_list[chunk_infos["region"][0]:chunk_infos["region"][1] + 1]:
|
||||
if is_non_compute_node_except_placeholder(node):
|
||||
continue
|
||||
count = 0
|
||||
@@ -159,9 +152,7 @@ class TraceFlow(object):
|
||||
if arg_node in all_node_info:
|
||||
if all_node_info[arg_node]["chunk_dim"] != arg_dim:
|
||||
return False
|
||||
all_node_info[arg_node]["fix_dim"] = list(
|
||||
set(all_node_info[arg_node]["fix_dim"] + arg_fix_dim)
|
||||
)
|
||||
all_node_info[arg_node]["fix_dim"] = list(set(all_node_info[arg_node]["fix_dim"] + arg_fix_dim))
|
||||
# else add it to list
|
||||
else:
|
||||
all_node_info[arg_node] = {"chunk_dim": arg_dim, "fix_dim": arg_fix_dim}
|
||||
@@ -170,9 +161,7 @@ class TraceFlow(object):
|
||||
return True
|
||||
|
||||
def _get_all_node_info(self, end_dim, start_idx, end_idx):
|
||||
cur_node_list = [
|
||||
self.trace_indice.node_list[end_idx]
|
||||
] # start from the last node
|
||||
cur_node_list = [self.trace_indice.node_list[end_idx]] # start from the last node
|
||||
all_node_info = {cur_node_list[0]: {"chunk_dim": end_dim, "fix_dim": []}}
|
||||
|
||||
while len(cur_node_list) > 0:
|
||||
@@ -183,12 +172,8 @@ class TraceFlow(object):
|
||||
cur_node_chunk_dim = all_node_info[cur_node]["chunk_dim"]
|
||||
cur_node_fix_dim = all_node_info[cur_node]["fix_dim"]
|
||||
if cur_node_chunk_dim:
|
||||
cur_node_compute = self.trace_indice._find_compute_trace_from_node(
|
||||
cur_node
|
||||
)
|
||||
cur_node_source = self.trace_indice._find_source_trace_from_node(
|
||||
cur_node
|
||||
)
|
||||
cur_node_compute = self.trace_indice._find_compute_trace_from_node(cur_node)
|
||||
cur_node_source = self.trace_indice._find_source_trace_from_node(cur_node)
|
||||
else:
|
||||
cur_node_compute = cur_node_source = None
|
||||
|
||||
@@ -215,15 +200,9 @@ class TraceFlow(object):
|
||||
return None
|
||||
|
||||
if len(arg_list) == 2:
|
||||
if any(i in cur_node.name for i in ["add", "mul"]):
|
||||
if any(i in cur_node.name for i in ["add", "mul", "truediv"]):
|
||||
for arg in arg_list:
|
||||
if not (
|
||||
start_idx
|
||||
<= find_idx_by_name(
|
||||
arg.name, self.trace_indice.node_list
|
||||
)
|
||||
< end_idx
|
||||
):
|
||||
if not (start_idx <= find_idx_by_name(arg.name, self.trace_indice.node_list) < end_idx):
|
||||
continue
|
||||
arg_chunk_dim = all_node_info[arg]["chunk_dim"]
|
||||
arg_fix_dim = all_node_info[arg]["fix_dim"]
|
||||
@@ -249,9 +228,7 @@ class TraceFlow(object):
|
||||
remove_inputs = []
|
||||
for input_node in inputs:
|
||||
input_dict = {}
|
||||
input_node_idx = find_idx_by_name(
|
||||
input_node.name, self.trace_indice.node_list
|
||||
)
|
||||
input_node_idx = find_idx_by_name(input_node.name, self.trace_indice.node_list)
|
||||
for user in input_node.users.keys():
|
||||
if is_non_compute_node(user):
|
||||
continue
|
||||
@@ -259,9 +236,7 @@ class TraceFlow(object):
|
||||
if start_idx <= user_idx <= end_idx:
|
||||
chunk_dim = all_node_info[user]["chunk_dim"]
|
||||
if chunk_dim is not None:
|
||||
user_source = self.trace_indice._find_source_trace_from_node(
|
||||
user
|
||||
)[chunk_dim]
|
||||
user_source = self.trace_indice._find_source_trace_from_node(user)[chunk_dim]
|
||||
if input_node_idx in user_source:
|
||||
input_dict[user_idx] = user_source[input_node_idx]
|
||||
else:
|
||||
@@ -284,7 +259,7 @@ class TraceFlow(object):
|
||||
maybe_prepose_nodes.sort(
|
||||
key=lambda x: find_idx_by_name(x.name, self.trace_indice.node_list),
|
||||
reverse=True,
|
||||
) # from last node to first node
|
||||
) # from last node to first node
|
||||
prepose_nodes = []
|
||||
# set every node as root, search its args, if all legal, turn root and args as prepose nodes
|
||||
while len(maybe_prepose_nodes) > 0:
|
||||
@@ -305,13 +280,8 @@ class TraceFlow(object):
|
||||
if type(cur_prepose_node_arg) != type(cur_prepose_node):
|
||||
continue
|
||||
# out of loop
|
||||
if not (
|
||||
start_idx
|
||||
<= find_idx_by_name(
|
||||
cur_prepose_node_arg.name, self.trace_indice.node_list
|
||||
)
|
||||
< end_idx
|
||||
):
|
||||
if not (start_idx <= find_idx_by_name(cur_prepose_node_arg.name, self.trace_indice.node_list) <
|
||||
end_idx):
|
||||
continue
|
||||
# compute op in loop
|
||||
elif cur_prepose_node_arg in all_node_info:
|
||||
@@ -335,15 +305,13 @@ class TraceFlow(object):
|
||||
if n in maybe_prepose_nodes:
|
||||
maybe_prepose_nodes.remove(n)
|
||||
# sort by index
|
||||
prepose_nodes.sort(
|
||||
key=lambda x: find_idx_by_name(x.name, self.trace_indice.node_list)
|
||||
)
|
||||
prepose_nodes.sort(key=lambda x: find_idx_by_name(x.name, self.trace_indice.node_list))
|
||||
|
||||
return prepose_nodes
|
||||
|
||||
def _get_non_chunk_inputs(self, chunk_info, start_idx, end_idx):
|
||||
# we need to log input nodes to avoid deleteing them in the loop
|
||||
chunk_node_list = self.trace_indice.node_list[start_idx : end_idx + 1]
|
||||
chunk_node_list = self.trace_indice.node_list[start_idx:end_idx + 1]
|
||||
# also need to get some prepose node's arg out of non_chunk_inputs
|
||||
for n in chunk_info["args"]["prepose_nodes"]:
|
||||
chunk_node_list.remove(n)
|
||||
@@ -354,9 +322,7 @@ class TraceFlow(object):
|
||||
return chunk_info
|
||||
|
||||
def flow_search(self, start_idx, start_dim, end_idx, end_dim):
|
||||
inputs, outputs = find_chunk_compute_input_and_output_nodes(
|
||||
self.trace_indice.node_list[start_idx : end_idx + 1]
|
||||
)
|
||||
inputs, outputs = find_chunk_compute_input_and_output_nodes(self.trace_indice.node_list[start_idx:end_idx + 1])
|
||||
# only single ouput
|
||||
if len(outputs) > 1:
|
||||
return None
|
||||
@@ -367,9 +333,7 @@ class TraceFlow(object):
|
||||
return None
|
||||
|
||||
# get input nodes' chunk dim
|
||||
inputs, inputs_dim = self._get_input_nodes_dim(
|
||||
inputs, start_idx, end_idx, all_node_info
|
||||
)
|
||||
inputs, inputs_dim = self._get_input_nodes_dim(inputs, start_idx, end_idx, all_node_info)
|
||||
if inputs is None:
|
||||
return None
|
||||
|
||||
@@ -385,9 +349,7 @@ class TraceFlow(object):
|
||||
}
|
||||
|
||||
# move useless nodes ahead of loop
|
||||
chunk_info["args"]["prepose_nodes"] = self._get_prepose_nodes(
|
||||
all_node_info, start_idx, end_idx
|
||||
)
|
||||
chunk_info["args"]["prepose_nodes"] = self._get_prepose_nodes(all_node_info, start_idx, end_idx)
|
||||
|
||||
# find non chunk inputs
|
||||
chunk_info = self._get_non_chunk_inputs(chunk_info, start_idx, end_idx)
|
||||
@@ -400,10 +362,8 @@ class TraceFlow(object):
|
||||
def _reassgin_reshape_size(self, chunk_info):
|
||||
chunk_region = chunk_info["region"]
|
||||
reshape_size = {}
|
||||
chunk_shape = get_node_shape(chunk_info["outputs"][0])[
|
||||
chunk_info["outputs_dim"]
|
||||
]
|
||||
for node in self.trace_indice.node_list[chunk_region[0] : chunk_region[1] + 1]:
|
||||
chunk_shape = get_node_shape(chunk_info["outputs"][0])[chunk_info["outputs_dim"]]
|
||||
for node in self.trace_indice.node_list[chunk_region[0]:chunk_region[1] + 1]:
|
||||
if any(i in node.name for i in ["reshape", "view"]):
|
||||
reshape_args = node.args[1:]
|
||||
reshape_log = self.trace_indice.indice_view_list[node]
|
||||
@@ -413,8 +373,6 @@ class TraceFlow(object):
|
||||
if reshape_arg_dim in reshape_log["dim_to"]:
|
||||
continue
|
||||
if reshape_arg_dim == chunk_dim:
|
||||
reshape_size[node.name][reshape_arg.name] = (
|
||||
"min(chunk_size, %d - chunk_idx)" % chunk_shape
|
||||
)
|
||||
reshape_size[node.name][reshape_arg.name] = ("min(chunk_size, %d - chunk_idx)" % chunk_shape)
|
||||
chunk_info["reshape_size"] = reshape_size
|
||||
return chunk_info
|
||||
|
Reference in New Issue
Block a user