From 774d34f1aa2f9534557dd4a0ca866392a496e448 Mon Sep 17 00:00:00 2001 From: oahzxl Date: Fri, 23 Dec 2022 13:41:10 +0800 Subject: [PATCH] refactor flow search --- chunk_codegen.py | 78 +++++++++++++++++++++++++++++++++++------------- 1 file changed, 58 insertions(+), 20 deletions(-) diff --git a/chunk_codegen.py b/chunk_codegen.py index eb16361c0..0b0a164fe 100644 --- a/chunk_codegen.py +++ b/chunk_codegen.py @@ -1004,7 +1004,7 @@ class FlowTracer(object): # if already in node_info, arg dim must be same if arg_node in all_node_info: - if all_node_info[arg_node] != arg_dim: + if all_node_info[arg_node]['chunk_dim'] != arg_dim: return False all_node_info[arg_node]["fix_dim"] = list( set(all_node_info[arg_node]["fix_dim"] + arg_fix_dim) @@ -1128,14 +1128,68 @@ class FlowTracer(object): "args": {}, } + # move useless nodes ahead of loop + # get all possible prepose nodes + maybe_prepose_nodes = [] + for node, node_info in all_node_info.items(): + if node_info['chunk_dim'] is None: + maybe_prepose_nodes.append(node) + maybe_prepose_nodes.sort(key=lambda x: _find_idx_by_name(x.name, index_tracer.nodes_list), reverse=True) # from last node to first node + prepose_nodes = [] + # set every node as root, search its args, if all legal, turn root and args as prepose nodes + while len(maybe_prepose_nodes) > 0: + tmp_cur_prepose_nodes = [maybe_prepose_nodes[0]] + tmp_cur_related_prepose_nodes = [] + prepose_flag = True + + # loop cur node's all arg until out of chunk + while len(tmp_cur_prepose_nodes) > 0: + tmp_next_prepose_nodes = [] + tmp_cur_related_prepose_nodes.extend(tmp_cur_prepose_nodes) + for cur_prepose_node in tmp_cur_prepose_nodes: + for cur_prepose_node_arg in cur_prepose_node.args: + if type(cur_prepose_node_arg) != type(cur_prepose_node): + continue + # out of loop + if not (start_idx <= _find_idx_by_name(cur_prepose_node_arg.name, self.node_list) < end_idx): + continue + # compute op in loop + elif cur_prepose_node_arg in all_node_info: + if all_node_info[cur_prepose_node_arg]['chunk_dim'] is None: + tmp_next_prepose_nodes.append(cur_prepose_node_arg) + else: + prepose_flag = False + break; break; break + # non compute op + else: + tmp_next_prepose_nodes.append(cur_prepose_node_arg) + tmp_cur_prepose_nodes = tmp_next_prepose_nodes + + if prepose_flag == False: + maybe_prepose_nodes.remove(maybe_prepose_nodes[0]) + continue + else: + for n in tmp_cur_related_prepose_nodes: + if n not in prepose_nodes: + prepose_nodes.append(n) + if n in maybe_prepose_nodes: + maybe_prepose_nodes.remove(n) + # sort by index + prepose_nodes.sort(key=lambda x: _find_idx_by_name(x.name, index_tracer.nodes_list)) + chunk_info["args"]["prepose_nodes"] = prepose_nodes + # we need to log input nodes to avoid deleteing them in the loop + chunk_node_list = self.node_list[start_idx : end_idx + 1] + # also need to get some prepose node's arg out of non_chunk_inputs + for n in prepose_nodes: + chunk_node_list.remove(n) non_chunk_inputs = _find_chunk_all_input_nodes( - self.node_list[start_idx : end_idx + 1] + chunk_node_list ) for i in non_chunk_inputs: - if i not in chunk_info["inputs"]: + if i not in chunk_info["inputs"] and i not in prepose_nodes: chunk_info["inputs_non_chunk"].append(i) - + return chunk_info @@ -1541,16 +1595,6 @@ class ChunkRegionSearch(object): continue for start_node, start_trace in start_traces.items(): for start_dim, start_trace_idx in enumerate(start_trace["idx"]): - if ( - start_idx == 199 - and end_idx == 229 - and start_dim == 2 - and end_dim == 2 - ): - print(1) - self.flow_tracer.flow_search( - start_idx, start_dim, end_idx, end_dim, self.index_tracer - ) # dim size cannot be 1 if ( _get_node_shape(end_node)[end_dim] == 1 @@ -1567,12 +1611,6 @@ class ChunkRegionSearch(object): start_idx, end_dim, end_node, end_idx ): continue - # detect flow meet - # flow_block, chunk_info = self.flow_tracer._detect_flow( - # start_idx, start_dim, end_idx, end_dim, self.index_tracer - # ) - # if flow_block: - # continue # flow search chunk_info = self.flow_tracer.flow_search( start_idx, start_dim, end_idx, end_dim, self.index_tracer