update source add

This commit is contained in:
oahzxl 2022-12-31 01:00:06 +08:00
parent f5515e9978
commit e5a5fbb8a9

View File

@ -133,24 +133,28 @@ class IndexTracer(object):
def _add_source(self, node_from, node_from_dim, node_to, node_to_dim, init=False):
node_from_dim = self._transform_index(node_from, node_from_dim)
node_from_trace = self._find_trace_from_node(node_from)
node_from_trace_source = self._find_source_trace_from_node(node_from)
node_to_dim = self._transform_index(node_to, node_to_dim)
node_to_trace = self._find_trace_from_node(node_to)
node_to_trace_source = self._find_source_trace_from_node(node_to)
node_from_idx = _find_idx_by_name(node_from.name, self.node_list)
if init:
node_to_trace["source"][node_to_dim] = {}
node_to_trace_source[node_to_dim] = {}
# add dim to cur new source
if node_from_idx not in node_to_trace["source"][node_to_dim]:
node_to_trace["source"][node_to_dim][node_from_idx] = [node_from_dim]
if node_from_idx not in node_to_trace_source[node_to_dim]:
node_to_trace_source[node_to_dim][node_from_idx] = [node_from_dim]
else:
if node_from_dim not in node_to_trace["source"][node_to_dim][node_from_idx]:
node_to_trace["source"][node_to_dim][node_from_idx].append(
if node_from_dim not in node_to_trace_source[node_to_dim][node_from_idx]:
node_to_trace_source[node_to_dim][node_from_idx].append(
node_from_dim
)
# update inputs source
node_to_trace["source"][node_to_dim].update(
node_from_trace["source"][node_from_dim]
)
for node_idx, node_dim in node_from_trace_source[node_from_dim].items():
if node_idx not in node_to_trace_source[node_to_dim]:
node_to_trace_source[node_to_dim][node_idx] = copy.deepcopy(node_dim)
else:
for d in node_dim:
if d not in node_to_trace_source[node_to_dim][node_idx]:
node_to_trace_source[node_to_dim][node_idx].append(d)
def _mark_computation_from_node(self, node_from, node_to, exclude=None):
if exclude == None:
@ -1761,9 +1765,9 @@ class ChunkRegionSearch(object):
)
if self._stop_search(init_mem_peak, mem_peak):
break
# self.memory_estimator.estimate_chunk_inference_mem(
# self.index_tracer.node_list, chunk_infos, print_mem=True
# )
self.memory_estimator.estimate_chunk_inference_mem(
self.index_tracer.node_list, chunk_infos, print_mem=True
)
return chunk_infos