[autochunk] support autochunk on evoformer (#2497)

This commit is contained in:
oahzxl
2023-01-19 11:41:00 +08:00
committed by GitHub
parent 304f1ba124
commit ecccc91f21
9 changed files with 200 additions and 188 deletions

View File

@@ -8,11 +8,7 @@ from .reorder_graph import ReorderGraph
from .select_chunk import SelectChunk
from .trace_flow import TraceFlow
from .trace_indice import TraceIndice
from .utils import (
get_node_shape,
is_non_compute_node,
is_non_compute_node_except_placeholder,
)
from .utils import get_node_shape, is_non_compute_node, is_non_compute_node_except_placeholder
class SearchChunk(object):
@@ -73,13 +69,11 @@ class SearchChunk(object):
"""
free_var_idx = []
for idx, n in enumerate(self.trace_indice.node_list):
if n.op == "placeholder":
if n.op == "placeholder" and get_node_shape(n) is not None:
free_var_idx.append(idx)
return free_var_idx
def _search_max_chunk_region(
self, active_node: List, peak_node: Node, chunk_regions: List
) -> Tuple:
def _search_max_chunk_region(self, active_node: List, peak_node: Node, chunk_regions: List) -> Tuple:
"""
Search max chunk region according to peak memory node
@@ -124,15 +118,9 @@ class SearchChunk(object):
region = i["region"]
if chunk_region_start >= region[0] and chunk_region_end <= region[1]:
return None
elif (
region[0] <= chunk_region_start <= region[1]
and chunk_region_end > region[1]
):
elif (region[0] <= chunk_region_start <= region[1] and chunk_region_end > region[1]):
chunk_region_start = region[1] + 1
elif (
region[0] <= chunk_region_end <= region[1]
and chunk_region_start < region[0]
):
elif (region[0] <= chunk_region_end <= region[1] and chunk_region_start < region[0]):
chunk_region_end = region[0] - 1
return chunk_region_start, chunk_region_end
@@ -164,25 +152,16 @@ class SearchChunk(object):
for start_node, start_trace in start_traces.items():
for start_dim, _ in enumerate(start_trace["indice"]):
# dim size cannot be 1
if (
get_node_shape(end_node)[end_dim] == 1
or get_node_shape(start_node)[start_dim] == 1
):
if (get_node_shape(end_node)[end_dim] == 1 or get_node_shape(start_node)[start_dim] == 1):
continue
# check index source align
if not self.trace_flow.check_index_source(
start_dim, start_node, start_idx, end_dim, end_node
):
if not self.trace_flow.check_index_source(start_dim, start_node, start_idx, end_dim, end_node):
continue
# check index copmute
if not self.trace_flow.check_index_compute(
start_idx, end_dim, end_node, end_idx
):
if not self.trace_flow.check_index_compute(start_idx, end_dim, end_node, end_idx):
continue
# flow search
chunk_info = self.trace_flow.flow_search(
start_idx, start_dim, end_idx, end_dim
)
chunk_info = self.trace_flow.flow_search(start_idx, start_dim, end_idx, end_dim)
if chunk_info is None:
continue
# check index copmute
@@ -191,9 +170,7 @@ class SearchChunk(object):
chunk_infos.append(chunk_info)
return chunk_infos
def _search_possible_chunk_regions(
self, max_chunk_region: Tuple, peak_node: Node
) -> List:
def _search_possible_chunk_regions(self, max_chunk_region: Tuple, peak_node: Node) -> List:
"""
Search every possible region within the max chunk region.
@@ -206,28 +183,23 @@ class SearchChunk(object):
"""
possible_chunk_region = []
output_trace = copy.deepcopy(self.trace_indice.indice_trace_list)
input_trace = [] # trace of a node's input nodes
input_trace = [] # trace of a node's input nodes
for _, n in enumerate(self.trace_indice.node_list):
cur_trace = {}
for arg in n.args:
if type(arg) == type(n) and not is_non_compute_node_except_placeholder(
arg
):
if type(arg) == type(n) and not is_non_compute_node_except_placeholder(arg):
cur_trace[arg] = self.trace_indice._find_trace_from_node(arg)
input_trace.append(cur_trace)
for start_idx in range(max_chunk_region[0], peak_node + 1):
for end_idx in range(peak_node, max_chunk_region[1] + 1):
# skip non compute nodes
if is_non_compute_node(
self.trace_indice.node_list[start_idx]
) or is_non_compute_node(self.trace_indice.node_list[end_idx]):
if is_non_compute_node(self.trace_indice.node_list[start_idx]) or is_non_compute_node(
self.trace_indice.node_list[end_idx]):
continue
# select free dim
chunk_info = self._find_chunk_info(
input_trace, output_trace, start_idx, end_idx
)
chunk_info = self._find_chunk_info(input_trace, output_trace, start_idx, end_idx)
if len(chunk_info) > 0:
possible_chunk_region.extend(chunk_info)
return possible_chunk_region
@@ -256,17 +228,12 @@ class SearchChunk(object):
best_chunk_region (Dict)
"""
peak_node = self._find_peak_node(mem_peak)
max_chunk_region = self._search_max_chunk_region(
active_node, peak_node, chunk_infos
)
max_chunk_region = self._search_max_chunk_region(active_node, peak_node, chunk_infos)
if max_chunk_region == None:
return None
possible_chunk_regions = self._search_possible_chunk_regions(
max_chunk_region, peak_node
)
best_chunk_region = self.select_chunk._select_best_chunk_region(
possible_chunk_regions, chunk_infos, peak_node, max_chunk_region, mem_peak
)
possible_chunk_regions = self._search_possible_chunk_regions(max_chunk_region, peak_node)
best_chunk_region = self.select_chunk._select_best_chunk_region(possible_chunk_regions, chunk_infos, peak_node,
max_chunk_region, mem_peak)
best_chunk_region = self.reorder_graph.reorder_all(best_chunk_region)
return best_chunk_region
@@ -291,9 +258,7 @@ class SearchChunk(object):
init_mem_peak,
_,
active_node,
) = self.estimate_memory.estimate_chunk_inference_mem(
self.trace_indice.node_list
)
) = self.estimate_memory.estimate_chunk_inference_mem(self.trace_indice.node_list)
mem_peak = init_mem_peak
while True:
@@ -306,14 +271,10 @@ class SearchChunk(object):
mem_peak,
_,
active_node,
) = self.estimate_memory.estimate_chunk_inference_mem(
self.trace_indice.node_list, chunk_infos
)
) = self.estimate_memory.estimate_chunk_inference_mem(self.trace_indice.node_list, chunk_infos)
if self._stop_search(init_mem_peak, mem_peak):
break
if self.print_mem:
self.print_mem = False
self.estimate_memory.estimate_chunk_inference_mem(
self.trace_indice.node_list, chunk_infos, print_mem=True
)
self.estimate_memory.estimate_chunk_inference_mem(self.trace_indice.node_list, chunk_infos, print_mem=True)
return chunk_infos