mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-23 14:10:29 +00:00
update source add
This commit is contained in:
parent
f5515e9978
commit
e5a5fbb8a9
@ -133,24 +133,28 @@ class IndexTracer(object):
|
||||
|
||||
def _add_source(self, node_from, node_from_dim, node_to, node_to_dim, init=False):
|
||||
node_from_dim = self._transform_index(node_from, node_from_dim)
|
||||
node_from_trace = self._find_trace_from_node(node_from)
|
||||
node_from_trace_source = self._find_source_trace_from_node(node_from)
|
||||
node_to_dim = self._transform_index(node_to, node_to_dim)
|
||||
node_to_trace = self._find_trace_from_node(node_to)
|
||||
node_to_trace_source = self._find_source_trace_from_node(node_to)
|
||||
node_from_idx = _find_idx_by_name(node_from.name, self.node_list)
|
||||
if init:
|
||||
node_to_trace["source"][node_to_dim] = {}
|
||||
node_to_trace_source[node_to_dim] = {}
|
||||
# add dim to cur new source
|
||||
if node_from_idx not in node_to_trace["source"][node_to_dim]:
|
||||
node_to_trace["source"][node_to_dim][node_from_idx] = [node_from_dim]
|
||||
if node_from_idx not in node_to_trace_source[node_to_dim]:
|
||||
node_to_trace_source[node_to_dim][node_from_idx] = [node_from_dim]
|
||||
else:
|
||||
if node_from_dim not in node_to_trace["source"][node_to_dim][node_from_idx]:
|
||||
node_to_trace["source"][node_to_dim][node_from_idx].append(
|
||||
if node_from_dim not in node_to_trace_source[node_to_dim][node_from_idx]:
|
||||
node_to_trace_source[node_to_dim][node_from_idx].append(
|
||||
node_from_dim
|
||||
)
|
||||
# update inputs source
|
||||
node_to_trace["source"][node_to_dim].update(
|
||||
node_from_trace["source"][node_from_dim]
|
||||
)
|
||||
for node_idx, node_dim in node_from_trace_source[node_from_dim].items():
|
||||
if node_idx not in node_to_trace_source[node_to_dim]:
|
||||
node_to_trace_source[node_to_dim][node_idx] = copy.deepcopy(node_dim)
|
||||
else:
|
||||
for d in node_dim:
|
||||
if d not in node_to_trace_source[node_to_dim][node_idx]:
|
||||
node_to_trace_source[node_to_dim][node_idx].append(d)
|
||||
|
||||
def _mark_computation_from_node(self, node_from, node_to, exclude=None):
|
||||
if exclude == None:
|
||||
@ -1761,9 +1765,9 @@ class ChunkRegionSearch(object):
|
||||
)
|
||||
if self._stop_search(init_mem_peak, mem_peak):
|
||||
break
|
||||
# self.memory_estimator.estimate_chunk_inference_mem(
|
||||
# self.index_tracer.node_list, chunk_infos, print_mem=True
|
||||
# )
|
||||
self.memory_estimator.estimate_chunk_inference_mem(
|
||||
self.index_tracer.node_list, chunk_infos, print_mem=True
|
||||
)
|
||||
return chunk_infos
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user