From 522f01741864f3565f8e97837ecc7289774ee127 Mon Sep 17 00:00:00 2001 From: oahzxl Date: Fri, 23 Dec 2022 13:41:51 +0800 Subject: [PATCH] code style --- chunk_codegen.py | 39 +++++++++++++++++++++++++-------------- 1 file changed, 25 insertions(+), 14 deletions(-) diff --git a/chunk_codegen.py b/chunk_codegen.py index 0b0a164fe..a8b970116 100644 --- a/chunk_codegen.py +++ b/chunk_codegen.py @@ -1004,7 +1004,7 @@ class FlowTracer(object): # if already in node_info, arg dim must be same if arg_node in all_node_info: - if all_node_info[arg_node]['chunk_dim'] != arg_dim: + if all_node_info[arg_node]["chunk_dim"] != arg_dim: return False all_node_info[arg_node]["fix_dim"] = list( set(all_node_info[arg_node]["fix_dim"] + arg_fix_dim) @@ -1132,16 +1132,19 @@ class FlowTracer(object): # get all possible prepose nodes maybe_prepose_nodes = [] for node, node_info in all_node_info.items(): - if node_info['chunk_dim'] is None: + if node_info["chunk_dim"] is None: maybe_prepose_nodes.append(node) - maybe_prepose_nodes.sort(key=lambda x: _find_idx_by_name(x.name, index_tracer.nodes_list), reverse=True) # from last node to first node + maybe_prepose_nodes.sort( + key=lambda x: _find_idx_by_name(x.name, index_tracer.nodes_list), + reverse=True, + ) # from last node to first node prepose_nodes = [] # set every node as root, search its args, if all legal, turn root and args as prepose nodes while len(maybe_prepose_nodes) > 0: tmp_cur_prepose_nodes = [maybe_prepose_nodes[0]] tmp_cur_related_prepose_nodes = [] prepose_flag = True - + # loop cur node's all arg until out of chunk while len(tmp_cur_prepose_nodes) > 0: tmp_next_prepose_nodes = [] @@ -1151,20 +1154,28 @@ class FlowTracer(object): if type(cur_prepose_node_arg) != type(cur_prepose_node): continue # out of loop - if not (start_idx <= _find_idx_by_name(cur_prepose_node_arg.name, self.node_list) < end_idx): + if not ( + start_idx + <= _find_idx_by_name( + cur_prepose_node_arg.name, self.node_list + ) + < end_idx + ): continue # compute op in loop elif cur_prepose_node_arg in all_node_info: - if all_node_info[cur_prepose_node_arg]['chunk_dim'] is None: + if all_node_info[cur_prepose_node_arg]["chunk_dim"] is None: tmp_next_prepose_nodes.append(cur_prepose_node_arg) else: prepose_flag = False - break; break; break + break + break + break # non compute op else: tmp_next_prepose_nodes.append(cur_prepose_node_arg) tmp_cur_prepose_nodes = tmp_next_prepose_nodes - + if prepose_flag == False: maybe_prepose_nodes.remove(maybe_prepose_nodes[0]) continue @@ -1175,21 +1186,21 @@ class FlowTracer(object): if n in maybe_prepose_nodes: maybe_prepose_nodes.remove(n) # sort by index - prepose_nodes.sort(key=lambda x: _find_idx_by_name(x.name, index_tracer.nodes_list)) + prepose_nodes.sort( + key=lambda x: _find_idx_by_name(x.name, index_tracer.nodes_list) + ) chunk_info["args"]["prepose_nodes"] = prepose_nodes - + # we need to log input nodes to avoid deleteing them in the loop chunk_node_list = self.node_list[start_idx : end_idx + 1] # also need to get some prepose node's arg out of non_chunk_inputs for n in prepose_nodes: chunk_node_list.remove(n) - non_chunk_inputs = _find_chunk_all_input_nodes( - chunk_node_list - ) + non_chunk_inputs = _find_chunk_all_input_nodes(chunk_node_list) for i in non_chunk_inputs: if i not in chunk_info["inputs"] and i not in prepose_nodes: chunk_info["inputs_non_chunk"].append(i) - + return chunk_info