[autochunk] support autochunk on evoformer (#2497)

This commit is contained in:
oahzxl
2023-01-19 11:41:00 +08:00
committed by GitHub
parent 304f1ba124
commit ecccc91f21
9 changed files with 200 additions and 188 deletions

View File

@@ -1,8 +1,13 @@
from typing import Dict, List, Tuple
from torch.fx.node import Node
from .trace_indice import TraceIndice
from .utils import (
find_chunk_all_input_nodes,
find_chunk_compute_input_and_output_nodes,
find_idx_by_name,
flat_list,
get_node_shape,
is_non_compute_node,
is_non_compute_node_except_placeholder,
@@ -171,7 +176,7 @@ class TraceFlow(object):
# get cur node info
cur_node_chunk_dim = all_node_info[cur_node]["chunk_dim"]
cur_node_fix_dim = all_node_info[cur_node]["fix_dim"]
if cur_node_chunk_dim:
if cur_node_chunk_dim is not None:
cur_node_compute = self.trace_indice._find_compute_trace_from_node(cur_node)
cur_node_source = self.trace_indice._find_source_trace_from_node(cur_node)
else:
@@ -223,15 +228,32 @@ class TraceFlow(object):
cur_node_list = next_node_list
return all_node_info
def _get_input_nodes_dim(self, inputs, start_idx, end_idx, all_node_info):
def _get_input_nodes_dim(self, inputs: List[Node], start_idx: int, end_idx: int, all_node_info: Dict) -> Tuple:
"""
Get chunk dim for every input node for their every entry, remove unchunked nodes
Args:
inputs (List[Node]): input nodes
all_node_info (Dict): describe all node's chunk dim and fix dim
start_idx (int): chunk start idx
end_idx (int): chunk end idx
Returns:
inputs (List(Node)): new inputs
inputs_dim (List): chunk dim for inputs
"""
inputs_dim = []
remove_inputs = []
for input_node in inputs:
input_dict = {}
input_node_idx = find_idx_by_name(input_node.name, self.trace_indice.node_list)
for user in input_node.users.keys():
# skip non compute
if is_non_compute_node(user):
continue
# untraced node, mostly non compute
if user not in all_node_info:
continue
user_idx = find_idx_by_name(user.name, self.trace_indice.node_list)
if start_idx <= user_idx <= end_idx:
chunk_dim = all_node_info[user]["chunk_dim"]
@@ -245,12 +267,24 @@ class TraceFlow(object):
remove_inputs.append(input_node)
else:
inputs_dim.append(input_dict)
# remove unchunked inputs
for i in remove_inputs:
if i in inputs:
inputs.remove(i)
return inputs, inputs_dim
def _get_prepose_nodes(self, all_node_info, start_idx, end_idx):
def _get_prepose_nodes(self, all_node_info: Dict, start_idx: int, end_idx: int) -> List[Node]:
"""
get all useless nodes in chunk region and prepose them
Args:
all_node_info (Dict): describe all node's chunk dim and fix dim
start_idx (int): chunk start idx
end_idx (int): chunk end idx
Returns:
List[Node]: all nodes to be preposed
"""
# get all possible prepose nodes
maybe_prepose_nodes = []
for node, node_info in all_node_info.items():
@@ -276,7 +310,7 @@ class TraceFlow(object):
for cur_prepose_node in tmp_cur_prepose_nodes:
if prepose_flag == False:
break
for cur_prepose_node_arg in cur_prepose_node.args:
for cur_prepose_node_arg in cur_prepose_node.all_input_nodes:
if type(cur_prepose_node_arg) != type(cur_prepose_node):
continue
# out of loop
@@ -360,19 +394,28 @@ class TraceFlow(object):
return chunk_info
def _reassgin_reshape_size(self, chunk_info):
"""
Some shape args in reshape may have changed due to chunk
reassgin those changed shape
"""
chunk_region = chunk_info["region"]
reshape_size = {}
chunk_shape = get_node_shape(chunk_info["outputs"][0])[chunk_info["outputs_dim"]]
for node in self.trace_indice.node_list[chunk_region[0]:chunk_region[1] + 1]:
if any(i in node.name for i in ["reshape", "view"]):
reshape_args = node.args[1:]
reshape_log = self.trace_indice.indice_view_list[node]
reshape_args = flat_list(node.args[1:])
chunk_dim = chunk_info["node_chunk_dim"][node]["chunk_dim"]
reshape_size[node.name] = {}
new_shape = ""
for reshape_arg_dim, reshape_arg in enumerate(reshape_args):
if reshape_arg_dim in reshape_log["dim_to"]:
continue
if reshape_arg_dim == chunk_dim:
reshape_size[node.name][reshape_arg.name] = ("min(chunk_size, %d - chunk_idx)" % chunk_shape)
new_shape += "min(chunk_size, %d - chunk_idx), " % chunk_shape
else:
if isinstance(reshape_arg, int):
new_shape += "%s, " % str(reshape_arg)
else:
new_shape += "%s, " % reshape_arg.name
new_shape = new_shape[:-2]
origin_shape = str(reshape_args)[1:-1]
reshape_size[node.name] = [origin_shape, new_shape]
chunk_info["reshape_size"] = reshape_size
return chunk_info