mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-07-18 17:31:53 +00:00
remove flow tracer
This commit is contained in:
parent
4d89525fc2
commit
4f5e105af3
171
chunk_codegen.py
171
chunk_codegen.py
@ -67,7 +67,7 @@ def _is_non_compute_node_except_placeholder_output(node):
|
|||||||
class IndexTracer(object):
|
class IndexTracer(object):
|
||||||
def __init__(self, gm) -> None:
|
def __init__(self, gm) -> None:
|
||||||
self.gm = gm
|
self.gm = gm
|
||||||
self.nodes_list = list(gm.graph.nodes)
|
self.node_list = list(gm.graph.nodes)
|
||||||
self.idx_trace_list = self._init_idx_trace_list()
|
self.idx_trace_list = self._init_idx_trace_list()
|
||||||
self.idx_trace_equal = []
|
self.idx_trace_equal = []
|
||||||
self.idx_view_list = []
|
self.idx_view_list = []
|
||||||
@ -75,7 +75,7 @@ class IndexTracer(object):
|
|||||||
|
|
||||||
def _init_idx_trace_list(self):
|
def _init_idx_trace_list(self):
|
||||||
idx_trace_list = []
|
idx_trace_list = []
|
||||||
for n in self.nodes_list:
|
for n in self.node_list:
|
||||||
if _get_node_shape(n) != None:
|
if _get_node_shape(n) != None:
|
||||||
cur_trace = {
|
cur_trace = {
|
||||||
"idx": [None for _ in range(len(_get_node_shape(n)))],
|
"idx": [None for _ in range(len(_get_node_shape(n)))],
|
||||||
@ -136,7 +136,7 @@ class IndexTracer(object):
|
|||||||
node_from_trace = self._find_trace_from_node(node_from)
|
node_from_trace = self._find_trace_from_node(node_from)
|
||||||
node_to_dim = self._transform_index(node_to, node_to_dim)
|
node_to_dim = self._transform_index(node_to, node_to_dim)
|
||||||
node_to_trace = self._find_trace_from_node(node_to)
|
node_to_trace = self._find_trace_from_node(node_to)
|
||||||
node_from_idx = _find_idx_by_name(node_from.name, self.nodes_list)
|
node_from_idx = _find_idx_by_name(node_from.name, self.node_list)
|
||||||
if init:
|
if init:
|
||||||
node_to_trace["source"][node_to_dim] = {}
|
node_to_trace["source"][node_to_dim] = {}
|
||||||
# add dim to cur new source
|
# add dim to cur new source
|
||||||
@ -210,7 +210,7 @@ class IndexTracer(object):
|
|||||||
idx (list): idx of the node
|
idx (list): idx of the node
|
||||||
compute (list): computed idx of the node.
|
compute (list): computed idx of the node.
|
||||||
"""
|
"""
|
||||||
node_idx = _find_idx_by_name(node.name, self.nodes_list)
|
node_idx = _find_idx_by_name(node.name, self.node_list)
|
||||||
node_dict = self.idx_trace_list[node_idx]
|
node_dict = self.idx_trace_list[node_idx]
|
||||||
return node_dict
|
return node_dict
|
||||||
|
|
||||||
@ -224,7 +224,7 @@ class IndexTracer(object):
|
|||||||
idx (list): idx of the node
|
idx (list): idx of the node
|
||||||
compute (list): computed idx of the node.
|
compute (list): computed idx of the node.
|
||||||
"""
|
"""
|
||||||
node_idx = _find_idx_by_name(node.name, self.nodes_list)
|
node_idx = _find_idx_by_name(node.name, self.node_list)
|
||||||
node_dict = self.idx_trace_list[node_idx]
|
node_dict = self.idx_trace_list[node_idx]
|
||||||
return node_dict["source"]
|
return node_dict["source"]
|
||||||
|
|
||||||
@ -237,7 +237,7 @@ class IndexTracer(object):
|
|||||||
Returns:
|
Returns:
|
||||||
idx (list): idx of the node
|
idx (list): idx of the node
|
||||||
"""
|
"""
|
||||||
node_idx = _find_idx_by_name(node.name, self.nodes_list)
|
node_idx = _find_idx_by_name(node.name, self.node_list)
|
||||||
return self.idx_trace_list[node_idx]["idx"]
|
return self.idx_trace_list[node_idx]["idx"]
|
||||||
|
|
||||||
def _find_compute_trace_from_node(self, node):
|
def _find_compute_trace_from_node(self, node):
|
||||||
@ -249,7 +249,7 @@ class IndexTracer(object):
|
|||||||
Returns:
|
Returns:
|
||||||
compute (list): computed idx of the node.
|
compute (list): computed idx of the node.
|
||||||
"""
|
"""
|
||||||
node_idx = _find_idx_by_name(node.name, self.nodes_list)
|
node_idx = _find_idx_by_name(node.name, self.node_list)
|
||||||
return self.idx_trace_list[node_idx]["compute"]
|
return self.idx_trace_list[node_idx]["compute"]
|
||||||
|
|
||||||
def _assign_index_as_input(self, node, node_idx, input_node=None):
|
def _assign_index_as_input(self, node, node_idx, input_node=None):
|
||||||
@ -262,7 +262,7 @@ class IndexTracer(object):
|
|||||||
"""
|
"""
|
||||||
if input_node == None:
|
if input_node == None:
|
||||||
input_node = node.args[0]
|
input_node = node.args[0]
|
||||||
input_node_idx = _find_idx_by_name(input_node.name, self.nodes_list)
|
input_node_idx = _find_idx_by_name(input_node.name, self.node_list)
|
||||||
input_node_idx_trace = self.idx_trace_list[input_node_idx]["idx"]
|
input_node_idx_trace = self.idx_trace_list[input_node_idx]["idx"]
|
||||||
|
|
||||||
new_idx_trace = copy.deepcopy(input_node_idx_trace)
|
new_idx_trace = copy.deepcopy(input_node_idx_trace)
|
||||||
@ -591,7 +591,7 @@ class IndexTracer(object):
|
|||||||
]
|
]
|
||||||
|
|
||||||
def trace_index(self):
|
def trace_index(self):
|
||||||
for idx, node in enumerate(self.nodes_list):
|
for idx, node in enumerate(self.node_list):
|
||||||
if node.op == "placeholder":
|
if node.op == "placeholder":
|
||||||
self._assign_all_index(node, idx)
|
self._assign_all_index(node, idx)
|
||||||
elif node.op == "call_method":
|
elif node.op == "call_method":
|
||||||
@ -655,7 +655,7 @@ class IndexTracer(object):
|
|||||||
Returns:
|
Returns:
|
||||||
bool: True if check pass
|
bool: True if check pass
|
||||||
"""
|
"""
|
||||||
start_node_idx = _find_idx_by_name(start_node.name, self.nodes_list)
|
start_node_idx = _find_idx_by_name(start_node.name, self.node_list)
|
||||||
end_node_trace = self._find_trace_from_node(end_node)
|
end_node_trace = self._find_trace_from_node(end_node)
|
||||||
end_node_trace_source = end_node_trace["source"][end_dim]
|
end_node_trace_source = end_node_trace["source"][end_dim]
|
||||||
sorted_source = sorted(
|
sorted_source = sorted(
|
||||||
@ -690,14 +690,14 @@ class IndexTracer(object):
|
|||||||
def get_node_chunk_dim(self, node_from, node_from_dim, node_to):
|
def get_node_chunk_dim(self, node_from, node_from_dim, node_to):
|
||||||
node_from_source = self._find_source_trace_from_node(node_from)
|
node_from_source = self._find_source_trace_from_node(node_from)
|
||||||
dim_source = node_from_source[node_from_dim]
|
dim_source = node_from_source[node_from_dim]
|
||||||
node_to_idx = _find_idx_by_name(node_to.name, self.nodes_list)
|
node_to_idx = _find_idx_by_name(node_to.name, self.node_list)
|
||||||
for k, v in dim_source.items():
|
for k, v in dim_source.items():
|
||||||
if k == node_to_idx:
|
if k == node_to_idx:
|
||||||
return v
|
return v
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def _find_inherit_dim(self, input_node, input_dim, node):
|
def _find_inherit_dim(self, input_node, input_dim, node):
|
||||||
input_node_idx = _find_idx_by_name(input_node.name, self.nodes_list)
|
input_node_idx = _find_idx_by_name(input_node.name, self.node_list)
|
||||||
node_trace_source = self._find_source_trace_from_node(node)
|
node_trace_source = self._find_source_trace_from_node(node)
|
||||||
for node_dim in range(len(_get_node_shape(node))):
|
for node_dim in range(len(_get_node_shape(node))):
|
||||||
if (
|
if (
|
||||||
@ -711,11 +711,11 @@ class IndexTracer(object):
|
|||||||
input_dim_after_node = {}
|
input_dim_after_node = {}
|
||||||
for input_node_idx, input_node in enumerate(chunk_infos["inputs"]):
|
for input_node_idx, input_node in enumerate(chunk_infos["inputs"]):
|
||||||
for k, v in chunk_infos["inputs_dim"][input_node_idx].items():
|
for k, v in chunk_infos["inputs_dim"][input_node_idx].items():
|
||||||
inherit_dim = self._find_inherit_dim(input_node, v, self.nodes_list[k])
|
inherit_dim = self._find_inherit_dim(input_node, v, self.node_list[k])
|
||||||
if inherit_dim:
|
if inherit_dim:
|
||||||
input_dim_after_node[k] = inherit_dim
|
input_dim_after_node[k] = inherit_dim
|
||||||
|
|
||||||
for node in self.nodes_list[
|
for node in self.node_list[
|
||||||
chunk_infos["region"][0] : chunk_infos["region"][1] + 1
|
chunk_infos["region"][0] : chunk_infos["region"][1] + 1
|
||||||
]:
|
]:
|
||||||
if _is_non_compute_node_except_placeholder(node):
|
if _is_non_compute_node_except_placeholder(node):
|
||||||
@ -746,124 +746,11 @@ class IndexTracer(object):
|
|||||||
else:
|
else:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
class FlowTracer(object):
|
|
||||||
def __init__(self, gm) -> None:
|
|
||||||
self.gm = gm
|
|
||||||
self.node_list = list(gm.graph.nodes)
|
|
||||||
self.flow_trace = {}
|
|
||||||
|
|
||||||
def _add_trace(self, name):
|
|
||||||
self.flow_trace[name] = []
|
|
||||||
|
|
||||||
def _add_node(self, trace_name, node):
|
|
||||||
self.flow_trace[trace_name].append(
|
|
||||||
{"node": node, "inside_depend": [], "outside_depend": []}
|
|
||||||
)
|
|
||||||
|
|
||||||
def _add_inside_depend(self, flow_name, node, inside_depend_node):
|
|
||||||
for i in self.flow_trace[flow_name]:
|
|
||||||
if i["node"] == node:
|
|
||||||
i["inside_depend"].append(inside_depend_node)
|
|
||||||
return
|
|
||||||
raise RuntimeError("node not found")
|
|
||||||
|
|
||||||
def _add_outside_depend(
|
|
||||||
self, flow_name, node, outside_depend_node, outside_depend_trace
|
|
||||||
):
|
|
||||||
for i in self.flow_trace[flow_name]:
|
|
||||||
if i["node"] == node:
|
|
||||||
i["outside_depend"].append({outside_depend_trace: outside_depend_node})
|
|
||||||
return
|
|
||||||
raise RuntimeError("node not found")
|
|
||||||
|
|
||||||
def _init_trace(self):
|
|
||||||
for i in self.node_list:
|
|
||||||
if i.op == "placeholder":
|
|
||||||
self._add_trace(i.name)
|
|
||||||
self._add_node(i.name, i)
|
|
||||||
|
|
||||||
def _find_flow_for_node(self, node):
|
|
||||||
if type(self.node_list[0]) != type(node):
|
|
||||||
return None
|
|
||||||
if _is_non_compute_node_except_placeholder(node):
|
|
||||||
return None
|
|
||||||
for name, trace in self.flow_trace.items():
|
|
||||||
for i in trace:
|
|
||||||
if node == i["node"]:
|
|
||||||
return name
|
|
||||||
if any(i in node.name for i in ["ones_like"]):
|
|
||||||
self._add_trace(node.name)
|
|
||||||
self._add_node(node.name, node)
|
|
||||||
return node.name
|
|
||||||
raise RuntimeError("node not found")
|
|
||||||
|
|
||||||
def _find_first_valid_flow(self, flow):
|
|
||||||
for i in flow:
|
|
||||||
if i is not None:
|
|
||||||
return i
|
|
||||||
raise RuntimeError("invalid flow")
|
|
||||||
|
|
||||||
def find_node_flow(self, node):
|
|
||||||
for name, trace in self.flow_trace.items():
|
|
||||||
for i in trace:
|
|
||||||
if node == i["node"]:
|
|
||||||
return name, i
|
|
||||||
raise RuntimeError("invalid node")
|
|
||||||
|
|
||||||
def _get_flow_mix_node(self, node):
|
|
||||||
if _is_non_compute_node(node):
|
|
||||||
return None
|
|
||||||
_, node_trace = self.find_node_flow(node)
|
|
||||||
if len(node_trace["outside_depend"]) == 0:
|
|
||||||
return None
|
|
||||||
elif len(node_trace["outside_depend"]) > 1:
|
|
||||||
raise NotImplementedError
|
|
||||||
vars = list(node_trace["outside_depend"][0].values())[0]
|
|
||||||
return vars
|
|
||||||
|
|
||||||
def _get_same_flow_node(self, node_list, node):
|
|
||||||
name, _ = self.find_node_flow(node)
|
|
||||||
result = []
|
|
||||||
for i in self.flow_trace[name]:
|
|
||||||
if i["node"] in node_list:
|
|
||||||
result.append(i["node"])
|
|
||||||
return result
|
|
||||||
|
|
||||||
def trace_flow(self):
|
|
||||||
# init trace
|
|
||||||
self._init_trace()
|
|
||||||
|
|
||||||
for node in self.node_list:
|
|
||||||
# skip if non compute node
|
|
||||||
if all(
|
|
||||||
type(arg) != type(node) or _is_non_compute_node_except_placeholder(arg)
|
|
||||||
for arg in node.args
|
|
||||||
) or _is_non_compute_node(node):
|
|
||||||
continue
|
|
||||||
|
|
||||||
node_input_flows = [self._find_flow_for_node(arg) for arg in node.args]
|
|
||||||
|
|
||||||
node_domin_flow = self._find_first_valid_flow(node_input_flows)
|
|
||||||
self._add_node(node_domin_flow, node)
|
|
||||||
for node_input_flow, arg in zip(node_input_flows, node.args):
|
|
||||||
if node_input_flow is None:
|
|
||||||
continue
|
|
||||||
elif node_input_flow == node_domin_flow:
|
|
||||||
self._add_inside_depend(node_domin_flow, node, arg)
|
|
||||||
else:
|
|
||||||
self._add_outside_depend(
|
|
||||||
node_domin_flow, node, arg, node_input_flow
|
|
||||||
)
|
|
||||||
return self.flow_trace
|
|
||||||
|
|
||||||
def _assgin_single_node_flow(
|
def _assgin_single_node_flow(
|
||||||
self,
|
self,
|
||||||
arg_node,
|
arg_node,
|
||||||
start_idx,
|
start_idx,
|
||||||
end_idx,
|
end_idx,
|
||||||
inputs,
|
|
||||||
index_tracer,
|
|
||||||
cur_node_dim,
|
cur_node_dim,
|
||||||
cur_node_compute,
|
cur_node_compute,
|
||||||
cur_node_source,
|
cur_node_source,
|
||||||
@ -871,7 +758,7 @@ class FlowTracer(object):
|
|||||||
all_node_info,
|
all_node_info,
|
||||||
next_node_list,
|
next_node_list,
|
||||||
):
|
):
|
||||||
arg_idx = _find_idx_by_name(arg_node.name, index_tracer.nodes_list)
|
arg_idx = _find_idx_by_name(arg_node.name, self.node_list)
|
||||||
# arg in chunk range or be inputs
|
# arg in chunk range or be inputs
|
||||||
if not (start_idx <= arg_idx < end_idx):
|
if not (start_idx <= arg_idx < end_idx):
|
||||||
return True
|
return True
|
||||||
@ -911,7 +798,7 @@ class FlowTracer(object):
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
def flow_search(
|
def flow_search(
|
||||||
self, start_idx, start_dim, end_idx, end_dim, index_tracer: IndexTracer
|
self, start_idx, start_dim, end_idx, end_dim
|
||||||
):
|
):
|
||||||
inputs, outputs = _find_chunk_compute_input_and_output_nodes(
|
inputs, outputs = _find_chunk_compute_input_and_output_nodes(
|
||||||
self.node_list[start_idx : end_idx + 1]
|
self.node_list[start_idx : end_idx + 1]
|
||||||
@ -920,7 +807,7 @@ class FlowTracer(object):
|
|||||||
if len(outputs) > 1:
|
if len(outputs) > 1:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
cur_node_list = [index_tracer.nodes_list[end_idx]] # start from the last node
|
cur_node_list = [self.node_list[end_idx]] # start from the last node
|
||||||
all_node_info = {cur_node_list[0]: {"chunk_dim": end_dim, "fix_dim": []}}
|
all_node_info = {cur_node_list[0]: {"chunk_dim": end_dim, "fix_dim": []}}
|
||||||
|
|
||||||
while len(cur_node_list) > 0:
|
while len(cur_node_list) > 0:
|
||||||
@ -930,12 +817,12 @@ class FlowTracer(object):
|
|||||||
# get cur node info
|
# get cur node info
|
||||||
cur_node_chunk_dim = all_node_info[cur_node]["chunk_dim"]
|
cur_node_chunk_dim = all_node_info[cur_node]["chunk_dim"]
|
||||||
cur_node_fix_dim = all_node_info[cur_node]["fix_dim"]
|
cur_node_fix_dim = all_node_info[cur_node]["fix_dim"]
|
||||||
cur_node_idx = _find_idx_by_name(cur_node.name, index_tracer.nodes_list)
|
cur_node_idx = _find_idx_by_name(cur_node.name, self.node_list)
|
||||||
if cur_node_chunk_dim:
|
if cur_node_chunk_dim:
|
||||||
cur_node_compute = index_tracer._find_compute_trace_from_node(
|
cur_node_compute = self._find_compute_trace_from_node(
|
||||||
cur_node
|
cur_node
|
||||||
)
|
)
|
||||||
cur_node_source = index_tracer._find_source_trace_from_node(
|
cur_node_source = self._find_source_trace_from_node(
|
||||||
cur_node
|
cur_node
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@ -953,8 +840,6 @@ class FlowTracer(object):
|
|||||||
arg,
|
arg,
|
||||||
start_idx,
|
start_idx,
|
||||||
end_idx,
|
end_idx,
|
||||||
inputs,
|
|
||||||
index_tracer,
|
|
||||||
cur_node_chunk_dim,
|
cur_node_chunk_dim,
|
||||||
cur_node_compute,
|
cur_node_compute,
|
||||||
cur_node_source,
|
cur_node_source,
|
||||||
@ -970,7 +855,7 @@ class FlowTracer(object):
|
|||||||
for arg in arg_list:
|
for arg in arg_list:
|
||||||
if not (
|
if not (
|
||||||
start_idx
|
start_idx
|
||||||
<= _find_idx_by_name(arg.name, index_tracer.nodes_list)
|
<= _find_idx_by_name(arg.name, self.node_list)
|
||||||
< end_idx
|
< end_idx
|
||||||
):
|
):
|
||||||
continue
|
continue
|
||||||
@ -1029,7 +914,7 @@ class FlowTracer(object):
|
|||||||
if node_info["chunk_dim"] is None:
|
if node_info["chunk_dim"] is None:
|
||||||
maybe_prepose_nodes.append(node)
|
maybe_prepose_nodes.append(node)
|
||||||
maybe_prepose_nodes.sort(
|
maybe_prepose_nodes.sort(
|
||||||
key=lambda x: _find_idx_by_name(x.name, index_tracer.nodes_list),
|
key=lambda x: _find_idx_by_name(x.name, self.node_list),
|
||||||
reverse=True,
|
reverse=True,
|
||||||
) # from last node to first node
|
) # from last node to first node
|
||||||
prepose_nodes = []
|
prepose_nodes = []
|
||||||
@ -1081,7 +966,7 @@ class FlowTracer(object):
|
|||||||
maybe_prepose_nodes.remove(n)
|
maybe_prepose_nodes.remove(n)
|
||||||
# sort by index
|
# sort by index
|
||||||
prepose_nodes.sort(
|
prepose_nodes.sort(
|
||||||
key=lambda x: _find_idx_by_name(x.name, index_tracer.nodes_list)
|
key=lambda x: _find_idx_by_name(x.name, self.node_list)
|
||||||
)
|
)
|
||||||
chunk_info["args"]["prepose_nodes"] = prepose_nodes
|
chunk_info["args"]["prepose_nodes"] = prepose_nodes
|
||||||
|
|
||||||
@ -1226,9 +1111,9 @@ class MemoryEstimator(object):
|
|||||||
for k, v in input_node_dim.items():
|
for k, v in input_node_dim.items():
|
||||||
# TODO: inherit dim should be list too, int now
|
# TODO: inherit dim should be list too, int now
|
||||||
inherit_dim = self.index_tracer._find_inherit_dim(
|
inherit_dim = self.index_tracer._find_inherit_dim(
|
||||||
input_node, v, self.index_tracer.nodes_list[k]
|
input_node, v, self.index_tracer.node_list[k]
|
||||||
)
|
)
|
||||||
if k == _find_idx_by_name(node.name, self.index_tracer.nodes_list):
|
if k == _find_idx_by_name(node.name, self.index_tracer.node_list):
|
||||||
chunk_ratio = float(chunk_size) / node_shape[inherit_dim]
|
chunk_ratio = float(chunk_size) / node_shape[inherit_dim]
|
||||||
return chunk_ratio
|
return chunk_ratio
|
||||||
for dim, source in enumerate(node_source):
|
for dim, source in enumerate(node_source):
|
||||||
@ -1412,8 +1297,6 @@ class ChunkRegionSearch(object):
|
|||||||
self.node_list = list(gm.graph.nodes)
|
self.node_list = list(gm.graph.nodes)
|
||||||
self.index_tracer = IndexTracer(gm)
|
self.index_tracer = IndexTracer(gm)
|
||||||
self.index_tracer.trace_index()
|
self.index_tracer.trace_index()
|
||||||
self.flow_tracer = FlowTracer(gm)
|
|
||||||
self.flow_tracer.trace_flow()
|
|
||||||
self.memory_estimator = MemoryEstimator(self.index_tracer)
|
self.memory_estimator = MemoryEstimator(self.index_tracer)
|
||||||
|
|
||||||
def _find_peak_node(self, mem_peak):
|
def _find_peak_node(self, mem_peak):
|
||||||
@ -1517,8 +1400,8 @@ class ChunkRegionSearch(object):
|
|||||||
):
|
):
|
||||||
continue
|
continue
|
||||||
# flow search
|
# flow search
|
||||||
chunk_info = self.flow_tracer.flow_search(
|
chunk_info = self.index_tracer.flow_search(
|
||||||
start_idx, start_dim, end_idx, end_dim, self.index_tracer
|
start_idx, start_dim, end_idx, end_dim
|
||||||
)
|
)
|
||||||
if chunk_info is None:
|
if chunk_info is None:
|
||||||
continue
|
continue
|
||||||
|
Loading…
Reference in New Issue
Block a user