mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-08 12:30:42 +00:00
rename trace_index to trace_indice
This commit is contained in:
@@ -7,7 +7,7 @@ from .estimate_memory import EstimateMemory
|
||||
from .reorder_graph import ReorderGraph
|
||||
from .select_chunk import SelectChunk
|
||||
from .trace_flow import TraceFlow
|
||||
from .trace_index import TraceIndex
|
||||
from .trace_indice import TraceIndice
|
||||
from .utils import (
|
||||
get_node_shape,
|
||||
is_non_compute_node,
|
||||
@@ -47,13 +47,13 @@ class SearchChunk(object):
|
||||
def __init__(self, gm, max_memory=None, print_mem=False) -> None:
|
||||
self.gm = gm
|
||||
self.print_mem = print_mem
|
||||
self.trace_index = TraceIndex(list(gm.graph.nodes))
|
||||
self.trace_index.trace_index()
|
||||
self.trace_flow = TraceFlow(self.trace_index)
|
||||
self.reorder_graph = ReorderGraph(self.trace_index)
|
||||
self.trace_indice = TraceIndice(list(gm.graph.nodes))
|
||||
self.trace_indice.trace_index()
|
||||
self.trace_flow = TraceFlow(self.trace_indice)
|
||||
self.reorder_graph = ReorderGraph(self.trace_indice)
|
||||
self.estimate_memory = EstimateMemory()
|
||||
self.select_chunk = SelectChunk(
|
||||
self.trace_index,
|
||||
self.trace_indice,
|
||||
self.estimate_memory,
|
||||
self.reorder_graph,
|
||||
max_memory=max_memory,
|
||||
@@ -72,7 +72,7 @@ class SearchChunk(object):
|
||||
free_var_idx (List): all indexs of free vars
|
||||
"""
|
||||
free_var_idx = []
|
||||
for idx, n in enumerate(self.trace_index.node_list):
|
||||
for idx, n in enumerate(self.trace_indice.node_list):
|
||||
if n.op == "placeholder":
|
||||
free_var_idx.append(idx)
|
||||
return free_var_idx
|
||||
@@ -156,7 +156,7 @@ class SearchChunk(object):
|
||||
"""
|
||||
start_traces = input_trace[start_idx]
|
||||
end_trace = output_trace[end_idx]
|
||||
end_node = self.trace_index.node_list[end_idx]
|
||||
end_node = self.trace_indice.node_list[end_idx]
|
||||
chunk_infos = []
|
||||
for end_dim, _ in enumerate(end_trace["idx"]):
|
||||
if len(start_traces) > 1:
|
||||
@@ -205,23 +205,23 @@ class SearchChunk(object):
|
||||
possible_chunk_region (List)
|
||||
"""
|
||||
possible_chunk_region = []
|
||||
output_trace = copy.deepcopy(self.trace_index.idx_trace_list)
|
||||
output_trace = copy.deepcopy(self.trace_indice.idx_trace_list)
|
||||
input_trace = [] # trace of a node's input nodes
|
||||
for _, n in enumerate(self.trace_index.node_list):
|
||||
for _, n in enumerate(self.trace_indice.node_list):
|
||||
cur_trace = {}
|
||||
for arg in n.args:
|
||||
if type(arg) == type(n) and not is_non_compute_node_except_placeholder(
|
||||
arg
|
||||
):
|
||||
cur_trace[arg] = self.trace_index._find_trace_from_node(arg)
|
||||
cur_trace[arg] = self.trace_indice._find_trace_from_node(arg)
|
||||
input_trace.append(cur_trace)
|
||||
|
||||
for start_idx in range(max_chunk_region[0], peak_node + 1):
|
||||
for end_idx in range(peak_node, max_chunk_region[1] + 1):
|
||||
# skip non compute nodes
|
||||
if is_non_compute_node(
|
||||
self.trace_index.node_list[start_idx]
|
||||
) or is_non_compute_node(self.trace_index.node_list[end_idx]):
|
||||
self.trace_indice.node_list[start_idx]
|
||||
) or is_non_compute_node(self.trace_indice.node_list[end_idx]):
|
||||
continue
|
||||
|
||||
# select free dim
|
||||
@@ -292,7 +292,7 @@ class SearchChunk(object):
|
||||
_,
|
||||
active_node,
|
||||
) = self.estimate_memory.estimate_chunk_inference_mem(
|
||||
self.trace_index.node_list
|
||||
self.trace_indice.node_list
|
||||
)
|
||||
mem_peak = init_mem_peak
|
||||
|
||||
@@ -307,13 +307,13 @@ class SearchChunk(object):
|
||||
_,
|
||||
active_node,
|
||||
) = self.estimate_memory.estimate_chunk_inference_mem(
|
||||
self.trace_index.node_list, chunk_infos
|
||||
self.trace_indice.node_list, chunk_infos
|
||||
)
|
||||
if self._stop_search(init_mem_peak, mem_peak):
|
||||
break
|
||||
if self.print_mem:
|
||||
self.print_mem = False
|
||||
self.estimate_memory.estimate_chunk_inference_mem(
|
||||
self.trace_index.node_list, chunk_infos, print_mem=True
|
||||
self.trace_indice.node_list, chunk_infos, print_mem=True
|
||||
)
|
||||
return chunk_infos
|
||||
|
Reference in New Issue
Block a user