mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-10 13:30:19 +00:00
[autochunk] support multi outputs chunk search (#2538)
Support multi outputs chunk search. Previously we only support single output chunk search. It is more flexible and improve performance by a large margin. For transformer, we reduce memory by 40% than previous search strategy. 1. rewrite search strategy to support multi outputs chunk search 2. fix many, many bugs 3. update tests
This commit is contained in:
@@ -9,6 +9,7 @@ from .select_chunk import SelectChunk
|
||||
from .trace_flow import TraceFlow
|
||||
from .trace_indice import TraceIndice
|
||||
from .utils import (
|
||||
NodeMgr,
|
||||
find_chunk_compute_input_and_output_nodes,
|
||||
get_logger,
|
||||
get_node_shape,
|
||||
@@ -49,15 +50,17 @@ class SearchChunk(object):
|
||||
def __init__(self, gm, max_memory=None, print_mem=False, print_progress=False) -> None:
|
||||
self.print_mem = print_mem
|
||||
self.print_progress = print_progress
|
||||
self.trace_indice = TraceIndice(list(gm.graph.nodes))
|
||||
self.estimate_memory = EstimateMemory()
|
||||
self.node_mgr = NodeMgr(gm)
|
||||
self.trace_indice = TraceIndice(self.node_mgr)
|
||||
self.estimate_memory = EstimateMemory(self.node_mgr)
|
||||
self._init_trace()
|
||||
self.trace_flow = TraceFlow(self.trace_indice)
|
||||
self.reorder_graph = ReorderGraph(self.trace_indice)
|
||||
self.trace_flow = TraceFlow(self.trace_indice, self.node_mgr)
|
||||
self.reorder_graph = ReorderGraph(self.trace_indice, self.node_mgr)
|
||||
self.select_chunk = SelectChunk(
|
||||
self.trace_indice,
|
||||
self.estimate_memory,
|
||||
self.reorder_graph,
|
||||
self.node_mgr,
|
||||
max_memory=max_memory,
|
||||
)
|
||||
|
||||
@@ -67,7 +70,7 @@ class SearchChunk(object):
|
||||
reduce the computation complexity of trace_indice
|
||||
"""
|
||||
# find all max ranges
|
||||
active_nodes = self.estimate_memory.get_active_nodes(self.trace_indice.node_list)
|
||||
active_nodes = self.estimate_memory.get_active_nodes(self.node_mgr.get_node_list())
|
||||
cur_node_idx = len(self._get_free_var_idx())
|
||||
max_chunk_region_list = []
|
||||
while True:
|
||||
@@ -100,7 +103,7 @@ class SearchChunk(object):
|
||||
free_var_idx (List): all indexs of free vars
|
||||
"""
|
||||
free_var_idx = []
|
||||
for idx, n in enumerate(self.trace_indice.node_list):
|
||||
for idx, n in enumerate(self.node_mgr.get_node_list()):
|
||||
if n.op == "placeholder" and get_node_shape(n) is not None:
|
||||
free_var_idx.append(idx)
|
||||
return free_var_idx
|
||||
@@ -164,6 +167,44 @@ class SearchChunk(object):
|
||||
chunk_region_end = region[0] - 1
|
||||
return chunk_region_start, chunk_region_end
|
||||
|
||||
def _find_chunk_info(self, input_trace, output_trace, start_idx, end_idx) -> List:
|
||||
"""
|
||||
Find chunk info for a region.
|
||||
|
||||
We are given the region start and region end, and need to find out all chunk info for it.
|
||||
We first loop every dim of start node and end node, to see if we can find dim pair,
|
||||
which is linked in a flow and not computed.
|
||||
If found, we then search flow in the whole region to find out all chunk infos.
|
||||
|
||||
Args:
|
||||
input_trace (List): node's input trace in region
|
||||
output_trace (List): node's output trace in region
|
||||
start_idx (int): region start node index
|
||||
end_idx (int): region end node index
|
||||
|
||||
Returns:
|
||||
chunk_infos: possible regions found
|
||||
"""
|
||||
start_traces = input_trace[start_idx]
|
||||
if len(start_traces) > 1: # TODO need to be removed
|
||||
return []
|
||||
end_trace = output_trace[end_idx]
|
||||
end_node = self.node_mgr.get_node_by_idx(end_idx)
|
||||
|
||||
chunk_infos = []
|
||||
for end_dim, _ in enumerate(end_trace["indice"]):
|
||||
for start_node, start_trace in start_traces.items():
|
||||
for start_dim, _ in enumerate(start_trace["indice"]):
|
||||
if not self.trace_flow.check_region_start_end(start_node, start_dim, start_idx, end_node, end_dim,
|
||||
end_idx):
|
||||
continue
|
||||
# flow search
|
||||
chunk_info = self.trace_flow.flow_search(start_idx, start_dim, end_idx, end_dim)
|
||||
if chunk_info is None:
|
||||
continue
|
||||
chunk_infos.append(chunk_info)
|
||||
return chunk_infos
|
||||
|
||||
def _search_possible_chunk_regions(self, max_chunk_region: Tuple, peak_node: Node) -> List:
|
||||
"""
|
||||
Search every possible region within the max chunk region.
|
||||
@@ -178,7 +219,7 @@ class SearchChunk(object):
|
||||
possible_chunk_region = []
|
||||
output_trace = copy.deepcopy(self.trace_indice.indice_trace_list)
|
||||
input_trace = [] # trace of a node's input nodes
|
||||
for _, n in enumerate(self.trace_indice.node_list):
|
||||
for _, n in enumerate(self.node_mgr.get_node_list()):
|
||||
cur_trace = {}
|
||||
for arg in n.args:
|
||||
if type(arg) == type(n) and not is_non_compute_node_except_placeholder(arg):
|
||||
@@ -188,11 +229,11 @@ class SearchChunk(object):
|
||||
for start_idx in range(max_chunk_region[0], peak_node + 1):
|
||||
for end_idx in range(peak_node, max_chunk_region[1] + 1):
|
||||
# skip non compute nodes
|
||||
if is_non_compute_node(self.trace_indice.node_list[start_idx]) or is_non_compute_node(
|
||||
self.trace_indice.node_list[end_idx]):
|
||||
if is_non_compute_node(self.node_mgr.get_node_by_idx(start_idx)) or is_non_compute_node(
|
||||
self.node_mgr.get_node_by_idx(end_idx)):
|
||||
continue
|
||||
# select free dim
|
||||
chunk_info = self.trace_flow.find_chunk_info(input_trace, output_trace, start_idx, end_idx)
|
||||
chunk_info = self._find_chunk_info(input_trace, output_trace, start_idx, end_idx)
|
||||
if len(chunk_info) > 0:
|
||||
possible_chunk_region.extend(chunk_info)
|
||||
return possible_chunk_region
|
||||
@@ -254,7 +295,7 @@ class SearchChunk(object):
|
||||
init_mem_peak,
|
||||
_,
|
||||
active_node,
|
||||
) = self.estimate_memory.estimate_chunk_inference_mem(self.trace_indice.node_list)
|
||||
) = self.estimate_memory.estimate_chunk_inference_mem(self.node_mgr.get_node_list())
|
||||
mem_peak = init_mem_peak
|
||||
|
||||
while True:
|
||||
@@ -267,7 +308,7 @@ class SearchChunk(object):
|
||||
mem_peak,
|
||||
_,
|
||||
active_node,
|
||||
) = self.estimate_memory.estimate_chunk_inference_mem(self.trace_indice.node_list, chunk_infos)
|
||||
) = self.estimate_memory.estimate_chunk_inference_mem(self.node_mgr.get_node_list(), chunk_infos)
|
||||
|
||||
if self.print_progress:
|
||||
get_logger().info("AutoChunk find chunk region %d = (%d, %d)" %
|
||||
@@ -277,5 +318,7 @@ class SearchChunk(object):
|
||||
break
|
||||
if self.print_mem:
|
||||
self.print_mem = False
|
||||
self.estimate_memory.estimate_chunk_inference_mem(self.trace_indice.node_list, chunk_infos, print_mem=True)
|
||||
self.estimate_memory.estimate_chunk_inference_mem(self.node_mgr.get_node_list(),
|
||||
chunk_infos,
|
||||
print_mem=True)
|
||||
return chunk_infos
|
||||
|
Reference in New Issue
Block a user