rename trace_index to trace_indice

This commit is contained in:
oahzxl
2023-01-09 17:25:13 +08:00
parent 065f0b4c27
commit 0ea903b94e
6 changed files with 74 additions and 74 deletions

View File

@@ -7,7 +7,7 @@ from .estimate_memory import EstimateMemory
from .reorder_graph import ReorderGraph
from .select_chunk import SelectChunk
from .trace_flow import TraceFlow
from .trace_index import TraceIndex
from .trace_indice import TraceIndice
from .utils import (
get_node_shape,
is_non_compute_node,
@@ -47,13 +47,13 @@ class SearchChunk(object):
def __init__(self, gm, max_memory=None, print_mem=False) -> None:
self.gm = gm
self.print_mem = print_mem
self.trace_index = TraceIndex(list(gm.graph.nodes))
self.trace_index.trace_index()
self.trace_flow = TraceFlow(self.trace_index)
self.reorder_graph = ReorderGraph(self.trace_index)
self.trace_indice = TraceIndice(list(gm.graph.nodes))
self.trace_indice.trace_index()
self.trace_flow = TraceFlow(self.trace_indice)
self.reorder_graph = ReorderGraph(self.trace_indice)
self.estimate_memory = EstimateMemory()
self.select_chunk = SelectChunk(
self.trace_index,
self.trace_indice,
self.estimate_memory,
self.reorder_graph,
max_memory=max_memory,
@@ -72,7 +72,7 @@ class SearchChunk(object):
free_var_idx (List): all indexs of free vars
"""
free_var_idx = []
for idx, n in enumerate(self.trace_index.node_list):
for idx, n in enumerate(self.trace_indice.node_list):
if n.op == "placeholder":
free_var_idx.append(idx)
return free_var_idx
@@ -156,7 +156,7 @@ class SearchChunk(object):
"""
start_traces = input_trace[start_idx]
end_trace = output_trace[end_idx]
end_node = self.trace_index.node_list[end_idx]
end_node = self.trace_indice.node_list[end_idx]
chunk_infos = []
for end_dim, _ in enumerate(end_trace["idx"]):
if len(start_traces) > 1:
@@ -205,23 +205,23 @@ class SearchChunk(object):
possible_chunk_region (List)
"""
possible_chunk_region = []
output_trace = copy.deepcopy(self.trace_index.idx_trace_list)
output_trace = copy.deepcopy(self.trace_indice.idx_trace_list)
input_trace = [] # trace of a node's input nodes
for _, n in enumerate(self.trace_index.node_list):
for _, n in enumerate(self.trace_indice.node_list):
cur_trace = {}
for arg in n.args:
if type(arg) == type(n) and not is_non_compute_node_except_placeholder(
arg
):
cur_trace[arg] = self.trace_index._find_trace_from_node(arg)
cur_trace[arg] = self.trace_indice._find_trace_from_node(arg)
input_trace.append(cur_trace)
for start_idx in range(max_chunk_region[0], peak_node + 1):
for end_idx in range(peak_node, max_chunk_region[1] + 1):
# skip non compute nodes
if is_non_compute_node(
self.trace_index.node_list[start_idx]
) or is_non_compute_node(self.trace_index.node_list[end_idx]):
self.trace_indice.node_list[start_idx]
) or is_non_compute_node(self.trace_indice.node_list[end_idx]):
continue
# select free dim
@@ -292,7 +292,7 @@ class SearchChunk(object):
_,
active_node,
) = self.estimate_memory.estimate_chunk_inference_mem(
self.trace_index.node_list
self.trace_indice.node_list
)
mem_peak = init_mem_peak
@@ -307,13 +307,13 @@ class SearchChunk(object):
_,
active_node,
) = self.estimate_memory.estimate_chunk_inference_mem(
self.trace_index.node_list, chunk_infos
self.trace_indice.node_list, chunk_infos
)
if self._stop_search(init_mem_peak, mem_peak):
break
if self.print_mem:
self.print_mem = False
self.estimate_memory.estimate_chunk_inference_mem(
self.trace_index.node_list, chunk_infos, print_mem=True
self.trace_indice.node_list, chunk_infos, print_mem=True
)
return chunk_infos