seperate trace flow

This commit is contained in:
oahzxl
2023-01-06 17:24:23 +08:00
parent 4748967fb1
commit a6cdbf9161
6 changed files with 447 additions and 424 deletions

View File

@@ -1,8 +1,10 @@
import copy
from .select_chunk import SelectChunk
from .trace_index import TraceIndex, ReorderGraph
from .trace_index import TraceIndex
from .reorder_graph import ReorderGraph
from .estiamte_memory import EstimateMemory
from .trace_flow import TraceFlow
from .utils import (
get_node_shape,
is_non_compute_node,
@@ -14,12 +16,13 @@ class SearchChunk(object):
def __init__(self, gm, max_memory=None, print_mem=False) -> None:
self.gm = gm
self.print_mem = print_mem
self.index_tracer = TraceIndex(list(gm.graph.nodes))
self.index_tracer.trace_index()
self.reorder_graph = ReorderGraph(self.index_tracer)
self.memory_estimator = EstimateMemory()
self.chunk_selector = SelectChunk(
self.index_tracer, self.memory_estimator, self.reorder_graph, max_memory=max_memory
self.trace_index = TraceIndex(list(gm.graph.nodes))
self.trace_index.trace_index()
self.trace_flow = TraceFlow(self.trace_index)
self.reorder_graph = ReorderGraph(self.trace_index)
self.estimate_memory = EstimateMemory()
self.select_chunk = SelectChunk(
self.trace_index, self.estimate_memory, self.reorder_graph, max_memory=max_memory
)
def _find_peak_node(self, mem_peak):
@@ -29,7 +32,7 @@ class SearchChunk(object):
def _get_free_var(self):
free_var_idx = []
for idx, n in enumerate(self.index_tracer.node_list):
for idx, n in enumerate(self.trace_index.node_list):
if n.op == "placeholder":
free_var_idx.append(idx)
return free_var_idx
@@ -99,7 +102,7 @@ class SearchChunk(object):
def _find_free_dim(self, input_trace, output_trace, start_idx, end_idx):
start_traces = input_trace[start_idx]
end_trace = output_trace[end_idx]
end_node = self.index_tracer.node_list[end_idx]
end_node = self.trace_index.node_list[end_idx]
chunk_infos = []
for end_dim, _ in enumerate(end_trace["idx"]):
if len(start_traces) > 1:
@@ -113,46 +116,46 @@ class SearchChunk(object):
):
continue
# check index source align
if not self.index_tracer.check_index_source(
if not self.trace_flow.check_index_source(
start_dim, start_node, start_idx, end_dim, end_node
):
continue
# check index copmute
if not self.index_tracer.check_index_compute(
if not self.trace_flow.check_index_compute(
start_idx, end_dim, end_node, end_idx
):
continue
# flow search
chunk_info = self.index_tracer.flow_search(
chunk_info = self.trace_flow.flow_search(
start_idx, start_dim, end_idx, end_dim
)
if chunk_info is None:
continue
# check index copmute
if not self.index_tracer.check_index_duplicate(chunk_info):
if not self.trace_flow.check_index_duplicate(chunk_info):
continue
chunk_infos.append(chunk_info)
return chunk_infos
def _search_possible_chunk_regions(self, max_chunk_region, peak_node):
possible_chunk_region = []
output_trace = copy.deepcopy(self.index_tracer.idx_trace_list)
output_trace = copy.deepcopy(self.trace_index.idx_trace_list)
input_trace = [] # trace of a node's input nodes
for _, n in enumerate(self.index_tracer.node_list):
for _, n in enumerate(self.trace_index.node_list):
cur_trace = {}
for arg in n.args:
if type(arg) == type(n) and not is_non_compute_node_except_placeholder(
arg
):
cur_trace[arg] = self.index_tracer._find_trace_from_node(arg)
cur_trace[arg] = self.trace_index._find_trace_from_node(arg)
input_trace.append(cur_trace)
for start_idx in range(max_chunk_region[0], peak_node + 1):
for end_idx in range(peak_node, max_chunk_region[1] + 1):
# skip non compute nodes
if is_non_compute_node(
self.index_tracer.node_list[start_idx]
) or is_non_compute_node(self.index_tracer.node_list[end_idx]):
self.trace_index.node_list[start_idx]
) or is_non_compute_node(self.trace_index.node_list[end_idx]):
continue
# select free dim
@@ -173,7 +176,7 @@ class SearchChunk(object):
possible_chunk_regions = self._search_possible_chunk_regions(
max_chunk_region, peak_node
)
best_chunk_region = self.chunk_selector._select_best_chunk_region(
best_chunk_region = self.select_chunk._select_best_chunk_region(
possible_chunk_regions, chunk_regions, peak_node, max_chunk_region, mem_peak
)
best_chunk_region = self.reorder_graph.reorder_all(best_chunk_region)
@@ -191,8 +194,8 @@ class SearchChunk(object):
init_mem_peak,
_,
active_node,
) = self.memory_estimator.estimate_chunk_inference_mem(
self.index_tracer.node_list
) = self.estimate_memory.estimate_chunk_inference_mem(
self.trace_index.node_list
)
mem_peak = init_mem_peak
@@ -206,14 +209,14 @@ class SearchChunk(object):
mem_peak,
_,
active_node,
) = self.memory_estimator.estimate_chunk_inference_mem(
self.index_tracer.node_list, chunk_infos
) = self.estimate_memory.estimate_chunk_inference_mem(
self.trace_index.node_list, chunk_infos
)
if self._stop_search(init_mem_peak, mem_peak):
break
if self.print_mem:
self.print_mem = False
self.memory_estimator.estimate_chunk_inference_mem(
self.index_tracer.node_list, chunk_infos, print_mem=True
self.estimate_memory.estimate_chunk_inference_mem(
self.trace_index.node_list, chunk_infos, print_mem=True
)
return chunk_infos