[autochunk] support autochunk on evoformer (#2497)

This commit is contained in:
oahzxl
2023-01-19 11:41:00 +08:00
committed by GitHub
parent 304f1ba124
commit ecccc91f21
9 changed files with 200 additions and 188 deletions

View File

@@ -3,7 +3,7 @@ from typing import Dict, List, Tuple
from torch.fx.node import Node
from .utils import find_first_tensor_arg, find_idx_by_name, get_node_shape, unflat_list
from .utils import find_first_tensor_arg, find_idx_by_name, flat_list, get_node_shape
class TraceIndice(object):
@@ -28,7 +28,7 @@ class TraceIndice(object):
node_list (List)
"""
def __init__(self, node_list: List) -> None:
def __init__(self, node_list: List[Node]) -> None:
self.node_list = node_list
self.indice_trace_list = self._init_indice_trace_list()
self.indice_view_list = {}
@@ -198,7 +198,7 @@ class TraceIndice(object):
node_idx = find_idx_by_name(node.name, self.node_list)
return self.indice_trace_list[node_idx]["compute"]
def _assign_indice_as_input(self, node, node_idx, input_node=None):
def _assign_indice_as_input(self, node: Node, node_idx: int, input_node=None):
"""
Assign node's trace as its input node.
@@ -216,7 +216,7 @@ class TraceIndice(object):
self._inherit_all_computation(input_node, node)
def _assign_all_indice(self, node, node_idx):
def _assign_all_indice(self, node: Node, node_idx: int):
"""
Add new indice for all node's dims.
@@ -232,7 +232,7 @@ class TraceIndice(object):
new_trace.append(self._add_indice())
self.indice_trace_list[node_idx]["indice"] = new_trace
def _assign_transpose_indice(self, node, node_idx):
def _assign_transpose_indice(self, node: Node, node_idx: int):
"""
Assign indice for transpose op.
1. swap input's dim according to transpose args
@@ -249,7 +249,7 @@ class TraceIndice(object):
self._inherit_indice(input_node, tranpose_dim[1], node, tranpose_dim[0])
self._inherit_indice(input_node, tranpose_dim[0], node, tranpose_dim[1])
def _assign_permute_indice(self, node, node_idx):
def _assign_permute_indice(self, node: Node, node_idx: int):
"""
Assign indice for permute op.
1. swap input's dim according to permute args
@@ -259,14 +259,14 @@ class TraceIndice(object):
node (node)
node_idx (int)
"""
permute_dim = unflat_list(node.args[1:])
permute_dim = flat_list(node.args[1:])
input_node = node.args[0]
self._assign_indice_as_input(node, node_idx, input_node)
for idx, d in enumerate(permute_dim):
self._inherit_indice(input_node, d, node, idx)
def _assign_linear_indice(self, node, node_idx):
def _assign_linear_indice(self, node: Node, node_idx: int):
"""
Assign indice for linear op.
1. copy trace from input node and change last indice accroding to weight
@@ -287,7 +287,7 @@ class TraceIndice(object):
self._mark_computation(node, node_idx, [-1])
def _assign_matmul_indice(self, node, node_idx):
def _assign_matmul_indice(self, node: Node, node_idx: int):
"""
Assign indice for matmul op.
1. copy trace from matmul_left and change last indice accroding to matmul_right. (assert they have same length)
@@ -393,7 +393,7 @@ class TraceIndice(object):
self._assign_indice_as_input(node, idx)
self._mark_computation(node, idx, [node.kwargs["dim"]])
def _assign_unsqueeze_indice(self, node, node_idx):
def _assign_unsqueeze_indice(self, node: Node, node_idx: int):
"""
Assign indice for unsqueeze op.
1. assign new indice for unsqueeze dim
@@ -404,9 +404,13 @@ class TraceIndice(object):
"""
self._del_dim(node_idx, -1)
self._assign_indice_as_input(node, node_idx)
self._add_dim(node_idx, node.args[1])
dim_idx = node.args[1]
# unsqueeze(-1) = unsqueeze(shape_num + 1)
if dim_idx < 0:
dim_idx = list(range(len(get_node_shape(node))))[dim_idx]
self._add_dim(node_idx, dim_idx)
def _assign_dropout_indice(self, node, node_idx):
def _assign_dropout_indice(self, node: Node, node_idx: int):
"""
Assign indice for unsqueeze op.
1. assign new indice for unsqueeze dim
@@ -417,7 +421,7 @@ class TraceIndice(object):
"""
self._assign_indice_as_input(node, node_idx)
def _assign_ones_like_indice(self, node, node_idx):
def _assign_ones_like_indice(self, node: Node, node_idx: int):
"""
Assign indice for oneslike op.
1. assign new indice for all dim
@@ -428,7 +432,47 @@ class TraceIndice(object):
"""
self._assign_all_indice(node, node_idx)
def _assign_view_reshape_indice(self, node, node_idx):
def _assign_getitem_indice(self, node: Node, node_idx: int):
"""
Assign indice for getitem.
getitem can act like slice sometimes
Args:
node (node)
node_idx (int)
"""
node_args = flat_list(node.args[1:])
if not any(i == str(node_arg) for i in ["None", "Ellipsis"] for node_arg in node_args):
return
# node args should be like [Ellipsis, slice(start, step, end), None]
node_shape = get_node_shape(node)
origin_idx_count = 0
new_idx_count = 0
new_dim_num = sum([1 if str(i) == "None" else 0 for i in node_args])
for _ in range(new_dim_num):
self._del_dim(node_idx, 0)
self._assign_indice_as_input(node, node_idx)
for _, node_arg in enumerate(node_args):
node_arg_str = str(node_arg)
# Ellipsis means [..., ]
if "Ellipsis" == node_arg_str:
shape_gap = len(node_shape) - len(node_args) + 1
origin_idx_count += shape_gap
new_idx_count += shape_gap
# slice(None, None, None) means all indexes, doesn't support other slice
elif "slice(None, None, None)" == node_arg_str:
origin_idx_count += 1
new_idx_count += 1
# None means a new dim
elif "None" == node_arg_str:
self._add_dim(node_idx, new_idx_count)
new_idx_count += 1
else:
raise NotImplementedError()
def _assign_view_reshape_indice(self, node: Node, node_idx: int):
"""
Assign indice for view and reshape op.
1. get origin shape and target shape by meta info.
@@ -447,7 +491,7 @@ class TraceIndice(object):
origin_node = node.args[0]
origin_shape = origin_node.meta["tensor_meta"].shape
target_shape = []
unflated_args = unflat_list(node.args)
unflated_args = flat_list(node.args)
for i in range(1, len(unflated_args)):
if isinstance(unflated_args[i], int):
target_shape.append(unflated_args[i])
@@ -544,6 +588,8 @@ class TraceIndice(object):
self._assign_einsum_indice(node, idx)
elif "layer_norm" in node.name:
self._assign_layernorm_indice(node, idx)
elif "getitem" in node.name:
self._assign_getitem_indice(node, idx)
elif any(i in node.name for i in ["getattr", "getitem", "eq", "_assert"]):
continue
else: