mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-07 03:52:01 +00:00
[autochunk] support autochunk on evoformer (#2497)
This commit is contained in:
@@ -8,11 +8,7 @@ from .reorder_graph import ReorderGraph
|
||||
from .select_chunk import SelectChunk
|
||||
from .trace_flow import TraceFlow
|
||||
from .trace_indice import TraceIndice
|
||||
from .utils import (
|
||||
get_node_shape,
|
||||
is_non_compute_node,
|
||||
is_non_compute_node_except_placeholder,
|
||||
)
|
||||
from .utils import get_node_shape, is_non_compute_node, is_non_compute_node_except_placeholder
|
||||
|
||||
|
||||
class SearchChunk(object):
|
||||
@@ -73,13 +69,11 @@ class SearchChunk(object):
|
||||
"""
|
||||
free_var_idx = []
|
||||
for idx, n in enumerate(self.trace_indice.node_list):
|
||||
if n.op == "placeholder":
|
||||
if n.op == "placeholder" and get_node_shape(n) is not None:
|
||||
free_var_idx.append(idx)
|
||||
return free_var_idx
|
||||
|
||||
def _search_max_chunk_region(
|
||||
self, active_node: List, peak_node: Node, chunk_regions: List
|
||||
) -> Tuple:
|
||||
def _search_max_chunk_region(self, active_node: List, peak_node: Node, chunk_regions: List) -> Tuple:
|
||||
"""
|
||||
Search max chunk region according to peak memory node
|
||||
|
||||
@@ -124,15 +118,9 @@ class SearchChunk(object):
|
||||
region = i["region"]
|
||||
if chunk_region_start >= region[0] and chunk_region_end <= region[1]:
|
||||
return None
|
||||
elif (
|
||||
region[0] <= chunk_region_start <= region[1]
|
||||
and chunk_region_end > region[1]
|
||||
):
|
||||
elif (region[0] <= chunk_region_start <= region[1] and chunk_region_end > region[1]):
|
||||
chunk_region_start = region[1] + 1
|
||||
elif (
|
||||
region[0] <= chunk_region_end <= region[1]
|
||||
and chunk_region_start < region[0]
|
||||
):
|
||||
elif (region[0] <= chunk_region_end <= region[1] and chunk_region_start < region[0]):
|
||||
chunk_region_end = region[0] - 1
|
||||
return chunk_region_start, chunk_region_end
|
||||
|
||||
@@ -164,25 +152,16 @@ class SearchChunk(object):
|
||||
for start_node, start_trace in start_traces.items():
|
||||
for start_dim, _ in enumerate(start_trace["indice"]):
|
||||
# dim size cannot be 1
|
||||
if (
|
||||
get_node_shape(end_node)[end_dim] == 1
|
||||
or get_node_shape(start_node)[start_dim] == 1
|
||||
):
|
||||
if (get_node_shape(end_node)[end_dim] == 1 or get_node_shape(start_node)[start_dim] == 1):
|
||||
continue
|
||||
# check index source align
|
||||
if not self.trace_flow.check_index_source(
|
||||
start_dim, start_node, start_idx, end_dim, end_node
|
||||
):
|
||||
if not self.trace_flow.check_index_source(start_dim, start_node, start_idx, end_dim, end_node):
|
||||
continue
|
||||
# check index copmute
|
||||
if not self.trace_flow.check_index_compute(
|
||||
start_idx, end_dim, end_node, end_idx
|
||||
):
|
||||
if not self.trace_flow.check_index_compute(start_idx, end_dim, end_node, end_idx):
|
||||
continue
|
||||
# flow search
|
||||
chunk_info = self.trace_flow.flow_search(
|
||||
start_idx, start_dim, end_idx, end_dim
|
||||
)
|
||||
chunk_info = self.trace_flow.flow_search(start_idx, start_dim, end_idx, end_dim)
|
||||
if chunk_info is None:
|
||||
continue
|
||||
# check index copmute
|
||||
@@ -191,9 +170,7 @@ class SearchChunk(object):
|
||||
chunk_infos.append(chunk_info)
|
||||
return chunk_infos
|
||||
|
||||
def _search_possible_chunk_regions(
|
||||
self, max_chunk_region: Tuple, peak_node: Node
|
||||
) -> List:
|
||||
def _search_possible_chunk_regions(self, max_chunk_region: Tuple, peak_node: Node) -> List:
|
||||
"""
|
||||
Search every possible region within the max chunk region.
|
||||
|
||||
@@ -206,28 +183,23 @@ class SearchChunk(object):
|
||||
"""
|
||||
possible_chunk_region = []
|
||||
output_trace = copy.deepcopy(self.trace_indice.indice_trace_list)
|
||||
input_trace = [] # trace of a node's input nodes
|
||||
input_trace = [] # trace of a node's input nodes
|
||||
for _, n in enumerate(self.trace_indice.node_list):
|
||||
cur_trace = {}
|
||||
for arg in n.args:
|
||||
if type(arg) == type(n) and not is_non_compute_node_except_placeholder(
|
||||
arg
|
||||
):
|
||||
if type(arg) == type(n) and not is_non_compute_node_except_placeholder(arg):
|
||||
cur_trace[arg] = self.trace_indice._find_trace_from_node(arg)
|
||||
input_trace.append(cur_trace)
|
||||
|
||||
for start_idx in range(max_chunk_region[0], peak_node + 1):
|
||||
for end_idx in range(peak_node, max_chunk_region[1] + 1):
|
||||
# skip non compute nodes
|
||||
if is_non_compute_node(
|
||||
self.trace_indice.node_list[start_idx]
|
||||
) or is_non_compute_node(self.trace_indice.node_list[end_idx]):
|
||||
if is_non_compute_node(self.trace_indice.node_list[start_idx]) or is_non_compute_node(
|
||||
self.trace_indice.node_list[end_idx]):
|
||||
continue
|
||||
|
||||
# select free dim
|
||||
chunk_info = self._find_chunk_info(
|
||||
input_trace, output_trace, start_idx, end_idx
|
||||
)
|
||||
chunk_info = self._find_chunk_info(input_trace, output_trace, start_idx, end_idx)
|
||||
if len(chunk_info) > 0:
|
||||
possible_chunk_region.extend(chunk_info)
|
||||
return possible_chunk_region
|
||||
@@ -256,17 +228,12 @@ class SearchChunk(object):
|
||||
best_chunk_region (Dict)
|
||||
"""
|
||||
peak_node = self._find_peak_node(mem_peak)
|
||||
max_chunk_region = self._search_max_chunk_region(
|
||||
active_node, peak_node, chunk_infos
|
||||
)
|
||||
max_chunk_region = self._search_max_chunk_region(active_node, peak_node, chunk_infos)
|
||||
if max_chunk_region == None:
|
||||
return None
|
||||
possible_chunk_regions = self._search_possible_chunk_regions(
|
||||
max_chunk_region, peak_node
|
||||
)
|
||||
best_chunk_region = self.select_chunk._select_best_chunk_region(
|
||||
possible_chunk_regions, chunk_infos, peak_node, max_chunk_region, mem_peak
|
||||
)
|
||||
possible_chunk_regions = self._search_possible_chunk_regions(max_chunk_region, peak_node)
|
||||
best_chunk_region = self.select_chunk._select_best_chunk_region(possible_chunk_regions, chunk_infos, peak_node,
|
||||
max_chunk_region, mem_peak)
|
||||
best_chunk_region = self.reorder_graph.reorder_all(best_chunk_region)
|
||||
return best_chunk_region
|
||||
|
||||
@@ -291,9 +258,7 @@ class SearchChunk(object):
|
||||
init_mem_peak,
|
||||
_,
|
||||
active_node,
|
||||
) = self.estimate_memory.estimate_chunk_inference_mem(
|
||||
self.trace_indice.node_list
|
||||
)
|
||||
) = self.estimate_memory.estimate_chunk_inference_mem(self.trace_indice.node_list)
|
||||
mem_peak = init_mem_peak
|
||||
|
||||
while True:
|
||||
@@ -306,14 +271,10 @@ class SearchChunk(object):
|
||||
mem_peak,
|
||||
_,
|
||||
active_node,
|
||||
) = self.estimate_memory.estimate_chunk_inference_mem(
|
||||
self.trace_indice.node_list, chunk_infos
|
||||
)
|
||||
) = self.estimate_memory.estimate_chunk_inference_mem(self.trace_indice.node_list, chunk_infos)
|
||||
if self._stop_search(init_mem_peak, mem_peak):
|
||||
break
|
||||
if self.print_mem:
|
||||
self.print_mem = False
|
||||
self.estimate_memory.estimate_chunk_inference_mem(
|
||||
self.trace_indice.node_list, chunk_infos, print_mem=True
|
||||
)
|
||||
self.estimate_memory.estimate_chunk_inference_mem(self.trace_indice.node_list, chunk_infos, print_mem=True)
|
||||
return chunk_infos
|
||||
|
Reference in New Issue
Block a user