[autochunk] support evoformer tracer (#2485)

support full evoformer tracer, which is a main module of alphafold. previously we just support a simplifed version of it.
1. support some evoformer's op in fx
2. support evoformer test
3. add repos for test code
This commit is contained in:
oahzxl
2023-01-16 19:25:05 +08:00
committed by GitHub
parent 67e1912b59
commit 4953b4ace1
25 changed files with 339 additions and 3215 deletions

View File

@@ -10,6 +10,7 @@ from .utils import (
class TraceFlow(object):
def __init__(self, trace_indice: TraceIndice) -> None:
self.trace_indice = trace_indice
@@ -28,9 +29,7 @@ class TraceFlow(object):
start_node_idx = find_idx_by_name(start_node.name, self.trace_indice.node_list)
end_node_trace = self.trace_indice._find_trace_from_node(end_node)
end_node_trace_source = end_node_trace["source"][end_dim]
sorted_source = sorted(
end_node_trace_source.items(), key=lambda d: d[0], reverse=True
)
sorted_source = sorted(end_node_trace_source.items(), key=lambda d: d[0], reverse=True)
for node_idx, node_dim in sorted_source:
if node_idx == start_node_idx and start_dim in node_dim:
return True
@@ -70,10 +69,8 @@ class TraceFlow(object):
input_node_idx = find_idx_by_name(input_node.name, self.trace_indice.node_list)
node_trace_source = self.trace_indice._find_source_trace_from_node(node)
for node_dim in range(len(get_node_shape(node))):
if (
input_node_idx in node_trace_source[node_dim]
and input_dim[0] in node_trace_source[node_dim][input_node_idx]
):
if (input_node_idx in node_trace_source[node_dim]
and input_dim[0] in node_trace_source[node_dim][input_node_idx]):
return node_dim
return None
@@ -81,15 +78,11 @@ class TraceFlow(object):
input_dim_after_node = {}
for input_node_idx, input_node in enumerate(chunk_infos["inputs"]):
for k, v in chunk_infos["inputs_dim"][input_node_idx].items():
inherit_dim = self._find_inherit_dim(
input_node, v, self.trace_indice.node_list[k]
)
inherit_dim = self._find_inherit_dim(input_node, v, self.trace_indice.node_list[k])
if inherit_dim:
input_dim_after_node[k] = inherit_dim
for node in self.trace_indice.node_list[
chunk_infos["region"][0] : chunk_infos["region"][1] + 1
]:
for node in self.trace_indice.node_list[chunk_infos["region"][0]:chunk_infos["region"][1] + 1]:
if is_non_compute_node_except_placeholder(node):
continue
count = 0
@@ -159,9 +152,7 @@ class TraceFlow(object):
if arg_node in all_node_info:
if all_node_info[arg_node]["chunk_dim"] != arg_dim:
return False
all_node_info[arg_node]["fix_dim"] = list(
set(all_node_info[arg_node]["fix_dim"] + arg_fix_dim)
)
all_node_info[arg_node]["fix_dim"] = list(set(all_node_info[arg_node]["fix_dim"] + arg_fix_dim))
# else add it to list
else:
all_node_info[arg_node] = {"chunk_dim": arg_dim, "fix_dim": arg_fix_dim}
@@ -170,9 +161,7 @@ class TraceFlow(object):
return True
def _get_all_node_info(self, end_dim, start_idx, end_idx):
cur_node_list = [
self.trace_indice.node_list[end_idx]
] # start from the last node
cur_node_list = [self.trace_indice.node_list[end_idx]] # start from the last node
all_node_info = {cur_node_list[0]: {"chunk_dim": end_dim, "fix_dim": []}}
while len(cur_node_list) > 0:
@@ -183,12 +172,8 @@ class TraceFlow(object):
cur_node_chunk_dim = all_node_info[cur_node]["chunk_dim"]
cur_node_fix_dim = all_node_info[cur_node]["fix_dim"]
if cur_node_chunk_dim:
cur_node_compute = self.trace_indice._find_compute_trace_from_node(
cur_node
)
cur_node_source = self.trace_indice._find_source_trace_from_node(
cur_node
)
cur_node_compute = self.trace_indice._find_compute_trace_from_node(cur_node)
cur_node_source = self.trace_indice._find_source_trace_from_node(cur_node)
else:
cur_node_compute = cur_node_source = None
@@ -215,15 +200,9 @@ class TraceFlow(object):
return None
if len(arg_list) == 2:
if any(i in cur_node.name for i in ["add", "mul"]):
if any(i in cur_node.name for i in ["add", "mul", "truediv"]):
for arg in arg_list:
if not (
start_idx
<= find_idx_by_name(
arg.name, self.trace_indice.node_list
)
< end_idx
):
if not (start_idx <= find_idx_by_name(arg.name, self.trace_indice.node_list) < end_idx):
continue
arg_chunk_dim = all_node_info[arg]["chunk_dim"]
arg_fix_dim = all_node_info[arg]["fix_dim"]
@@ -249,9 +228,7 @@ class TraceFlow(object):
remove_inputs = []
for input_node in inputs:
input_dict = {}
input_node_idx = find_idx_by_name(
input_node.name, self.trace_indice.node_list
)
input_node_idx = find_idx_by_name(input_node.name, self.trace_indice.node_list)
for user in input_node.users.keys():
if is_non_compute_node(user):
continue
@@ -259,9 +236,7 @@ class TraceFlow(object):
if start_idx <= user_idx <= end_idx:
chunk_dim = all_node_info[user]["chunk_dim"]
if chunk_dim is not None:
user_source = self.trace_indice._find_source_trace_from_node(
user
)[chunk_dim]
user_source = self.trace_indice._find_source_trace_from_node(user)[chunk_dim]
if input_node_idx in user_source:
input_dict[user_idx] = user_source[input_node_idx]
else:
@@ -284,7 +259,7 @@ class TraceFlow(object):
maybe_prepose_nodes.sort(
key=lambda x: find_idx_by_name(x.name, self.trace_indice.node_list),
reverse=True,
) # from last node to first node
) # from last node to first node
prepose_nodes = []
# set every node as root, search its args, if all legal, turn root and args as prepose nodes
while len(maybe_prepose_nodes) > 0:
@@ -305,13 +280,8 @@ class TraceFlow(object):
if type(cur_prepose_node_arg) != type(cur_prepose_node):
continue
# out of loop
if not (
start_idx
<= find_idx_by_name(
cur_prepose_node_arg.name, self.trace_indice.node_list
)
< end_idx
):
if not (start_idx <= find_idx_by_name(cur_prepose_node_arg.name, self.trace_indice.node_list) <
end_idx):
continue
# compute op in loop
elif cur_prepose_node_arg in all_node_info:
@@ -335,15 +305,13 @@ class TraceFlow(object):
if n in maybe_prepose_nodes:
maybe_prepose_nodes.remove(n)
# sort by index
prepose_nodes.sort(
key=lambda x: find_idx_by_name(x.name, self.trace_indice.node_list)
)
prepose_nodes.sort(key=lambda x: find_idx_by_name(x.name, self.trace_indice.node_list))
return prepose_nodes
def _get_non_chunk_inputs(self, chunk_info, start_idx, end_idx):
# we need to log input nodes to avoid deleteing them in the loop
chunk_node_list = self.trace_indice.node_list[start_idx : end_idx + 1]
chunk_node_list = self.trace_indice.node_list[start_idx:end_idx + 1]
# also need to get some prepose node's arg out of non_chunk_inputs
for n in chunk_info["args"]["prepose_nodes"]:
chunk_node_list.remove(n)
@@ -354,9 +322,7 @@ class TraceFlow(object):
return chunk_info
def flow_search(self, start_idx, start_dim, end_idx, end_dim):
inputs, outputs = find_chunk_compute_input_and_output_nodes(
self.trace_indice.node_list[start_idx : end_idx + 1]
)
inputs, outputs = find_chunk_compute_input_and_output_nodes(self.trace_indice.node_list[start_idx:end_idx + 1])
# only single ouput
if len(outputs) > 1:
return None
@@ -367,9 +333,7 @@ class TraceFlow(object):
return None
# get input nodes' chunk dim
inputs, inputs_dim = self._get_input_nodes_dim(
inputs, start_idx, end_idx, all_node_info
)
inputs, inputs_dim = self._get_input_nodes_dim(inputs, start_idx, end_idx, all_node_info)
if inputs is None:
return None
@@ -385,9 +349,7 @@ class TraceFlow(object):
}
# move useless nodes ahead of loop
chunk_info["args"]["prepose_nodes"] = self._get_prepose_nodes(
all_node_info, start_idx, end_idx
)
chunk_info["args"]["prepose_nodes"] = self._get_prepose_nodes(all_node_info, start_idx, end_idx)
# find non chunk inputs
chunk_info = self._get_non_chunk_inputs(chunk_info, start_idx, end_idx)
@@ -400,10 +362,8 @@ class TraceFlow(object):
def _reassgin_reshape_size(self, chunk_info):
chunk_region = chunk_info["region"]
reshape_size = {}
chunk_shape = get_node_shape(chunk_info["outputs"][0])[
chunk_info["outputs_dim"]
]
for node in self.trace_indice.node_list[chunk_region[0] : chunk_region[1] + 1]:
chunk_shape = get_node_shape(chunk_info["outputs"][0])[chunk_info["outputs_dim"]]
for node in self.trace_indice.node_list[chunk_region[0]:chunk_region[1] + 1]:
if any(i in node.name for i in ["reshape", "view"]):
reshape_args = node.args[1:]
reshape_log = self.trace_indice.indice_view_list[node]
@@ -413,8 +373,6 @@ class TraceFlow(object):
if reshape_arg_dim in reshape_log["dim_to"]:
continue
if reshape_arg_dim == chunk_dim:
reshape_size[node.name][reshape_arg.name] = (
"min(chunk_size, %d - chunk_idx)" % chunk_shape
)
reshape_size[node.name][reshape_arg.name] = ("min(chunk_size, %d - chunk_idx)" % chunk_shape)
chunk_info["reshape_size"] = reshape_size
return chunk_info