mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-07 03:52:01 +00:00
[autochunk] support transformer (#2526)
This commit is contained in:
@@ -8,7 +8,13 @@ from .reorder_graph import ReorderGraph
|
||||
from .select_chunk import SelectChunk
|
||||
from .trace_flow import TraceFlow
|
||||
from .trace_indice import TraceIndice
|
||||
from .utils import get_logger, get_node_shape, is_non_compute_node, is_non_compute_node_except_placeholder
|
||||
from .utils import (
|
||||
find_chunk_compute_input_and_output_nodes,
|
||||
get_logger,
|
||||
get_node_shape,
|
||||
is_non_compute_node,
|
||||
is_non_compute_node_except_placeholder,
|
||||
)
|
||||
|
||||
|
||||
class SearchChunk(object):
|
||||
@@ -114,6 +120,12 @@ class SearchChunk(object):
|
||||
chunk_region_start (int)
|
||||
chunk_region_end (int)
|
||||
"""
|
||||
# check if peak node already in chunkinfo
|
||||
if chunk_regions is not None:
|
||||
for i in chunk_regions:
|
||||
if i["region"][0] < peak_node_idx <= i["region"][1]:
|
||||
return None
|
||||
|
||||
free_vars = self._get_free_var_idx()
|
||||
free_var_num = len(free_vars)
|
||||
active_node_num = [len(i) for i in active_node]
|
||||
@@ -152,55 +164,6 @@ class SearchChunk(object):
|
||||
chunk_region_end = region[0] - 1
|
||||
return chunk_region_start, chunk_region_end
|
||||
|
||||
def _find_chunk_info(self, input_trace, output_trace, start_idx, end_idx) -> List:
|
||||
"""
|
||||
Find chunk info for a region.
|
||||
|
||||
We are given the region start and region end, and need to find out all chunk info for it.
|
||||
We first loop every dim of start node and end node, to see if we can find dim pair,
|
||||
which is linked in a flow and not computed.
|
||||
If found, we then search flow in the whole region to find out all chunk infos.
|
||||
|
||||
Args:
|
||||
input_trace (List): node's input trace in region
|
||||
output_trace (List): node's output trace in region
|
||||
start_idx (int): region start node index
|
||||
end_idx (int): region end node index
|
||||
|
||||
Returns:
|
||||
chunk_infos: possible regions found
|
||||
"""
|
||||
start_traces = input_trace[start_idx]
|
||||
end_trace = output_trace[end_idx]
|
||||
end_node = self.trace_indice.node_list[end_idx]
|
||||
chunk_infos = []
|
||||
for end_dim, _ in enumerate(end_trace["indice"]):
|
||||
if len(start_traces) > 1:
|
||||
continue
|
||||
for start_node, start_trace in start_traces.items():
|
||||
for start_dim, _ in enumerate(start_trace["indice"]):
|
||||
# dim size cannot be 1
|
||||
if (get_node_shape(end_node)[end_dim] == 1 or get_node_shape(start_node)[start_dim] == 1):
|
||||
continue
|
||||
# must have users
|
||||
if len(end_node.users) == 0:
|
||||
continue
|
||||
# check index source align
|
||||
if not self.trace_flow.check_index_source(start_dim, start_node, start_idx, end_dim, end_node):
|
||||
continue
|
||||
# check index copmute
|
||||
if not self.trace_flow.check_index_compute(start_idx, end_dim, end_node, end_idx):
|
||||
continue
|
||||
# flow search
|
||||
chunk_info = self.trace_flow.flow_search(start_idx, start_dim, end_idx, end_dim)
|
||||
if chunk_info is None:
|
||||
continue
|
||||
# check index copmute
|
||||
if not self.trace_flow.check_index_duplicate(chunk_info):
|
||||
continue
|
||||
chunk_infos.append(chunk_info)
|
||||
return chunk_infos
|
||||
|
||||
def _search_possible_chunk_regions(self, max_chunk_region: Tuple, peak_node: Node) -> List:
|
||||
"""
|
||||
Search every possible region within the max chunk region.
|
||||
@@ -228,9 +191,8 @@ class SearchChunk(object):
|
||||
if is_non_compute_node(self.trace_indice.node_list[start_idx]) or is_non_compute_node(
|
||||
self.trace_indice.node_list[end_idx]):
|
||||
continue
|
||||
|
||||
# select free dim
|
||||
chunk_info = self._find_chunk_info(input_trace, output_trace, start_idx, end_idx)
|
||||
chunk_info = self.trace_flow.find_chunk_info(input_trace, output_trace, start_idx, end_idx)
|
||||
if len(chunk_info) > 0:
|
||||
possible_chunk_region.extend(chunk_info)
|
||||
return possible_chunk_region
|
||||
|
Reference in New Issue
Block a user