[autochunk] support evoformer tracer (#2485)

support full evoformer tracer, which is a main module of alphafold. previously we just support a simplifed version of it.
1. support some evoformer's op in fx
2. support evoformer test
3. add repos for test code
This commit is contained in:
oahzxl
2023-01-16 19:25:05 +08:00
committed by GitHub
parent 67e1912b59
commit 4953b4ace1
25 changed files with 339 additions and 3215 deletions

View File

@@ -3,7 +3,7 @@ from typing import Dict, List, Tuple
from torch.fx.node import Node
from .utils import find_idx_by_name, get_node_shape
from .utils import find_first_tensor_arg, find_idx_by_name, get_node_shape, unflat_list
class TraceIndice(object):
@@ -79,9 +79,7 @@ class TraceIndice(object):
node_from_trace = self._find_trace_from_node(node_from)
node_to_trace = self._find_trace_from_node(node_to)
node_to_trace["indice"][node_to_dim] = node_from_trace["indice"][node_from_dim]
node_to_trace["compute"][node_to_dim] = copy.deepcopy(
node_from_trace["compute"][node_from_dim]
)
node_to_trace["compute"][node_to_dim] = copy.deepcopy(node_from_trace["compute"][node_from_dim])
self._add_source(node_from, node_from_dim, node_to, node_to_dim, init=True)
def _inherit_all_computation(self, node_from, node_to):
@@ -209,7 +207,7 @@ class TraceIndice(object):
node_idx (int)
"""
if input_node == None:
input_node = node.args[0]
input_node = find_first_tensor_arg(node)
input_node_idx = find_idx_by_name(input_node.name, self.node_list)
input_node_idx_trace = self.indice_trace_list[input_node_idx]["indice"]
@@ -227,6 +225,8 @@ class TraceIndice(object):
node_idx (int)
"""
shape = node.meta["tensor_meta"].shape
if shape is None:
return
new_trace = []
for _ in shape:
new_trace.append(self._add_indice())
@@ -259,7 +259,7 @@ class TraceIndice(object):
node (node)
node_idx (int)
"""
permute_dim = node.args[1:]
permute_dim = unflat_list(node.args[1:])
input_node = node.args[0]
self._assign_indice_as_input(node, node_idx, input_node)
@@ -359,6 +359,15 @@ class TraceIndice(object):
left, right = patterns.split("->")
left = left.split(",")
if '...' in right:
replace_list = "!@#$%^&*"
target_len = len(get_node_shape(node))
add_len = target_len - len(right) + 3
replace_str = replace_list[:add_len]
right = right.replace("...", replace_str)
for ll in range(len(left)):
left[ll] = left[ll].replace("...", replace_str)
all_index = []
for i in left:
for c in i:
@@ -369,9 +378,7 @@ class TraceIndice(object):
for left_idx, left_str in enumerate(left):
if right_indice in left_str:
source_idx = left_str.index(right_indice)
self._inherit_indice(
input_nodes[left_idx], source_idx, node, right_idx
)
self._inherit_indice(input_nodes[left_idx], source_idx, node, right_idx)
def _assign_softmax_indice(self, node, idx):
"""
@@ -440,11 +447,12 @@ class TraceIndice(object):
origin_node = node.args[0]
origin_shape = origin_node.meta["tensor_meta"].shape
target_shape = []
for i in range(1, len(node.args)):
if isinstance(node.args[i], int):
target_shape.append(node.args[i])
unflated_args = unflat_list(node.args)
for i in range(1, len(unflated_args)):
if isinstance(unflated_args[i], int):
target_shape.append(unflated_args[i])
else:
target_shape.append(node.args[i].meta["fwd_out"][0])
target_shape.append(unflated_args[i].meta["fwd_out"][0])
# compute the value of -1
if -1 in target_shape:
@@ -472,13 +480,7 @@ class TraceIndice(object):
dim_to = [dim_equal.index(False), dim_equal.index(False) + 1]
self._del_dim(node_idx, -1)
else:
raise NotImplementedError(
"shape"
+ str(origin_shape)
+ "and"
+ str(target_shape)
+ "view not implemented"
)
raise NotImplementedError("shape" + str(origin_shape) + "and" + str(target_shape) + "view not implemented")
# get new indice
origin_trace = self._find_indice_trace_from_node(origin_node)
@@ -521,6 +523,8 @@ class TraceIndice(object):
self._assign_unsqueeze_indice(node, idx)
elif any(i in node.name for i in ["to", "contiguous"]):
self._assgin_no_change_indice(node, idx)
elif "new_ones" in node.name:
self._assign_ones_like_indice(node, idx)
else:
raise NotImplementedError(node.name, "method not implemented yet!")
elif node.op == "call_function":
@@ -530,7 +534,7 @@ class TraceIndice(object):
self._assign_matmul_indice(node, idx)
elif "softmax" in node.name:
self._assign_softmax_indice(node, idx)
elif any(n in node.name for n in ["mul", "add", "sigmoid", "relu"]):
elif any(n in node.name for n in ["mul", "add", "sigmoid", "relu", "sub", "truediv"]):
self._assign_elementwise_indice(node, idx)
elif "ones_like" in node.name:
self._assign_ones_like_indice(node, idx)
@@ -538,21 +542,21 @@ class TraceIndice(object):
self._assign_dropout_indice(node, idx)
elif "einsum" in node.name:
self._assign_einsum_indice(node, idx)
elif "getattr" in node.name:
continue # get attr like shape
elif "getitem" in node.name:
continue # get item in list
elif "layer_norm" in node.name:
self._assign_layernorm_indice(node, idx)
elif any(i in node.name for i in ["getattr", "getitem", "eq", "_assert"]):
continue
else:
raise NotImplementedError(
node.name, "function not implemented yet!"
)
raise NotImplementedError(node.name, "function not implemented yet!")
elif node.op == "call_module":
if any(n in node.name for n in ["layernorm", "norm"]):
self._assign_layernorm_indice(node, idx)
elif any(n in node.name for n in ["sigmoid", "dropout", "relu"]):
self._assign_elementwise_indice(node, idx)
else:
raise NotImplementedError(node.name, "module not implemented yet!")
elif node.op == "get_attr":
self._assign_all_indice(node, idx) # get param
self._assign_all_indice(node, idx) # get param
elif node.op == "output":
continue
else: