[autochunk] refactor chunk memory estimation (#2762)

* refact memory code

* dont log free var memory

* add memory align

* update chunk target

* update setting for new memory

* finish test

* update tracer

* update typo

* update test
This commit is contained in:
Xuanlei Zhao
2023-03-08 16:22:30 +08:00
committed by GitHub
parent b51bfec357
commit 2ca9728cbb
12 changed files with 294 additions and 422 deletions

View File

@@ -42,10 +42,11 @@ class SearchChunk(object):
def __init__(self, gm, max_memory=None, print_mem=False, print_progress=False) -> None:
self.print_mem = print_mem
self.max_memory = max_memory
self.print_progress = print_progress
self.node_mgr = NodeMgr(gm)
self.node_mgr = NodeMgr(list(gm.graph.nodes))
self.trace_indice = TraceIndice(self.node_mgr)
self.estimate_memory = EstimateMemory(self.node_mgr)
self.estimate_memory = EstimateMemory()
self._init_trace()
self.trace_flow = TraceFlow(self.trace_indice, self.node_mgr)
self.reorder_graph = ReorderGraph(self.trace_indice, self.node_mgr)
@@ -63,45 +64,46 @@ class SearchChunk(object):
reduce the computation complexity of trace_indice
"""
# find all max ranges
active_nodes = self.estimate_memory.get_active_nodes(self.node_mgr.get_node_list())
cur_node_idx = len(self._get_free_var_idx())
max_chunk_region_list = []
while True:
max_chunk_region = self._search_max_chunk_region(active_nodes, cur_node_idx)
cur_node_idx = max_chunk_region[1] + 1
if cur_node_idx >= len(active_nodes) - 1:
break
max_chunk_region_list.append(max_chunk_region)
# nothing to limit for the first range
max_chunk_region_list = max_chunk_region_list[1:]
max_chunk_region_list[0] = (0, max_chunk_region_list[0][1])
active_nodes = self.estimate_memory.estimate_chunk_inference_mem(self.node_mgr.get_node_list())[2]
# set trace range and do the trace
if self.print_progress:
get_logger().info("AutoChunk start tracing indice")
self.trace_indice.set_trace_range(max_chunk_region_list, active_nodes)
self.trace_indice.set_active_nodes(active_nodes)
self.trace_indice.trace_indice()
def _find_peak_node(self, mem_peak: List) -> int:
def _find_peak_region(self, mem_peak: List) -> int:
"""
find peak node, along with its neighbour nodes exceeds max mem
"""
max_value = max(mem_peak)
max_idx = mem_peak.index(max_value)
return max_idx
peak_region = [max_idx, max_idx]
if self.max_memory is None:
return peak_region
def _get_free_var_idx(self) -> List:
"""
Get free var index
# to left
count = 0
for i in range(max_idx - 1, -1, -1):
if mem_peak[i] > self.max_memory:
peak_region[0] = i
else:
count += 1
if count >= 3:
break
# to right
count = 0
for i in range(max_idx + 1, len(mem_peak) - 1):
if mem_peak[i] > self.max_memory:
peak_region[1] = i
count = 0
else:
count += 1
if count >= 3:
break
Returns:
free_var_idx (List): all indexs of free vars
"""
free_var_idx = []
for idx, n in enumerate(self.node_mgr.get_node_list()):
if n.op == "placeholder" and get_node_shape(n) is not None:
free_var_idx.append(idx)
return free_var_idx
return peak_region
def _search_max_chunk_region(self, active_node: List, peak_node_idx: int, chunk_regions: List = None) -> Tuple:
def _search_max_chunk_region(self, active_node: List, peak_region: int, chunk_regions: List = None) -> Tuple:
"""
Search max chunk region according to peak memory node
@@ -119,50 +121,24 @@ class SearchChunk(object):
# check if peak node already in chunkinfo
if chunk_regions is not None:
for i in chunk_regions:
if i["region"][0] < peak_node_idx <= i["region"][1]:
if i["region"][0] < peak_region[0] <= i["region"][1] or \
i["region"][0] < peak_region[1] <= i["region"][1]:
return None
free_vars = self._get_free_var_idx()
free_var_num = len(free_vars)
active_node_num = [len(i) for i in active_node]
min_active_node_num = min(active_node_num[free_var_num:])
threshold = max(free_var_num, min_active_node_num)
# normal search
# from peak_node to free_var
inside_flag = False
chunk_region_start = free_var_num
for i in range(peak_node_idx, -1, -1):
if active_node_num[i] <= threshold:
inside_flag = True
if inside_flag and active_node_num[i] > threshold:
chunk_region_start = i + 1
break
# from peak_node to len-2
inside_flag = False
chunk_region_end = len(active_node) - 1
for i in range(peak_node_idx, len(active_node)):
if active_node_num[i] <= threshold:
inside_flag = True
if inside_flag and active_node_num[i] > threshold:
window_size = 100
# search min for start
min_num = 1e4
for i in range(peak_region[0], max(peak_region[0] - window_size, -1), -1):
if active_node_num[i] < min_num:
min_num = active_node_num[i]
chunk_region_start = i
# search min for end
min_num = 1e4
for i in range(peak_region[1], min(peak_region[1] + window_size, len(active_node_num))):
if active_node_num[i] < min_num:
min_num = active_node_num[i]
chunk_region_end = i
break
# if normal search fails, use approximate search
if (chunk_region_end - chunk_region_start) > 250:
window_size = 100
# search min for start
min_num = 1e3
for i in range(max(peak_node_idx - window_size, 0), peak_node_idx + 1):
if active_node_num[i] < min_num:
min_num = active_node_num[i]
chunk_region_start = i
# search min for end
min_num = 1e3
for i in range(min(peak_node_idx + window_size, len(active_node_num) - 1), peak_node_idx - 1, -1):
if active_node_num[i] < min_num:
min_num = active_node_num[i]
chunk_region_end = i
# avoid chunk regions overlap
if chunk_regions is not None:
@@ -214,7 +190,7 @@ class SearchChunk(object):
chunk_infos.append(chunk_info)
return chunk_infos
def _search_possible_chunk_regions(self, max_chunk_region: Tuple, peak_node: Node) -> List:
def _search_possible_chunk_regions(self, max_chunk_region: Tuple, peak_region: Node) -> List:
"""
Search every possible region within the max chunk region.
@@ -235,8 +211,8 @@ class SearchChunk(object):
cur_trace[arg] = self.trace_indice._find_trace_from_node(arg)
input_trace.append(cur_trace)
for start_idx in range(max_chunk_region[0], peak_node + 1):
for end_idx in range(peak_node, max_chunk_region[1] + 1):
for start_idx in range(max_chunk_region[0], peak_region[0] + 1):
for end_idx in range(peak_region[1], max_chunk_region[1] + 1):
# skip non compute nodes
if is_non_compute_node(self.node_mgr.get_node_by_idx(start_idx)) or is_non_compute_node(
self.node_mgr.get_node_by_idx(end_idx)):
@@ -270,13 +246,12 @@ class SearchChunk(object):
Returns:
best_chunk_region (Dict)
"""
peak_node = self._find_peak_node(mem_peak)
max_chunk_region = self._search_max_chunk_region(active_node, peak_node, chunk_infos)
peak_region = self._find_peak_region(mem_peak)
max_chunk_region = self._search_max_chunk_region(active_node, peak_region, chunk_infos)
if max_chunk_region == None:
return None
possible_chunk_regions = self._search_possible_chunk_regions(max_chunk_region, peak_node)
best_chunk_region = self.select_chunk._select_best_chunk_region(possible_chunk_regions, chunk_infos, peak_node,
max_chunk_region, mem_peak)
possible_chunk_regions = self._search_possible_chunk_regions(max_chunk_region, peak_region)
best_chunk_region = self.select_chunk._select_best_chunk_region(possible_chunk_regions, chunk_infos, mem_peak)
best_chunk_region = self.reorder_graph.reorder_all(best_chunk_region)
return best_chunk_region