[autochunk] support multi outputs chunk search (#2538)

Support multi outputs chunk search. Previously we only support single output chunk search. It is more flexible and improve performance by a large margin. For transformer, we reduce memory by 40% than previous search strategy.

1. rewrite search strategy to support multi outputs chunk search
2. fix many, many bugs
3. update tests
This commit is contained in:
oahzxl
2023-02-01 13:18:51 +08:00
committed by GitHub
parent f477a14f4a
commit 05671fcb42
14 changed files with 428 additions and 258 deletions

View File

@@ -1,5 +1,5 @@
from .trace_indice import TraceIndice
from .utils import find_idx_by_name
from .utils import NodeMgr
class ReorderGraph(object):
@@ -7,31 +7,27 @@ class ReorderGraph(object):
Reorder node list and indice trace list
"""
def __init__(self, trace_indice: TraceIndice) -> None:
def __init__(self, trace_indice: TraceIndice, node_mgr: NodeMgr) -> None:
self.trace_indice = trace_indice
self.all_reorder_map = {
i: i for i in range(len(self.trace_indice.indice_trace_list))
}
self.node_mgr = node_mgr
self.all_reorder_map = {i: i for i in range(len(self.node_mgr.get_node_list()))}
def _get_reorder_map(self, chunk_info):
reorder_map = {i: i for i in range(len(self.trace_indice.node_list))}
reorder_map = {i: i for i in range(len(self.node_mgr.get_node_list()))}
chunk_region_start = chunk_info["region"][0]
chunk_region_end = chunk_info["region"][1]
chunk_prepose_nodes = chunk_info["args"]["prepose_nodes"]
chunk_prepose_nodes_idx = [
find_idx_by_name(i.name, self.trace_indice.node_list)
for i in chunk_prepose_nodes
]
chunk_prepose_nodes_idx = [self.node_mgr.find_node_idx(i) for i in chunk_prepose_nodes]
# put prepose nodes ahead
for idx, n in enumerate(chunk_prepose_nodes):
n_idx = chunk_prepose_nodes_idx[idx]
reorder_map[n_idx] = chunk_region_start + idx
# put other nodes after prepose nodes
for n in self.trace_indice.node_list[chunk_region_start : chunk_region_end + 1]:
for n in self.node_mgr.get_node_slice_by_idx(chunk_region_start, chunk_region_end + 1):
if n in chunk_prepose_nodes:
continue
n_idx = find_idx_by_name(n.name, self.trace_indice.node_list)
n_idx = self.node_mgr.find_node_idx(n)
pos = sum([n_idx < i for i in chunk_prepose_nodes_idx])
reorder_map[n_idx] = n_idx + pos
@@ -44,7 +40,7 @@ class ReorderGraph(object):
chunk_info["region"][1],
)
new_inputs_dim = []
for idx, input_dim in enumerate(chunk_info["inputs_dim"]):
for _, input_dim in enumerate(chunk_info["inputs_dim"]):
new_input_dim = {}
for k, v in input_dim.items():
new_input_dim[reorder_map[k]] = v
@@ -57,16 +53,14 @@ class ReorderGraph(object):
self.all_reorder_map[origin_idx] = reorder_map[map_idx]
def _reorder_self_node_list(self, reorder_map):
new_node_list = [None for _ in range(len(self.trace_indice.node_list))]
new_node_list = [None for _ in range(len(self.node_mgr.get_node_list()))]
for old_idx, new_idx in reorder_map.items():
new_node_list[new_idx] = self.trace_indice.node_list[old_idx]
self.trace_indice.node_list = new_node_list
new_node_list[new_idx] = self.node_mgr.get_node_by_idx(old_idx)
self.node_mgr.update_node_list(new_node_list)
def _reorder_idx_trace(self, reorder_map):
# reorder list
new_idx_trace_list = [
None for _ in range(len(self.trace_indice.indice_trace_list))
]
new_idx_trace_list = [None for _ in range(len(self.trace_indice.indice_trace_list))]
for old_idx, new_idx in reorder_map.items():
new_idx_trace_list[new_idx] = self.trace_indice.indice_trace_list[old_idx]
self.trace_indice.indice_trace_list = new_idx_trace_list