[autochunk] support multi outputs chunk search (#2538)

Support multi outputs chunk search. Previously we only support single output chunk search. It is more flexible and improve performance by a large margin. For transformer, we reduce memory by 40% than previous search strategy.

1. rewrite search strategy to support multi outputs chunk search
2. fix many, many bugs
3. update tests
This commit is contained in:
oahzxl
2023-02-01 13:18:51 +08:00
committed by GitHub
parent f477a14f4a
commit 05671fcb42
14 changed files with 428 additions and 258 deletions

View File

@@ -1,7 +1,7 @@
from .estimate_memory import EstimateMemory
from .reorder_graph import ReorderGraph
from .trace_indice import TraceIndice
from .utils import is_non_compute_node
from .utils import NodeMgr, is_non_compute_node
class SelectChunk(object):
@@ -11,11 +11,13 @@ class SelectChunk(object):
trace_indice: TraceIndice,
estimate_memory: EstimateMemory,
reorder_graph: ReorderGraph,
node_mgr: NodeMgr,
max_memory=None,
):
self.trace_indice = trace_indice
self.estimate_memory = estimate_memory
self.reorder_graph = reorder_graph
self.node_mgr = node_mgr
if max_memory is not None:
self.stratge = "fit_memory"
self.max_memory = max_memory # MB
@@ -68,7 +70,7 @@ class SelectChunk(object):
regions_dict = []
for region in possible_chunk_regions:
cur_region = region.copy()
cur_node_list, cur_region = self.reorder_graph.tmp_reorder(self.trace_indice.node_list, cur_region)
cur_node_list, cur_region = self.reorder_graph.tmp_reorder(self.node_mgr.get_node_list(), cur_region)
cur_chunk_infos = chunk_infos + [cur_region]
cur_mem_peak = self.estimate_memory.estimate_chunk_inference_mem(cur_node_list, cur_chunk_infos)[0]
cur_chunk_region_peak = cur_mem_peak[max_possible_chunk_region[0]:max_possible_chunk_region[1] + 1]
@@ -134,7 +136,7 @@ class SelectChunk(object):
def _get_compute_node_num(self, start, end):
count = 0
for i in self.trace_indice.node_list[start:end + 1]:
for i in self.node_mgr.get_node_slice_by_idx(start, end + 1):
if not is_non_compute_node(i):
count += 1
return count
@@ -161,7 +163,7 @@ class SelectChunk(object):
regions_dict_list = []
for region in possible_chunk_regions:
cur_region = region.copy()
cur_node_list, cur_region = self.reorder_graph.tmp_reorder(self.trace_indice.node_list, cur_region)
cur_node_list, cur_region = self.reorder_graph.tmp_reorder(self.node_mgr.get_node_list(), cur_region)
cur_chunk_infos = chunk_infos + [cur_region]
cur_mem_peak = self.estimate_memory.estimate_chunk_inference_mem(cur_node_list, cur_chunk_infos)[0]
cur_chunk_region_peak = cur_mem_peak[max_possible_chunk_region[0]:max_possible_chunk_region[1] + 1]