mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-24 06:29:09 +00:00
update source add
This commit is contained in:
parent
f5515e9978
commit
e5a5fbb8a9
@ -133,24 +133,28 @@ class IndexTracer(object):
|
|||||||
|
|
||||||
def _add_source(self, node_from, node_from_dim, node_to, node_to_dim, init=False):
|
def _add_source(self, node_from, node_from_dim, node_to, node_to_dim, init=False):
|
||||||
node_from_dim = self._transform_index(node_from, node_from_dim)
|
node_from_dim = self._transform_index(node_from, node_from_dim)
|
||||||
node_from_trace = self._find_trace_from_node(node_from)
|
node_from_trace_source = self._find_source_trace_from_node(node_from)
|
||||||
node_to_dim = self._transform_index(node_to, node_to_dim)
|
node_to_dim = self._transform_index(node_to, node_to_dim)
|
||||||
node_to_trace = self._find_trace_from_node(node_to)
|
node_to_trace_source = self._find_source_trace_from_node(node_to)
|
||||||
node_from_idx = _find_idx_by_name(node_from.name, self.node_list)
|
node_from_idx = _find_idx_by_name(node_from.name, self.node_list)
|
||||||
if init:
|
if init:
|
||||||
node_to_trace["source"][node_to_dim] = {}
|
node_to_trace_source[node_to_dim] = {}
|
||||||
# add dim to cur new source
|
# add dim to cur new source
|
||||||
if node_from_idx not in node_to_trace["source"][node_to_dim]:
|
if node_from_idx not in node_to_trace_source[node_to_dim]:
|
||||||
node_to_trace["source"][node_to_dim][node_from_idx] = [node_from_dim]
|
node_to_trace_source[node_to_dim][node_from_idx] = [node_from_dim]
|
||||||
else:
|
else:
|
||||||
if node_from_dim not in node_to_trace["source"][node_to_dim][node_from_idx]:
|
if node_from_dim not in node_to_trace_source[node_to_dim][node_from_idx]:
|
||||||
node_to_trace["source"][node_to_dim][node_from_idx].append(
|
node_to_trace_source[node_to_dim][node_from_idx].append(
|
||||||
node_from_dim
|
node_from_dim
|
||||||
)
|
)
|
||||||
# update inputs source
|
# update inputs source
|
||||||
node_to_trace["source"][node_to_dim].update(
|
for node_idx, node_dim in node_from_trace_source[node_from_dim].items():
|
||||||
node_from_trace["source"][node_from_dim]
|
if node_idx not in node_to_trace_source[node_to_dim]:
|
||||||
)
|
node_to_trace_source[node_to_dim][node_idx] = copy.deepcopy(node_dim)
|
||||||
|
else:
|
||||||
|
for d in node_dim:
|
||||||
|
if d not in node_to_trace_source[node_to_dim][node_idx]:
|
||||||
|
node_to_trace_source[node_to_dim][node_idx].append(d)
|
||||||
|
|
||||||
def _mark_computation_from_node(self, node_from, node_to, exclude=None):
|
def _mark_computation_from_node(self, node_from, node_to, exclude=None):
|
||||||
if exclude == None:
|
if exclude == None:
|
||||||
@ -1761,9 +1765,9 @@ class ChunkRegionSearch(object):
|
|||||||
)
|
)
|
||||||
if self._stop_search(init_mem_peak, mem_peak):
|
if self._stop_search(init_mem_peak, mem_peak):
|
||||||
break
|
break
|
||||||
# self.memory_estimator.estimate_chunk_inference_mem(
|
self.memory_estimator.estimate_chunk_inference_mem(
|
||||||
# self.index_tracer.node_list, chunk_infos, print_mem=True
|
self.index_tracer.node_list, chunk_infos, print_mem=True
|
||||||
# )
|
)
|
||||||
return chunk_infos
|
return chunk_infos
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user