mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-08-20 00:47:13 +00:00
code style
This commit is contained in:
parent
4f5e105af3
commit
fa5e6fbf96
@ -65,9 +65,8 @@ def _is_non_compute_node_except_placeholder_output(node):
|
|||||||
|
|
||||||
|
|
||||||
class IndexTracer(object):
|
class IndexTracer(object):
|
||||||
def __init__(self, gm) -> None:
|
def __init__(self, node_list) -> None:
|
||||||
self.gm = gm
|
self.node_list = node_list
|
||||||
self.node_list = list(gm.graph.nodes)
|
|
||||||
self.idx_trace_list = self._init_idx_trace_list()
|
self.idx_trace_list = self._init_idx_trace_list()
|
||||||
self.idx_trace_equal = []
|
self.idx_trace_equal = []
|
||||||
self.idx_view_list = []
|
self.idx_view_list = []
|
||||||
@ -797,9 +796,7 @@ class IndexTracer(object):
|
|||||||
next_node_list.append(arg_node)
|
next_node_list.append(arg_node)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def flow_search(
|
def flow_search(self, start_idx, start_dim, end_idx, end_dim):
|
||||||
self, start_idx, start_dim, end_idx, end_dim
|
|
||||||
):
|
|
||||||
inputs, outputs = _find_chunk_compute_input_and_output_nodes(
|
inputs, outputs = _find_chunk_compute_input_and_output_nodes(
|
||||||
self.node_list[start_idx : end_idx + 1]
|
self.node_list[start_idx : end_idx + 1]
|
||||||
)
|
)
|
||||||
@ -819,12 +816,8 @@ class IndexTracer(object):
|
|||||||
cur_node_fix_dim = all_node_info[cur_node]["fix_dim"]
|
cur_node_fix_dim = all_node_info[cur_node]["fix_dim"]
|
||||||
cur_node_idx = _find_idx_by_name(cur_node.name, self.node_list)
|
cur_node_idx = _find_idx_by_name(cur_node.name, self.node_list)
|
||||||
if cur_node_chunk_dim:
|
if cur_node_chunk_dim:
|
||||||
cur_node_compute = self._find_compute_trace_from_node(
|
cur_node_compute = self._find_compute_trace_from_node(cur_node)
|
||||||
cur_node
|
cur_node_source = self._find_source_trace_from_node(cur_node)
|
||||||
)
|
|
||||||
cur_node_source = self._find_source_trace_from_node(
|
|
||||||
cur_node
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
cur_node_compute = cur_node_source = None
|
cur_node_compute = cur_node_source = None
|
||||||
|
|
||||||
@ -965,9 +958,7 @@ class IndexTracer(object):
|
|||||||
if n in maybe_prepose_nodes:
|
if n in maybe_prepose_nodes:
|
||||||
maybe_prepose_nodes.remove(n)
|
maybe_prepose_nodes.remove(n)
|
||||||
# sort by index
|
# sort by index
|
||||||
prepose_nodes.sort(
|
prepose_nodes.sort(key=lambda x: _find_idx_by_name(x.name, self.node_list))
|
||||||
key=lambda x: _find_idx_by_name(x.name, self.node_list)
|
|
||||||
)
|
|
||||||
chunk_info["args"]["prepose_nodes"] = prepose_nodes
|
chunk_info["args"]["prepose_nodes"] = prepose_nodes
|
||||||
|
|
||||||
# we need to log input nodes to avoid deleteing them in the loop
|
# we need to log input nodes to avoid deleteing them in the loop
|
||||||
@ -1295,7 +1286,9 @@ class ChunkRegionSearch(object):
|
|||||||
def __init__(self, gm) -> None:
|
def __init__(self, gm) -> None:
|
||||||
self.gm = gm
|
self.gm = gm
|
||||||
self.node_list = list(gm.graph.nodes)
|
self.node_list = list(gm.graph.nodes)
|
||||||
self.index_tracer = IndexTracer(gm)
|
self.index_tracer = IndexTracer(
|
||||||
|
self.node_list
|
||||||
|
) # node list shared in index tracer
|
||||||
self.index_tracer.trace_index()
|
self.index_tracer.trace_index()
|
||||||
self.memory_estimator = MemoryEstimator(self.index_tracer)
|
self.memory_estimator = MemoryEstimator(self.index_tracer)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user