mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-07 20:10:17 +00:00
rename trace_index to trace_indice
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
from .trace_index import TraceIndex
|
||||
from .trace_indice import TraceIndice
|
||||
from .utils import (
|
||||
find_chunk_all_input_nodes,
|
||||
find_chunk_compute_input_and_output_nodes,
|
||||
@@ -10,8 +10,8 @@ from .utils import (
|
||||
|
||||
|
||||
class TraceFlow(object):
|
||||
def __init__(self, trace_index: TraceIndex) -> None:
|
||||
self.trace_index = trace_index
|
||||
def __init__(self, trace_indice: TraceIndice) -> None:
|
||||
self.trace_indice = trace_indice
|
||||
|
||||
def check_index_source(self, start_dim, start_node, start_idx, end_dim, end_node):
|
||||
"""
|
||||
@@ -25,8 +25,8 @@ class TraceFlow(object):
|
||||
Returns:
|
||||
bool: True if check pass
|
||||
"""
|
||||
start_node_idx = find_idx_by_name(start_node.name, self.trace_index.node_list)
|
||||
end_node_trace = self.trace_index._find_trace_from_node(end_node)
|
||||
start_node_idx = find_idx_by_name(start_node.name, self.trace_indice.node_list)
|
||||
end_node_trace = self.trace_indice._find_trace_from_node(end_node)
|
||||
end_node_trace_source = end_node_trace["source"][end_dim]
|
||||
sorted_source = sorted(
|
||||
end_node_trace_source.items(), key=lambda d: d[0], reverse=True
|
||||
@@ -51,24 +51,24 @@ class TraceFlow(object):
|
||||
Returns:
|
||||
bool: True if check pass
|
||||
"""
|
||||
end_node_trace = self.trace_index._find_trace_from_node(end_node)
|
||||
end_node_trace = self.trace_indice._find_trace_from_node(end_node)
|
||||
end_node_compute = end_node_trace["compute"][end_dim]
|
||||
if any(start_idx <= i <= end_idx for i in end_node_compute):
|
||||
return False
|
||||
return True
|
||||
|
||||
def get_node_chunk_dim(self, node_from, node_from_dim, node_to):
|
||||
node_from_source = self.trace_index._find_source_trace_from_node(node_from)
|
||||
node_from_source = self.trace_indice._find_source_trace_from_node(node_from)
|
||||
dim_source = node_from_source[node_from_dim]
|
||||
node_to_idx = find_idx_by_name(node_to.name, self.trace_index.node_list)
|
||||
node_to_idx = find_idx_by_name(node_to.name, self.trace_indice.node_list)
|
||||
for k, v in dim_source.items():
|
||||
if k == node_to_idx:
|
||||
return v
|
||||
return None
|
||||
|
||||
def _find_inherit_dim(self, input_node, input_dim, node):
|
||||
input_node_idx = find_idx_by_name(input_node.name, self.trace_index.node_list)
|
||||
node_trace_source = self.trace_index._find_source_trace_from_node(node)
|
||||
input_node_idx = find_idx_by_name(input_node.name, self.trace_indice.node_list)
|
||||
node_trace_source = self.trace_indice._find_source_trace_from_node(node)
|
||||
for node_dim in range(len(get_node_shape(node))):
|
||||
if (
|
||||
input_node_idx in node_trace_source[node_dim]
|
||||
@@ -82,19 +82,19 @@ class TraceFlow(object):
|
||||
for input_node_idx, input_node in enumerate(chunk_infos["inputs"]):
|
||||
for k, v in chunk_infos["inputs_dim"][input_node_idx].items():
|
||||
inherit_dim = self._find_inherit_dim(
|
||||
input_node, v, self.trace_index.node_list[k]
|
||||
input_node, v, self.trace_indice.node_list[k]
|
||||
)
|
||||
if inherit_dim:
|
||||
input_dim_after_node[k] = inherit_dim
|
||||
|
||||
for node in self.trace_index.node_list[
|
||||
for node in self.trace_indice.node_list[
|
||||
chunk_infos["region"][0] : chunk_infos["region"][1] + 1
|
||||
]:
|
||||
if is_non_compute_node_except_placeholder(node):
|
||||
continue
|
||||
count = 0
|
||||
duplicate_dims = []
|
||||
node_trace_source = self.trace_index._find_source_trace_from_node(node)
|
||||
node_trace_source = self.trace_indice._find_source_trace_from_node(node)
|
||||
for node_dim in range(len(get_node_shape(node))):
|
||||
duplicate_dim = []
|
||||
duplicate_flag = False
|
||||
@@ -130,7 +130,7 @@ class TraceFlow(object):
|
||||
all_node_info,
|
||||
next_node_list,
|
||||
):
|
||||
arg_idx = find_idx_by_name(arg_node.name, self.trace_index.node_list)
|
||||
arg_idx = find_idx_by_name(arg_node.name, self.trace_indice.node_list)
|
||||
# arg in chunk range or be inputs
|
||||
if not (start_idx <= arg_idx < end_idx):
|
||||
return True
|
||||
@@ -171,7 +171,7 @@ class TraceFlow(object):
|
||||
|
||||
def _get_all_node_info(self, end_dim, start_idx, end_idx):
|
||||
cur_node_list = [
|
||||
self.trace_index.node_list[end_idx]
|
||||
self.trace_indice.node_list[end_idx]
|
||||
] # start from the last node
|
||||
all_node_info = {cur_node_list[0]: {"chunk_dim": end_dim, "fix_dim": []}}
|
||||
|
||||
@@ -183,10 +183,10 @@ class TraceFlow(object):
|
||||
cur_node_chunk_dim = all_node_info[cur_node]["chunk_dim"]
|
||||
cur_node_fix_dim = all_node_info[cur_node]["fix_dim"]
|
||||
if cur_node_chunk_dim:
|
||||
cur_node_compute = self.trace_index._find_compute_trace_from_node(
|
||||
cur_node_compute = self.trace_indice._find_compute_trace_from_node(
|
||||
cur_node
|
||||
)
|
||||
cur_node_source = self.trace_index._find_source_trace_from_node(
|
||||
cur_node_source = self.trace_indice._find_source_trace_from_node(
|
||||
cur_node
|
||||
)
|
||||
else:
|
||||
@@ -220,7 +220,7 @@ class TraceFlow(object):
|
||||
if not (
|
||||
start_idx
|
||||
<= find_idx_by_name(
|
||||
arg.name, self.trace_index.node_list
|
||||
arg.name, self.trace_indice.node_list
|
||||
)
|
||||
< end_idx
|
||||
):
|
||||
@@ -250,16 +250,16 @@ class TraceFlow(object):
|
||||
for input_node in inputs:
|
||||
input_dict = {}
|
||||
input_node_idx = find_idx_by_name(
|
||||
input_node.name, self.trace_index.node_list
|
||||
input_node.name, self.trace_indice.node_list
|
||||
)
|
||||
for user in input_node.users.keys():
|
||||
if is_non_compute_node(user):
|
||||
continue
|
||||
user_idx = find_idx_by_name(user.name, self.trace_index.node_list)
|
||||
user_idx = find_idx_by_name(user.name, self.trace_indice.node_list)
|
||||
if start_idx <= user_idx <= end_idx:
|
||||
chunk_dim = all_node_info[user]["chunk_dim"]
|
||||
if chunk_dim is not None:
|
||||
user_source = self.trace_index._find_source_trace_from_node(
|
||||
user_source = self.trace_indice._find_source_trace_from_node(
|
||||
user
|
||||
)[chunk_dim]
|
||||
if input_node_idx in user_source:
|
||||
@@ -282,7 +282,7 @@ class TraceFlow(object):
|
||||
if node_info["chunk_dim"] is None:
|
||||
maybe_prepose_nodes.append(node)
|
||||
maybe_prepose_nodes.sort(
|
||||
key=lambda x: find_idx_by_name(x.name, self.trace_index.node_list),
|
||||
key=lambda x: find_idx_by_name(x.name, self.trace_indice.node_list),
|
||||
reverse=True,
|
||||
) # from last node to first node
|
||||
prepose_nodes = []
|
||||
@@ -308,7 +308,7 @@ class TraceFlow(object):
|
||||
if not (
|
||||
start_idx
|
||||
<= find_idx_by_name(
|
||||
cur_prepose_node_arg.name, self.trace_index.node_list
|
||||
cur_prepose_node_arg.name, self.trace_indice.node_list
|
||||
)
|
||||
< end_idx
|
||||
):
|
||||
@@ -336,14 +336,14 @@ class TraceFlow(object):
|
||||
maybe_prepose_nodes.remove(n)
|
||||
# sort by index
|
||||
prepose_nodes.sort(
|
||||
key=lambda x: find_idx_by_name(x.name, self.trace_index.node_list)
|
||||
key=lambda x: find_idx_by_name(x.name, self.trace_indice.node_list)
|
||||
)
|
||||
|
||||
return prepose_nodes
|
||||
|
||||
def _get_non_chunk_inputs(self, chunk_info, start_idx, end_idx):
|
||||
# we need to log input nodes to avoid deleteing them in the loop
|
||||
chunk_node_list = self.trace_index.node_list[start_idx : end_idx + 1]
|
||||
chunk_node_list = self.trace_indice.node_list[start_idx : end_idx + 1]
|
||||
# also need to get some prepose node's arg out of non_chunk_inputs
|
||||
for n in chunk_info["args"]["prepose_nodes"]:
|
||||
chunk_node_list.remove(n)
|
||||
@@ -355,7 +355,7 @@ class TraceFlow(object):
|
||||
|
||||
def flow_search(self, start_idx, start_dim, end_idx, end_dim):
|
||||
inputs, outputs = find_chunk_compute_input_and_output_nodes(
|
||||
self.trace_index.node_list[start_idx : end_idx + 1]
|
||||
self.trace_indice.node_list[start_idx : end_idx + 1]
|
||||
)
|
||||
# only single ouput
|
||||
if len(outputs) > 1:
|
||||
@@ -403,10 +403,10 @@ class TraceFlow(object):
|
||||
chunk_shape = get_node_shape(chunk_info["outputs"][0])[
|
||||
chunk_info["outputs_dim"]
|
||||
]
|
||||
for node in self.trace_index.node_list[chunk_region[0] : chunk_region[1] + 1]:
|
||||
for node in self.trace_indice.node_list[chunk_region[0] : chunk_region[1] + 1]:
|
||||
if any(i in node.name for i in ["reshape", "view"]):
|
||||
reshape_args = node.args[1:]
|
||||
reshape_log = self.trace_index.idx_view_list[node]
|
||||
reshape_log = self.trace_indice.idx_view_list[node]
|
||||
chunk_dim = chunk_info["node_chunk_dim"][node]["chunk_dim"]
|
||||
reshape_size[node.name] = {}
|
||||
for reshape_arg_dim, reshape_arg in enumerate(reshape_args):
|
||||
|
Reference in New Issue
Block a user