update source add

This commit is contained in:
oahzxl 2022-12-31 01:00:06 +08:00
parent f5515e9978
commit e5a5fbb8a9

View File

@ -133,24 +133,28 @@ class IndexTracer(object):
def _add_source(self, node_from, node_from_dim, node_to, node_to_dim, init=False): def _add_source(self, node_from, node_from_dim, node_to, node_to_dim, init=False):
node_from_dim = self._transform_index(node_from, node_from_dim) node_from_dim = self._transform_index(node_from, node_from_dim)
node_from_trace = self._find_trace_from_node(node_from) node_from_trace_source = self._find_source_trace_from_node(node_from)
node_to_dim = self._transform_index(node_to, node_to_dim) node_to_dim = self._transform_index(node_to, node_to_dim)
node_to_trace = self._find_trace_from_node(node_to) node_to_trace_source = self._find_source_trace_from_node(node_to)
node_from_idx = _find_idx_by_name(node_from.name, self.node_list) node_from_idx = _find_idx_by_name(node_from.name, self.node_list)
if init: if init:
node_to_trace["source"][node_to_dim] = {} node_to_trace_source[node_to_dim] = {}
# add dim to cur new source # add dim to cur new source
if node_from_idx not in node_to_trace["source"][node_to_dim]: if node_from_idx not in node_to_trace_source[node_to_dim]:
node_to_trace["source"][node_to_dim][node_from_idx] = [node_from_dim] node_to_trace_source[node_to_dim][node_from_idx] = [node_from_dim]
else: else:
if node_from_dim not in node_to_trace["source"][node_to_dim][node_from_idx]: if node_from_dim not in node_to_trace_source[node_to_dim][node_from_idx]:
node_to_trace["source"][node_to_dim][node_from_idx].append( node_to_trace_source[node_to_dim][node_from_idx].append(
node_from_dim node_from_dim
) )
# update inputs source # update inputs source
node_to_trace["source"][node_to_dim].update( for node_idx, node_dim in node_from_trace_source[node_from_dim].items():
node_from_trace["source"][node_from_dim] if node_idx not in node_to_trace_source[node_to_dim]:
) node_to_trace_source[node_to_dim][node_idx] = copy.deepcopy(node_dim)
else:
for d in node_dim:
if d not in node_to_trace_source[node_to_dim][node_idx]:
node_to_trace_source[node_to_dim][node_idx].append(d)
def _mark_computation_from_node(self, node_from, node_to, exclude=None): def _mark_computation_from_node(self, node_from, node_to, exclude=None):
if exclude == None: if exclude == None:
@ -1761,9 +1765,9 @@ class ChunkRegionSearch(object):
) )
if self._stop_search(init_mem_peak, mem_peak): if self._stop_search(init_mem_peak, mem_peak):
break break
# self.memory_estimator.estimate_chunk_inference_mem( self.memory_estimator.estimate_chunk_inference_mem(
# self.index_tracer.node_list, chunk_infos, print_mem=True self.index_tracer.node_list, chunk_infos, print_mem=True
# ) )
return chunk_infos return chunk_infos