[autochunk] support autochunk on evoformer (#2497)

This commit is contained in:
oahzxl
2023-01-19 11:41:00 +08:00
committed by GitHub
parent 304f1ba124
commit ecccc91f21
9 changed files with 200 additions and 188 deletions

View File

@@ -3,14 +3,14 @@ from typing import Any, Callable, Dict, Iterable, List, Tuple
from torch.fx.node import Node
def unflat_list(inputs):
def flat_list(inputs):
"""
unflat a list by recursion
flat a list by recursion
"""
res = []
for i in inputs:
if isinstance(i, list) or isinstance(i, set) or isinstance(i, tuple):
res.extend(unflat_list(i))
res.extend(flat_list(i))
else:
res.append(i)
return res
@@ -27,8 +27,13 @@ def find_first_tensor_arg(node):
def is_non_compute_node(node):
if any(i in node.op for i in ["placeholder", "get_attr", "output"]) or any(
i in node.name for i in ["getitem", "getattr"]):
if any(i in node.op for i in ["placeholder", "get_attr", "output"]) or any(i in node.name for i in ["getattr"]):
return True
if "getitem" in node.name:
node_args = flat_list(node.args[1:])
for node_arg in node_args:
if any(i == str(node_arg) for i in ["None", "Ellipsis"]):
return False
return True
return False
@@ -40,15 +45,15 @@ def get_node_shape(node):
def is_non_compute_node_except_placeholder(node):
if any(i in node.op for i in ["get_attr", "output"]) or any(i in node.name for i in ["getitem", "getattr"]):
return True
return False
if "placeholder" in node.op:
return False
return is_non_compute_node(node)
def is_non_compute_node_except_placeholder_output(node):
if any(i in node.op for i in ["get_attr"]) or any(i in node.name for i in ["getitem", "getattr"]):
return True
return False
if "output" in node.op:
return False
return is_non_compute_node_except_placeholder(node)
def find_idx_by_name(name, nodes_list):