mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-08 20:40:34 +00:00
[autochunk] support parsing blocks (#2506)
This commit is contained in:
@@ -8,7 +8,7 @@ from .reorder_graph import ReorderGraph
|
||||
from .select_chunk import SelectChunk
|
||||
from .trace_flow import TraceFlow
|
||||
from .trace_indice import TraceIndice
|
||||
from .utils import get_node_shape, is_non_compute_node, is_non_compute_node_except_placeholder
|
||||
from .utils import get_logger, get_node_shape, is_non_compute_node, is_non_compute_node_except_placeholder
|
||||
|
||||
|
||||
class SearchChunk(object):
|
||||
@@ -40,14 +40,14 @@ class SearchChunk(object):
|
||||
print_mem (bool): print estimated memory
|
||||
"""
|
||||
|
||||
def __init__(self, gm, max_memory=None, print_mem=False) -> None:
|
||||
self.gm = gm
|
||||
def __init__(self, gm, max_memory=None, print_mem=False, print_progress=False) -> None:
|
||||
self.print_mem = print_mem
|
||||
self.print_progress = print_progress
|
||||
self.trace_indice = TraceIndice(list(gm.graph.nodes))
|
||||
self.trace_indice.trace_indice()
|
||||
self.estimate_memory = EstimateMemory()
|
||||
self._init_trace()
|
||||
self.trace_flow = TraceFlow(self.trace_indice)
|
||||
self.reorder_graph = ReorderGraph(self.trace_indice)
|
||||
self.estimate_memory = EstimateMemory()
|
||||
self.select_chunk = SelectChunk(
|
||||
self.trace_indice,
|
||||
self.estimate_memory,
|
||||
@@ -55,7 +55,33 @@ class SearchChunk(object):
|
||||
max_memory=max_memory,
|
||||
)
|
||||
|
||||
def _find_peak_node(self, mem_peak):
|
||||
def _init_trace(self) -> None:
|
||||
"""
|
||||
find the max trace range for every node
|
||||
reduce the computation complexity of trace_indice
|
||||
"""
|
||||
# find all max ranges
|
||||
active_nodes = self.estimate_memory.get_active_nodes(self.trace_indice.node_list)
|
||||
cur_node_idx = len(self._get_free_var_idx())
|
||||
max_chunk_region_list = []
|
||||
while True:
|
||||
max_chunk_region = self._search_max_chunk_region(active_nodes, cur_node_idx)
|
||||
cur_node_idx = max_chunk_region[1]
|
||||
if cur_node_idx == len(active_nodes) - 1:
|
||||
break
|
||||
max_chunk_region_list.append(max_chunk_region)
|
||||
|
||||
# nothing to limit for the first range
|
||||
max_chunk_region_list = max_chunk_region_list[1:]
|
||||
max_chunk_region_list[0] = (0, max_chunk_region_list[0][1])
|
||||
|
||||
# set trace range and do the trace
|
||||
if self.print_progress:
|
||||
get_logger().info("AutoChunk start tracing indice")
|
||||
self.trace_indice.set_trace_range(max_chunk_region_list, active_nodes)
|
||||
self.trace_indice.trace_indice()
|
||||
|
||||
def _find_peak_node(self, mem_peak: List) -> int:
|
||||
max_value = max(mem_peak)
|
||||
max_idx = mem_peak.index(max_value)
|
||||
return max_idx
|
||||
@@ -73,7 +99,7 @@ class SearchChunk(object):
|
||||
free_var_idx.append(idx)
|
||||
return free_var_idx
|
||||
|
||||
def _search_max_chunk_region(self, active_node: List, peak_node: Node, chunk_regions: List) -> Tuple:
|
||||
def _search_max_chunk_region(self, active_node: List, peak_node_idx: int, chunk_regions: List = None) -> Tuple:
|
||||
"""
|
||||
Search max chunk region according to peak memory node
|
||||
|
||||
@@ -81,7 +107,7 @@ class SearchChunk(object):
|
||||
|
||||
Args:
|
||||
active_node (List): active node status for every node
|
||||
peak_node (Node): peak memory node
|
||||
peak_node_idx (int): peak memory node idx
|
||||
chunk_regions (List): chunk region infos
|
||||
|
||||
Returns:
|
||||
@@ -97,7 +123,7 @@ class SearchChunk(object):
|
||||
# from peak_node to free_var
|
||||
inside_flag = False
|
||||
chunk_region_start = free_var_num
|
||||
for i in range(peak_node, -1, -1):
|
||||
for i in range(peak_node_idx, -1, -1):
|
||||
if active_node_num[i] <= threshold:
|
||||
inside_flag = True
|
||||
if inside_flag and active_node_num[i] > threshold:
|
||||
@@ -107,21 +133,23 @@ class SearchChunk(object):
|
||||
# from peak_node to len-2
|
||||
inside_flag = False
|
||||
chunk_region_end = len(active_node) - 1
|
||||
for i in range(peak_node, len(active_node)):
|
||||
for i in range(peak_node_idx, len(active_node)):
|
||||
if active_node_num[i] <= threshold:
|
||||
inside_flag = True
|
||||
if inside_flag and active_node_num[i] > threshold:
|
||||
chunk_region_end = i
|
||||
break
|
||||
|
||||
for i in chunk_regions:
|
||||
region = i["region"]
|
||||
if chunk_region_start >= region[0] and chunk_region_end <= region[1]:
|
||||
return None
|
||||
elif (region[0] <= chunk_region_start <= region[1] and chunk_region_end > region[1]):
|
||||
chunk_region_start = region[1] + 1
|
||||
elif (region[0] <= chunk_region_end <= region[1] and chunk_region_start < region[0]):
|
||||
chunk_region_end = region[0] - 1
|
||||
# avoid chunk regions overlap
|
||||
if chunk_regions is not None:
|
||||
for i in chunk_regions:
|
||||
region = i["region"]
|
||||
if chunk_region_start >= region[0] and chunk_region_end <= region[1]:
|
||||
return None
|
||||
elif (region[0] <= chunk_region_start <= region[1] and chunk_region_end > region[1]):
|
||||
chunk_region_start = region[1] + 1
|
||||
elif (region[0] <= chunk_region_end <= region[1] and chunk_region_start < region[0]):
|
||||
chunk_region_end = region[0] - 1
|
||||
return chunk_region_start, chunk_region_end
|
||||
|
||||
def _find_chunk_info(self, input_trace, output_trace, start_idx, end_idx) -> List:
|
||||
@@ -154,6 +182,9 @@ class SearchChunk(object):
|
||||
# dim size cannot be 1
|
||||
if (get_node_shape(end_node)[end_dim] == 1 or get_node_shape(start_node)[start_dim] == 1):
|
||||
continue
|
||||
# must have users
|
||||
if len(end_node.users) == 0:
|
||||
continue
|
||||
# check index source align
|
||||
if not self.trace_flow.check_index_source(start_dim, start_node, start_idx, end_dim, end_node):
|
||||
continue
|
||||
@@ -253,6 +284,9 @@ class SearchChunk(object):
|
||||
Returns:
|
||||
chunk_infos (Dict)
|
||||
"""
|
||||
if self.print_progress:
|
||||
get_logger().info("AutoChunk start searching chunk regions")
|
||||
|
||||
chunk_infos = []
|
||||
(
|
||||
init_mem_peak,
|
||||
@@ -272,6 +306,11 @@ class SearchChunk(object):
|
||||
_,
|
||||
active_node,
|
||||
) = self.estimate_memory.estimate_chunk_inference_mem(self.trace_indice.node_list, chunk_infos)
|
||||
|
||||
if self.print_progress:
|
||||
get_logger().info("AutoChunk find chunk region %d = (%d, %d)" %
|
||||
(len(chunk_infos), chunk_info["region"][0], chunk_info["region"][1]))
|
||||
|
||||
if self._stop_search(init_mem_peak, mem_peak):
|
||||
break
|
||||
if self.print_mem:
|
||||
|
Reference in New Issue
Block a user