From 1e0fd11bc1773ca47cbd95fb19b86517265390ce Mon Sep 17 00:00:00 2001 From: oahzxl Date: Tue, 13 Dec 2022 10:01:30 +0800 Subject: [PATCH] support check_index_duplicate --- chunk_codegen.py | 31 ++++++++++++++++--------------- 1 file changed, 16 insertions(+), 15 deletions(-) diff --git a/chunk_codegen.py b/chunk_codegen.py index 22d48f5d6..64bff4a80 100644 --- a/chunk_codegen.py +++ b/chunk_codegen.py @@ -179,7 +179,12 @@ class FlowTracer(object): "outputs_dim": end_dim, "args": {}, } - flow_flag = False + flow_block = False + + # TODO don't allow multi outputs now + if len(outputs) > 1: + flow_block = True + return flow_block, chunk_info for idx in range(start_idx, end_idx + 1): node = self.node_list[idx] @@ -199,10 +204,10 @@ class FlowTracer(object): self.node_list[end_idx], end_dim, node ) if mix_flow_node_dim is None: - flow_flag = True + flow_block = True break if _get_node_shape(mix_flow_node)[mix_flow_node_dim] == 1: - flow_flag = False + flow_block = False for i in self._get_same_flow_node( chunk_info["inputs"], mix_flow_node ): @@ -210,11 +215,15 @@ class FlowTracer(object): # else, we need to chunk mix var as well else: # TODO chunk another value - flow_flag = True + flow_block = True break else: raise NotImplementedError("%s not implemented" % node.name) + if flow_block: + flow_block = True + return flow_block, chunk_info + inputs_dim = [] remove_inputs = [] for input_node in chunk_info["inputs"]: @@ -250,7 +259,7 @@ class FlowTracer(object): if i not in chunk_info["inputs"]: chunk_info["inputs_non_chunk"].append(i) - return flow_flag, chunk_info + return flow_block, chunk_info class IndexTracer(object): @@ -869,14 +878,6 @@ class IndexTracer(object): if any(start_idx <= i <= end_idx for i in end_node_compute): return False return True - # end_node_trace_source = end_node_trace['source'][end_dim] - # for node_idx, node_dim in end_node_trace_source.items(): - # if node_idx < start_node_idx or node_idx > end_node_idx: - # continue - # compute_list = self.idx_trace_list[node_idx]['compute'][node_dim] - # if any(start_node_idx <= i <= end_node_idx for i in compute_list): - # return False - # return True def _get_node_chunk_dim(self, node_from, node_from_dim, node_to): node_from_source = self._find_source_trace_from_node(node_from) @@ -1240,10 +1241,10 @@ class ChunkRegionSearch(object): ): continue # detect flow meet - flow_flag, chunk_info = self.flow_tracer._detect_flow( + flow_block, chunk_info = self.flow_tracer._detect_flow( start_idx, start_dim, end_idx, end_dim, self.index_tracer ) - if flow_flag: + if flow_block: continue chunk_infos.append(chunk_info) chunk_infos = self._check_duplicate_map(chunk_infos)