[autochunk] support evoformer tracer (#2485)

support full evoformer tracer, which is a main module of alphafold. previously we just support a simplifed version of it.
1. support some evoformer's op in fx
2. support evoformer test
3. add repos for test code
This commit is contained in:
oahzxl
2023-01-16 19:25:05 +08:00
committed by GitHub
parent 67e1912b59
commit 4953b4ace1
25 changed files with 339 additions and 3215 deletions

View File

@@ -3,10 +3,32 @@ from typing import Any, Callable, Dict, Iterable, List, Tuple
from torch.fx.node import Node
def unflat_list(inputs):
"""
unflat a list by recursion
"""
res = []
for i in inputs:
if isinstance(i, list) or isinstance(i, set) or isinstance(i, tuple):
res.extend(unflat_list(i))
else:
res.append(i)
return res
def find_first_tensor_arg(node):
"""
Find the first input tensor arg for a node
"""
for arg in node.args:
if type(arg) == type(node):
return arg
raise RuntimeError()
def is_non_compute_node(node):
if any(i in node.op for i in ["placeholder", "get_attr", "output"]) or any(
i in node.name for i in ["getitem", "getattr"]
):
i in node.name for i in ["getitem", "getattr"]):
return True
return False
@@ -18,17 +40,13 @@ def get_node_shape(node):
def is_non_compute_node_except_placeholder(node):
if any(i in node.op for i in ["get_attr", "output"]) or any(
i in node.name for i in ["getitem", "getattr"]
):
if any(i in node.op for i in ["get_attr", "output"]) or any(i in node.name for i in ["getitem", "getattr"]):
return True
return False
def is_non_compute_node_except_placeholder_output(node):
if any(i in node.op for i in ["get_attr"]) or any(
i in node.name for i in ["getitem", "getattr"]
):
if any(i in node.op for i in ["get_attr"]) or any(i in node.name for i in ["getitem", "getattr"]):
return True
return False
@@ -74,22 +92,16 @@ def find_chunk_compute_input_and_output_nodes(nodes: List[Node]):
# we treat that input node as the input of the checkpoint function
for node in nodes:
for input_node in node._input_nodes.keys():
if (
input_node not in nodes
and input_node not in input_nodes
and not is_non_compute_node_except_placeholder(input_node)
):
if (input_node not in nodes and input_node not in input_nodes
and not is_non_compute_node_except_placeholder(input_node)):
input_nodes.append(input_node)
# if a node has a user node which is not in the node list
# we treat that user node as the node receiving the current node output
for node in nodes:
for output_node in node.users.keys():
if (
output_node not in nodes
and node not in output_nodes
and not is_non_compute_node_except_placeholder_output(output_node)
):
if (output_node not in nodes and node not in output_nodes
and not is_non_compute_node_except_placeholder_output(output_node)):
output_nodes.append(node)
return input_nodes, output_nodes