From e5a5fbb8a94313722542b72f601b8433eef1e5dc Mon Sep 17 00:00:00 2001 From: oahzxl Date: Sat, 31 Dec 2022 01:00:06 +0800 Subject: [PATCH] update source add --- chunk_codegen.py | 30 +++++++++++++++++------------- 1 file changed, 17 insertions(+), 13 deletions(-) diff --git a/chunk_codegen.py b/chunk_codegen.py index 1c8be65d4..de58a61b9 100644 --- a/chunk_codegen.py +++ b/chunk_codegen.py @@ -133,24 +133,28 @@ class IndexTracer(object): def _add_source(self, node_from, node_from_dim, node_to, node_to_dim, init=False): node_from_dim = self._transform_index(node_from, node_from_dim) - node_from_trace = self._find_trace_from_node(node_from) + node_from_trace_source = self._find_source_trace_from_node(node_from) node_to_dim = self._transform_index(node_to, node_to_dim) - node_to_trace = self._find_trace_from_node(node_to) + node_to_trace_source = self._find_source_trace_from_node(node_to) node_from_idx = _find_idx_by_name(node_from.name, self.node_list) if init: - node_to_trace["source"][node_to_dim] = {} + node_to_trace_source[node_to_dim] = {} # add dim to cur new source - if node_from_idx not in node_to_trace["source"][node_to_dim]: - node_to_trace["source"][node_to_dim][node_from_idx] = [node_from_dim] + if node_from_idx not in node_to_trace_source[node_to_dim]: + node_to_trace_source[node_to_dim][node_from_idx] = [node_from_dim] else: - if node_from_dim not in node_to_trace["source"][node_to_dim][node_from_idx]: - node_to_trace["source"][node_to_dim][node_from_idx].append( + if node_from_dim not in node_to_trace_source[node_to_dim][node_from_idx]: + node_to_trace_source[node_to_dim][node_from_idx].append( node_from_dim ) # update inputs source - node_to_trace["source"][node_to_dim].update( - node_from_trace["source"][node_from_dim] - ) + for node_idx, node_dim in node_from_trace_source[node_from_dim].items(): + if node_idx not in node_to_trace_source[node_to_dim]: + node_to_trace_source[node_to_dim][node_idx] = copy.deepcopy(node_dim) + else: + for d in node_dim: + if d not in node_to_trace_source[node_to_dim][node_idx]: + node_to_trace_source[node_to_dim][node_idx].append(d) def _mark_computation_from_node(self, node_from, node_to, exclude=None): if exclude == None: @@ -1761,9 +1765,9 @@ class ChunkRegionSearch(object): ) if self._stop_search(init_mem_peak, mem_peak): break - # self.memory_estimator.estimate_chunk_inference_mem( - # self.index_tracer.node_list, chunk_infos, print_mem=True - # ) + self.memory_estimator.estimate_chunk_inference_mem( + self.index_tracer.node_list, chunk_infos, print_mem=True + ) return chunk_infos