code style

This commit is contained in:
oahzxl 2022-12-23 13:41:51 +08:00
parent 774d34f1aa
commit 522f017418

View File

@ -1004,7 +1004,7 @@ class FlowTracer(object):
# if already in node_info, arg dim must be same # if already in node_info, arg dim must be same
if arg_node in all_node_info: if arg_node in all_node_info:
if all_node_info[arg_node]['chunk_dim'] != arg_dim: if all_node_info[arg_node]["chunk_dim"] != arg_dim:
return False return False
all_node_info[arg_node]["fix_dim"] = list( all_node_info[arg_node]["fix_dim"] = list(
set(all_node_info[arg_node]["fix_dim"] + arg_fix_dim) set(all_node_info[arg_node]["fix_dim"] + arg_fix_dim)
@ -1132,9 +1132,12 @@ class FlowTracer(object):
# get all possible prepose nodes # get all possible prepose nodes
maybe_prepose_nodes = [] maybe_prepose_nodes = []
for node, node_info in all_node_info.items(): for node, node_info in all_node_info.items():
if node_info['chunk_dim'] is None: if node_info["chunk_dim"] is None:
maybe_prepose_nodes.append(node) maybe_prepose_nodes.append(node)
maybe_prepose_nodes.sort(key=lambda x: _find_idx_by_name(x.name, index_tracer.nodes_list), reverse=True) # from last node to first node maybe_prepose_nodes.sort(
key=lambda x: _find_idx_by_name(x.name, index_tracer.nodes_list),
reverse=True,
) # from last node to first node
prepose_nodes = [] prepose_nodes = []
# set every node as root, search its args, if all legal, turn root and args as prepose nodes # set every node as root, search its args, if all legal, turn root and args as prepose nodes
while len(maybe_prepose_nodes) > 0: while len(maybe_prepose_nodes) > 0:
@ -1151,15 +1154,23 @@ class FlowTracer(object):
if type(cur_prepose_node_arg) != type(cur_prepose_node): if type(cur_prepose_node_arg) != type(cur_prepose_node):
continue continue
# out of loop # out of loop
if not (start_idx <= _find_idx_by_name(cur_prepose_node_arg.name, self.node_list) < end_idx): if not (
start_idx
<= _find_idx_by_name(
cur_prepose_node_arg.name, self.node_list
)
< end_idx
):
continue continue
# compute op in loop # compute op in loop
elif cur_prepose_node_arg in all_node_info: elif cur_prepose_node_arg in all_node_info:
if all_node_info[cur_prepose_node_arg]['chunk_dim'] is None: if all_node_info[cur_prepose_node_arg]["chunk_dim"] is None:
tmp_next_prepose_nodes.append(cur_prepose_node_arg) tmp_next_prepose_nodes.append(cur_prepose_node_arg)
else: else:
prepose_flag = False prepose_flag = False
break; break; break break
break
break
# non compute op # non compute op
else: else:
tmp_next_prepose_nodes.append(cur_prepose_node_arg) tmp_next_prepose_nodes.append(cur_prepose_node_arg)
@ -1175,7 +1186,9 @@ class FlowTracer(object):
if n in maybe_prepose_nodes: if n in maybe_prepose_nodes:
maybe_prepose_nodes.remove(n) maybe_prepose_nodes.remove(n)
# sort by index # sort by index
prepose_nodes.sort(key=lambda x: _find_idx_by_name(x.name, index_tracer.nodes_list)) prepose_nodes.sort(
key=lambda x: _find_idx_by_name(x.name, index_tracer.nodes_list)
)
chunk_info["args"]["prepose_nodes"] = prepose_nodes chunk_info["args"]["prepose_nodes"] = prepose_nodes
# we need to log input nodes to avoid deleteing them in the loop # we need to log input nodes to avoid deleteing them in the loop
@ -1183,9 +1196,7 @@ class FlowTracer(object):
# also need to get some prepose node's arg out of non_chunk_inputs # also need to get some prepose node's arg out of non_chunk_inputs
for n in prepose_nodes: for n in prepose_nodes:
chunk_node_list.remove(n) chunk_node_list.remove(n)
non_chunk_inputs = _find_chunk_all_input_nodes( non_chunk_inputs = _find_chunk_all_input_nodes(chunk_node_list)
chunk_node_list
)
for i in non_chunk_inputs: for i in non_chunk_inputs:
if i not in chunk_info["inputs"] and i not in prepose_nodes: if i not in chunk_info["inputs"] and i not in prepose_nodes:
chunk_info["inputs_non_chunk"].append(i) chunk_info["inputs_non_chunk"].append(i)