rename trace_index to trace_indice

This commit is contained in:
oahzxl
2023-01-09 17:25:13 +08:00
parent 065f0b4c27
commit 0ea903b94e
6 changed files with 74 additions and 74 deletions

View File

@@ -1,19 +1,19 @@
from .estimate_memory import EstimateMemory
from .reorder_graph import ReorderGraph
from .trace_index import TraceIndex
from .trace_indice import TraceIndice
from .utils import is_non_compute_node
class SelectChunk(object):
def __init__(
self,
trace_index: TraceIndex,
trace_indice: TraceIndice,
estimate_memory: EstimateMemory,
reorder_graph: ReorderGraph,
max_memory=None,
):
self.index_tracer = trace_index
self.memory_estimator = estimate_memory
self.trace_indice = trace_indice
self.estimate_memory = estimate_memory
self.reorder_graph = reorder_graph
if max_memory is not None:
self.stratge = "fit_memory"
@@ -68,10 +68,10 @@ class SelectChunk(object):
for region in possible_chunk_regions:
cur_region = region.copy()
cur_node_list, cur_region = self.reorder_graph.tmp_reorder(
self.index_tracer.node_list, cur_region
self.trace_indice.node_list, cur_region
)
cur_chunk_infos = chunk_infos + [cur_region]
cur_mem_peak = self.memory_estimator.estimate_chunk_inference_mem(
cur_mem_peak = self.estimate_memory.estimate_chunk_inference_mem(
cur_node_list, cur_chunk_infos
)[0]
cur_chunk_region_peak = cur_mem_peak[
@@ -113,7 +113,7 @@ class SelectChunk(object):
chunk_size *= 2
reorder_chunk_info["chunk_size"] = chunk_size
cur_chunk_infos = chunk_infos + [reorder_chunk_info]
cur_mem_peak = self.memory_estimator.estimate_chunk_inference_mem(
cur_mem_peak = self.estimate_memory.estimate_chunk_inference_mem(
chunk_region_dict["reorder_node_list"], cur_chunk_infos
)[0]
cur_chunk_max_mem = max(
@@ -139,7 +139,7 @@ class SelectChunk(object):
mid = int((left + right) / 2 + 0.5)
chunk_info["chunk_size"] = mid
cur_chunk_infos = chunk_infos + [chunk_info]
cur_mem_peak = self.memory_estimator.estimate_chunk_inference_mem(
cur_mem_peak = self.estimate_memory.estimate_chunk_inference_mem(
chunk_region_dict["reorder_node_list"], cur_chunk_infos
)[0]
cur_chunk_max_mem = max(
@@ -153,7 +153,7 @@ class SelectChunk(object):
def _get_compute_node_num(self, start, end):
count = 0
for i in self.index_tracer.node_list[start : end + 1]:
for i in self.trace_indice.node_list[start : end + 1]:
if not is_non_compute_node(i):
count += 1
return count
@@ -178,10 +178,10 @@ class SelectChunk(object):
for region in possible_chunk_regions:
cur_region = region.copy()
cur_node_list, cur_region = self.reorder_graph.tmp_reorder(
self.index_tracer.node_list, cur_region
self.trace_indice.node_list, cur_region
)
cur_chunk_infos = chunk_infos + [cur_region]
cur_mem_peak = self.memory_estimator.estimate_chunk_inference_mem(
cur_mem_peak = self.estimate_memory.estimate_chunk_inference_mem(
cur_node_list, cur_chunk_infos
)[0]
cur_chunk_region_peak = cur_mem_peak[