mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-08 12:30:42 +00:00
[autochunk] refactor chunk memory estimation (#2762)
* refact memory code * dont log free var memory * add memory align * update chunk target * update setting for new memory * finish test * update tracer * update typo * update test
This commit is contained in:
@@ -42,10 +42,11 @@ class SearchChunk(object):
|
||||
|
||||
def __init__(self, gm, max_memory=None, print_mem=False, print_progress=False) -> None:
|
||||
self.print_mem = print_mem
|
||||
self.max_memory = max_memory
|
||||
self.print_progress = print_progress
|
||||
self.node_mgr = NodeMgr(gm)
|
||||
self.node_mgr = NodeMgr(list(gm.graph.nodes))
|
||||
self.trace_indice = TraceIndice(self.node_mgr)
|
||||
self.estimate_memory = EstimateMemory(self.node_mgr)
|
||||
self.estimate_memory = EstimateMemory()
|
||||
self._init_trace()
|
||||
self.trace_flow = TraceFlow(self.trace_indice, self.node_mgr)
|
||||
self.reorder_graph = ReorderGraph(self.trace_indice, self.node_mgr)
|
||||
@@ -63,45 +64,46 @@ class SearchChunk(object):
|
||||
reduce the computation complexity of trace_indice
|
||||
"""
|
||||
# find all max ranges
|
||||
active_nodes = self.estimate_memory.get_active_nodes(self.node_mgr.get_node_list())
|
||||
cur_node_idx = len(self._get_free_var_idx())
|
||||
max_chunk_region_list = []
|
||||
while True:
|
||||
max_chunk_region = self._search_max_chunk_region(active_nodes, cur_node_idx)
|
||||
cur_node_idx = max_chunk_region[1] + 1
|
||||
if cur_node_idx >= len(active_nodes) - 1:
|
||||
break
|
||||
max_chunk_region_list.append(max_chunk_region)
|
||||
|
||||
# nothing to limit for the first range
|
||||
max_chunk_region_list = max_chunk_region_list[1:]
|
||||
max_chunk_region_list[0] = (0, max_chunk_region_list[0][1])
|
||||
|
||||
active_nodes = self.estimate_memory.estimate_chunk_inference_mem(self.node_mgr.get_node_list())[2]
|
||||
# set trace range and do the trace
|
||||
if self.print_progress:
|
||||
get_logger().info("AutoChunk start tracing indice")
|
||||
self.trace_indice.set_trace_range(max_chunk_region_list, active_nodes)
|
||||
self.trace_indice.set_active_nodes(active_nodes)
|
||||
self.trace_indice.trace_indice()
|
||||
|
||||
def _find_peak_node(self, mem_peak: List) -> int:
|
||||
def _find_peak_region(self, mem_peak: List) -> int:
|
||||
"""
|
||||
find peak node, along with its neighbour nodes exceeds max mem
|
||||
"""
|
||||
max_value = max(mem_peak)
|
||||
max_idx = mem_peak.index(max_value)
|
||||
return max_idx
|
||||
peak_region = [max_idx, max_idx]
|
||||
if self.max_memory is None:
|
||||
return peak_region
|
||||
|
||||
def _get_free_var_idx(self) -> List:
|
||||
"""
|
||||
Get free var index
|
||||
# to left
|
||||
count = 0
|
||||
for i in range(max_idx - 1, -1, -1):
|
||||
if mem_peak[i] > self.max_memory:
|
||||
peak_region[0] = i
|
||||
else:
|
||||
count += 1
|
||||
if count >= 3:
|
||||
break
|
||||
# to right
|
||||
count = 0
|
||||
for i in range(max_idx + 1, len(mem_peak) - 1):
|
||||
if mem_peak[i] > self.max_memory:
|
||||
peak_region[1] = i
|
||||
count = 0
|
||||
else:
|
||||
count += 1
|
||||
if count >= 3:
|
||||
break
|
||||
|
||||
Returns:
|
||||
free_var_idx (List): all indexs of free vars
|
||||
"""
|
||||
free_var_idx = []
|
||||
for idx, n in enumerate(self.node_mgr.get_node_list()):
|
||||
if n.op == "placeholder" and get_node_shape(n) is not None:
|
||||
free_var_idx.append(idx)
|
||||
return free_var_idx
|
||||
return peak_region
|
||||
|
||||
def _search_max_chunk_region(self, active_node: List, peak_node_idx: int, chunk_regions: List = None) -> Tuple:
|
||||
def _search_max_chunk_region(self, active_node: List, peak_region: int, chunk_regions: List = None) -> Tuple:
|
||||
"""
|
||||
Search max chunk region according to peak memory node
|
||||
|
||||
@@ -119,50 +121,24 @@ class SearchChunk(object):
|
||||
# check if peak node already in chunkinfo
|
||||
if chunk_regions is not None:
|
||||
for i in chunk_regions:
|
||||
if i["region"][0] < peak_node_idx <= i["region"][1]:
|
||||
if i["region"][0] < peak_region[0] <= i["region"][1] or \
|
||||
i["region"][0] < peak_region[1] <= i["region"][1]:
|
||||
return None
|
||||
|
||||
free_vars = self._get_free_var_idx()
|
||||
free_var_num = len(free_vars)
|
||||
active_node_num = [len(i) for i in active_node]
|
||||
min_active_node_num = min(active_node_num[free_var_num:])
|
||||
threshold = max(free_var_num, min_active_node_num)
|
||||
|
||||
# normal search
|
||||
# from peak_node to free_var
|
||||
inside_flag = False
|
||||
chunk_region_start = free_var_num
|
||||
for i in range(peak_node_idx, -1, -1):
|
||||
if active_node_num[i] <= threshold:
|
||||
inside_flag = True
|
||||
if inside_flag and active_node_num[i] > threshold:
|
||||
chunk_region_start = i + 1
|
||||
break
|
||||
# from peak_node to len-2
|
||||
inside_flag = False
|
||||
chunk_region_end = len(active_node) - 1
|
||||
for i in range(peak_node_idx, len(active_node)):
|
||||
if active_node_num[i] <= threshold:
|
||||
inside_flag = True
|
||||
if inside_flag and active_node_num[i] > threshold:
|
||||
window_size = 100
|
||||
# search min for start
|
||||
min_num = 1e4
|
||||
for i in range(peak_region[0], max(peak_region[0] - window_size, -1), -1):
|
||||
if active_node_num[i] < min_num:
|
||||
min_num = active_node_num[i]
|
||||
chunk_region_start = i
|
||||
# search min for end
|
||||
min_num = 1e4
|
||||
for i in range(peak_region[1], min(peak_region[1] + window_size, len(active_node_num))):
|
||||
if active_node_num[i] < min_num:
|
||||
min_num = active_node_num[i]
|
||||
chunk_region_end = i
|
||||
break
|
||||
|
||||
# if normal search fails, use approximate search
|
||||
if (chunk_region_end - chunk_region_start) > 250:
|
||||
window_size = 100
|
||||
# search min for start
|
||||
min_num = 1e3
|
||||
for i in range(max(peak_node_idx - window_size, 0), peak_node_idx + 1):
|
||||
if active_node_num[i] < min_num:
|
||||
min_num = active_node_num[i]
|
||||
chunk_region_start = i
|
||||
# search min for end
|
||||
min_num = 1e3
|
||||
for i in range(min(peak_node_idx + window_size, len(active_node_num) - 1), peak_node_idx - 1, -1):
|
||||
if active_node_num[i] < min_num:
|
||||
min_num = active_node_num[i]
|
||||
chunk_region_end = i
|
||||
|
||||
# avoid chunk regions overlap
|
||||
if chunk_regions is not None:
|
||||
@@ -214,7 +190,7 @@ class SearchChunk(object):
|
||||
chunk_infos.append(chunk_info)
|
||||
return chunk_infos
|
||||
|
||||
def _search_possible_chunk_regions(self, max_chunk_region: Tuple, peak_node: Node) -> List:
|
||||
def _search_possible_chunk_regions(self, max_chunk_region: Tuple, peak_region: Node) -> List:
|
||||
"""
|
||||
Search every possible region within the max chunk region.
|
||||
|
||||
@@ -235,8 +211,8 @@ class SearchChunk(object):
|
||||
cur_trace[arg] = self.trace_indice._find_trace_from_node(arg)
|
||||
input_trace.append(cur_trace)
|
||||
|
||||
for start_idx in range(max_chunk_region[0], peak_node + 1):
|
||||
for end_idx in range(peak_node, max_chunk_region[1] + 1):
|
||||
for start_idx in range(max_chunk_region[0], peak_region[0] + 1):
|
||||
for end_idx in range(peak_region[1], max_chunk_region[1] + 1):
|
||||
# skip non compute nodes
|
||||
if is_non_compute_node(self.node_mgr.get_node_by_idx(start_idx)) or is_non_compute_node(
|
||||
self.node_mgr.get_node_by_idx(end_idx)):
|
||||
@@ -270,13 +246,12 @@ class SearchChunk(object):
|
||||
Returns:
|
||||
best_chunk_region (Dict)
|
||||
"""
|
||||
peak_node = self._find_peak_node(mem_peak)
|
||||
max_chunk_region = self._search_max_chunk_region(active_node, peak_node, chunk_infos)
|
||||
peak_region = self._find_peak_region(mem_peak)
|
||||
max_chunk_region = self._search_max_chunk_region(active_node, peak_region, chunk_infos)
|
||||
if max_chunk_region == None:
|
||||
return None
|
||||
possible_chunk_regions = self._search_possible_chunk_regions(max_chunk_region, peak_node)
|
||||
best_chunk_region = self.select_chunk._select_best_chunk_region(possible_chunk_regions, chunk_infos, peak_node,
|
||||
max_chunk_region, mem_peak)
|
||||
possible_chunk_regions = self._search_possible_chunk_regions(max_chunk_region, peak_region)
|
||||
best_chunk_region = self.select_chunk._select_best_chunk_region(possible_chunk_regions, chunk_infos, mem_peak)
|
||||
best_chunk_region = self.reorder_graph.reorder_all(best_chunk_region)
|
||||
return best_chunk_region
|
||||
|
||||
|
Reference in New Issue
Block a user