[autochunk] support multi outputs chunk search (#2538)

Support multi outputs chunk search. Previously we only support single output chunk search. It is more flexible and improve performance by a large margin. For transformer, we reduce memory by 40% than previous search strategy.

1. rewrite search strategy to support multi outputs chunk search
2. fix many, many bugs
3. update tests
This commit is contained in:
oahzxl
2023-02-01 13:18:51 +08:00
committed by GitHub
parent f477a14f4a
commit 05671fcb42
14 changed files with 428 additions and 258 deletions

View File

@@ -9,6 +9,7 @@ from .select_chunk import SelectChunk
from .trace_flow import TraceFlow
from .trace_indice import TraceIndice
from .utils import (
NodeMgr,
find_chunk_compute_input_and_output_nodes,
get_logger,
get_node_shape,
@@ -49,15 +50,17 @@ class SearchChunk(object):
def __init__(self, gm, max_memory=None, print_mem=False, print_progress=False) -> None:
self.print_mem = print_mem
self.print_progress = print_progress
self.trace_indice = TraceIndice(list(gm.graph.nodes))
self.estimate_memory = EstimateMemory()
self.node_mgr = NodeMgr(gm)
self.trace_indice = TraceIndice(self.node_mgr)
self.estimate_memory = EstimateMemory(self.node_mgr)
self._init_trace()
self.trace_flow = TraceFlow(self.trace_indice)
self.reorder_graph = ReorderGraph(self.trace_indice)
self.trace_flow = TraceFlow(self.trace_indice, self.node_mgr)
self.reorder_graph = ReorderGraph(self.trace_indice, self.node_mgr)
self.select_chunk = SelectChunk(
self.trace_indice,
self.estimate_memory,
self.reorder_graph,
self.node_mgr,
max_memory=max_memory,
)
@@ -67,7 +70,7 @@ class SearchChunk(object):
reduce the computation complexity of trace_indice
"""
# find all max ranges
active_nodes = self.estimate_memory.get_active_nodes(self.trace_indice.node_list)
active_nodes = self.estimate_memory.get_active_nodes(self.node_mgr.get_node_list())
cur_node_idx = len(self._get_free_var_idx())
max_chunk_region_list = []
while True:
@@ -100,7 +103,7 @@ class SearchChunk(object):
free_var_idx (List): all indexs of free vars
"""
free_var_idx = []
for idx, n in enumerate(self.trace_indice.node_list):
for idx, n in enumerate(self.node_mgr.get_node_list()):
if n.op == "placeholder" and get_node_shape(n) is not None:
free_var_idx.append(idx)
return free_var_idx
@@ -164,6 +167,44 @@ class SearchChunk(object):
chunk_region_end = region[0] - 1
return chunk_region_start, chunk_region_end
def _find_chunk_info(self, input_trace, output_trace, start_idx, end_idx) -> List:
"""
Find chunk info for a region.
We are given the region start and region end, and need to find out all chunk info for it.
We first loop every dim of start node and end node, to see if we can find dim pair,
which is linked in a flow and not computed.
If found, we then search flow in the whole region to find out all chunk infos.
Args:
input_trace (List): node's input trace in region
output_trace (List): node's output trace in region
start_idx (int): region start node index
end_idx (int): region end node index
Returns:
chunk_infos: possible regions found
"""
start_traces = input_trace[start_idx]
if len(start_traces) > 1: # TODO need to be removed
return []
end_trace = output_trace[end_idx]
end_node = self.node_mgr.get_node_by_idx(end_idx)
chunk_infos = []
for end_dim, _ in enumerate(end_trace["indice"]):
for start_node, start_trace in start_traces.items():
for start_dim, _ in enumerate(start_trace["indice"]):
if not self.trace_flow.check_region_start_end(start_node, start_dim, start_idx, end_node, end_dim,
end_idx):
continue
# flow search
chunk_info = self.trace_flow.flow_search(start_idx, start_dim, end_idx, end_dim)
if chunk_info is None:
continue
chunk_infos.append(chunk_info)
return chunk_infos
def _search_possible_chunk_regions(self, max_chunk_region: Tuple, peak_node: Node) -> List:
"""
Search every possible region within the max chunk region.
@@ -178,7 +219,7 @@ class SearchChunk(object):
possible_chunk_region = []
output_trace = copy.deepcopy(self.trace_indice.indice_trace_list)
input_trace = [] # trace of a node's input nodes
for _, n in enumerate(self.trace_indice.node_list):
for _, n in enumerate(self.node_mgr.get_node_list()):
cur_trace = {}
for arg in n.args:
if type(arg) == type(n) and not is_non_compute_node_except_placeholder(arg):
@@ -188,11 +229,11 @@ class SearchChunk(object):
for start_idx in range(max_chunk_region[0], peak_node + 1):
for end_idx in range(peak_node, max_chunk_region[1] + 1):
# skip non compute nodes
if is_non_compute_node(self.trace_indice.node_list[start_idx]) or is_non_compute_node(
self.trace_indice.node_list[end_idx]):
if is_non_compute_node(self.node_mgr.get_node_by_idx(start_idx)) or is_non_compute_node(
self.node_mgr.get_node_by_idx(end_idx)):
continue
# select free dim
chunk_info = self.trace_flow.find_chunk_info(input_trace, output_trace, start_idx, end_idx)
chunk_info = self._find_chunk_info(input_trace, output_trace, start_idx, end_idx)
if len(chunk_info) > 0:
possible_chunk_region.extend(chunk_info)
return possible_chunk_region
@@ -254,7 +295,7 @@ class SearchChunk(object):
init_mem_peak,
_,
active_node,
) = self.estimate_memory.estimate_chunk_inference_mem(self.trace_indice.node_list)
) = self.estimate_memory.estimate_chunk_inference_mem(self.node_mgr.get_node_list())
mem_peak = init_mem_peak
while True:
@@ -267,7 +308,7 @@ class SearchChunk(object):
mem_peak,
_,
active_node,
) = self.estimate_memory.estimate_chunk_inference_mem(self.trace_indice.node_list, chunk_infos)
) = self.estimate_memory.estimate_chunk_inference_mem(self.node_mgr.get_node_list(), chunk_infos)
if self.print_progress:
get_logger().info("AutoChunk find chunk region %d = (%d, %d)" %
@@ -277,5 +318,7 @@ class SearchChunk(object):
break
if self.print_mem:
self.print_mem = False
self.estimate_memory.estimate_chunk_inference_mem(self.trace_indice.node_list, chunk_infos, print_mem=True)
self.estimate_memory.estimate_chunk_inference_mem(self.node_mgr.get_node_list(),
chunk_infos,
print_mem=True)
return chunk_infos