code style

This commit is contained in:
oahzxl
2023-01-06 17:31:59 +08:00
parent a6cdbf9161
commit c3a2bf48b4
5 changed files with 46 additions and 36 deletions

View File

@@ -3,28 +3,31 @@ from .utils import find_idx_by_name
class ReorderGraph(object):
def __init__(self, index_tracer: TraceIndex) -> None:
self.index_tracer = index_tracer
self.all_reorder_map = {i: i for i in range(len(self.index_tracer.idx_trace_list))}
def __init__(self, trace_index: TraceIndex) -> None:
self.trace_index = trace_index
self.all_reorder_map = {
i: i for i in range(len(self.trace_index.idx_trace_list))
}
def _get_reorder_map(self, chunk_info):
reorder_map = {i: i for i in range(len(self.index_tracer.node_list))}
reorder_map = {i: i for i in range(len(self.trace_index.node_list))}
chunk_region_start = chunk_info["region"][0]
chunk_region_end = chunk_info["region"][1]
chunk_prepose_nodes = chunk_info["args"]["prepose_nodes"]
chunk_prepose_nodes_idx = [
find_idx_by_name(i.name, self.index_tracer.node_list) for i in chunk_prepose_nodes
find_idx_by_name(i.name, self.trace_index.node_list)
for i in chunk_prepose_nodes
]
# put prepose nodes ahead
for idx, n in enumerate(chunk_prepose_nodes):
n_idx = chunk_prepose_nodes_idx[idx]
reorder_map[n_idx] = chunk_region_start + idx
# put other nodes after prepose nodes
for n in self.index_tracer.node_list[chunk_region_start : chunk_region_end + 1]:
for n in self.trace_index.node_list[chunk_region_start : chunk_region_end + 1]:
if n in chunk_prepose_nodes:
continue
n_idx = find_idx_by_name(n.name, self.index_tracer.node_list)
n_idx = find_idx_by_name(n.name, self.trace_index.node_list)
pos = sum([n_idx < i for i in chunk_prepose_nodes_idx])
reorder_map[n_idx] = n_idx + pos
@@ -50,25 +53,25 @@ class ReorderGraph(object):
self.all_reorder_map[origin_idx] = reorder_map[map_idx]
def _reorder_self_node_list(self, reorder_map):
new_node_list = [None for _ in range(len(self.index_tracer.node_list))]
new_node_list = [None for _ in range(len(self.trace_index.node_list))]
for old_idx, new_idx in reorder_map.items():
new_node_list[new_idx] = self.index_tracer.node_list[old_idx]
self.index_tracer.node_list = new_node_list
new_node_list[new_idx] = self.trace_index.node_list[old_idx]
self.trace_index.node_list = new_node_list
def _reorder_idx_trace(self, reorder_map):
# reorder list
new_idx_trace_list = [None for _ in range(len(self.index_tracer.idx_trace_list))]
new_idx_trace_list = [None for _ in range(len(self.trace_index.idx_trace_list))]
for old_idx, new_idx in reorder_map.items():
new_idx_trace_list[new_idx] = self.index_tracer.idx_trace_list[old_idx]
self.index_tracer.idx_trace_list = new_idx_trace_list
new_idx_trace_list[new_idx] = self.trace_index.idx_trace_list[old_idx]
self.trace_index.idx_trace_list = new_idx_trace_list
# update compute
for idx_trace in self.index_tracer.idx_trace_list:
for idx_trace in self.trace_index.idx_trace_list:
compute = idx_trace["compute"]
for dim_compute in compute:
for idx, i in enumerate(dim_compute):
dim_compute[idx] = reorder_map[i]
# update source
for idx_trace in self.index_tracer.idx_trace_list:
for idx_trace in self.trace_index.idx_trace_list:
source = idx_trace["source"]
for dim_idx, dim_source in enumerate(source):
new_dim_source = {}