mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-07 20:10:17 +00:00
rename trace_index to trace_indice
This commit is contained in:
@@ -1,19 +1,19 @@
|
||||
from .estimate_memory import EstimateMemory
|
||||
from .reorder_graph import ReorderGraph
|
||||
from .trace_index import TraceIndex
|
||||
from .trace_indice import TraceIndice
|
||||
from .utils import is_non_compute_node
|
||||
|
||||
|
||||
class SelectChunk(object):
|
||||
def __init__(
|
||||
self,
|
||||
trace_index: TraceIndex,
|
||||
trace_indice: TraceIndice,
|
||||
estimate_memory: EstimateMemory,
|
||||
reorder_graph: ReorderGraph,
|
||||
max_memory=None,
|
||||
):
|
||||
self.index_tracer = trace_index
|
||||
self.memory_estimator = estimate_memory
|
||||
self.trace_indice = trace_indice
|
||||
self.estimate_memory = estimate_memory
|
||||
self.reorder_graph = reorder_graph
|
||||
if max_memory is not None:
|
||||
self.stratge = "fit_memory"
|
||||
@@ -68,10 +68,10 @@ class SelectChunk(object):
|
||||
for region in possible_chunk_regions:
|
||||
cur_region = region.copy()
|
||||
cur_node_list, cur_region = self.reorder_graph.tmp_reorder(
|
||||
self.index_tracer.node_list, cur_region
|
||||
self.trace_indice.node_list, cur_region
|
||||
)
|
||||
cur_chunk_infos = chunk_infos + [cur_region]
|
||||
cur_mem_peak = self.memory_estimator.estimate_chunk_inference_mem(
|
||||
cur_mem_peak = self.estimate_memory.estimate_chunk_inference_mem(
|
||||
cur_node_list, cur_chunk_infos
|
||||
)[0]
|
||||
cur_chunk_region_peak = cur_mem_peak[
|
||||
@@ -113,7 +113,7 @@ class SelectChunk(object):
|
||||
chunk_size *= 2
|
||||
reorder_chunk_info["chunk_size"] = chunk_size
|
||||
cur_chunk_infos = chunk_infos + [reorder_chunk_info]
|
||||
cur_mem_peak = self.memory_estimator.estimate_chunk_inference_mem(
|
||||
cur_mem_peak = self.estimate_memory.estimate_chunk_inference_mem(
|
||||
chunk_region_dict["reorder_node_list"], cur_chunk_infos
|
||||
)[0]
|
||||
cur_chunk_max_mem = max(
|
||||
@@ -139,7 +139,7 @@ class SelectChunk(object):
|
||||
mid = int((left + right) / 2 + 0.5)
|
||||
chunk_info["chunk_size"] = mid
|
||||
cur_chunk_infos = chunk_infos + [chunk_info]
|
||||
cur_mem_peak = self.memory_estimator.estimate_chunk_inference_mem(
|
||||
cur_mem_peak = self.estimate_memory.estimate_chunk_inference_mem(
|
||||
chunk_region_dict["reorder_node_list"], cur_chunk_infos
|
||||
)[0]
|
||||
cur_chunk_max_mem = max(
|
||||
@@ -153,7 +153,7 @@ class SelectChunk(object):
|
||||
|
||||
def _get_compute_node_num(self, start, end):
|
||||
count = 0
|
||||
for i in self.index_tracer.node_list[start : end + 1]:
|
||||
for i in self.trace_indice.node_list[start : end + 1]:
|
||||
if not is_non_compute_node(i):
|
||||
count += 1
|
||||
return count
|
||||
@@ -178,10 +178,10 @@ class SelectChunk(object):
|
||||
for region in possible_chunk_regions:
|
||||
cur_region = region.copy()
|
||||
cur_node_list, cur_region = self.reorder_graph.tmp_reorder(
|
||||
self.index_tracer.node_list, cur_region
|
||||
self.trace_indice.node_list, cur_region
|
||||
)
|
||||
cur_chunk_infos = chunk_infos + [cur_region]
|
||||
cur_mem_peak = self.memory_estimator.estimate_chunk_inference_mem(
|
||||
cur_mem_peak = self.estimate_memory.estimate_chunk_inference_mem(
|
||||
cur_node_list, cur_chunk_infos
|
||||
)[0]
|
||||
cur_chunk_region_peak = cur_mem_peak[
|
||||
|
Reference in New Issue
Block a user