[autochunk] refactor chunk memory estimation (#2762)

* refact memory code

* dont log free var memory

* add memory align

* update chunk target

* update setting for new memory

* finish test

* update tracer

* update typo

* update test
This commit is contained in:
Xuanlei Zhao
2023-03-08 16:22:30 +08:00
committed by GitHub
parent b51bfec357
commit 2ca9728cbb
12 changed files with 294 additions and 422 deletions

View File

@@ -33,7 +33,6 @@ class TraceIndice(object):
self.indice_trace_list = self._init_indice_trace_list()
self.indice_view_list = {}
self.indice_count = -1
self.trace_range = []
self.active_node_list = []
def _init_indice_trace_list(self) -> List:
@@ -50,8 +49,7 @@ class TraceIndice(object):
indice_trace_list.append(cur_trace)
return indice_trace_list
def set_trace_range(self, trace_range: List, active_node_list: List) -> None:
self.trace_range = trace_range
def set_active_nodes(self, active_node_list: List) -> None:
self.active_node_list = active_node_list
def _add_indice(self) -> int:
@@ -731,23 +729,35 @@ class TraceIndice(object):
dim_from.reverse()
# search view list
for view_node, view_dict in self.indice_view_list.items():
if (view_dict["idx_to"] == idx_from and view_dict["dim_to"] == dim_from
and view_dict["dim_from"] == dim_to):
# inheirt indice from current node
if len_diff == 1:
if origin_shape[dim_from[0]] == 1:
self._inherit_indice(origin_node, dim_from[1], node, dim_to[0], init=False)
elif origin_shape[dim_from[1]] == 1:
self._inherit_indice(origin_node, dim_from[0], node, dim_to[0], init=False)
elif len_diff == -1:
if target_shape[dim_to[0]] == 1:
self._inherit_indice(origin_node, dim_from[0], node, dim_to[1], init=False)
elif target_shape[dim_to[1]] == 1:
self._inherit_indice(origin_node, dim_from[0], node, dim_to[0], init=False)
# inherid indice from input node of last view
for dim_to_i in dim_to:
self._inherit_indice(view_node.args[0], dim_to_i, node, dim_to_i, init=False)
# for view_node, view_dict in self.indice_view_list.items():
# if (view_dict["idx_to"] == idx_from and view_dict["dim_to"] == dim_from
# and view_dict["dim_from"] == dim_to):
# # inheirt indice from current node
# if len_diff == 1:
# if origin_shape[dim_from[0]] == 1:
# self._inherit_indice(origin_node, dim_from[1], node, dim_to[0], init=False)
# elif origin_shape[dim_from[1]] == 1:
# self._inherit_indice(origin_node, dim_from[0], node, dim_to[0], init=False)
# elif len_diff == -1:
# if target_shape[dim_to[0]] == 1:
# self._inherit_indice(origin_node, dim_from[0], node, dim_to[1], init=False)
# elif target_shape[dim_to[1]] == 1:
# self._inherit_indice(origin_node, dim_from[0], node, dim_to[0], init=False)
# # inherid indice from input node of last view
# for dim_to_i in dim_to:
# self._inherit_indice(view_node.args[0], dim_to_i, node, dim_to_i, init=False)
# inheirt indice from current node
if len_diff == 1:
if origin_shape[dim_from[0]] == 1:
self._inherit_indice(origin_node, dim_from[1], node, dim_to[0], init=False)
elif origin_shape[dim_from[1]] == 1:
self._inherit_indice(origin_node, dim_from[0], node, dim_to[0], init=False)
elif len_diff == -1:
if target_shape[dim_to[0]] == 1:
self._inherit_indice(origin_node, dim_from[0], node, dim_to[1], init=False)
elif target_shape[dim_to[1]] == 1:
self._inherit_indice(origin_node, dim_from[0], node, dim_to[0], init=False)
# log view, not used now
view_dict = {
@@ -762,32 +772,22 @@ class TraceIndice(object):
"""
clear too far trace to speed up computation
"""
trace_range = None
for i in range(len(self.trace_range)):
if self.trace_range[i][1] == node_idx:
trace_range = (self.trace_range[i][0], self.trace_range[i][1])
break
if self.trace_range[i][1] > node_idx:
break
if trace_range is None:
return
trace_barrier = max(node_idx - 100, 0)
active_nodes = self.active_node_list[trace_barrier]
active_nodes = [self.node_mgr.find_node_idx(i) for i in active_nodes.keys()]
active_nodes = self.active_node_list[trace_range[0]:trace_range[1] + 1]
active_nodes = set(flat_list(active_nodes))
active_nodes = [self.node_mgr.find_node_idx_by_name(i) for i in active_nodes]
for i in range(trace_range[0], trace_range[1] + 1):
trace = self.indice_trace_list[i]
# clear compute
for dim_compute in trace["compute"]:
for i in range(len(dim_compute) - 1, -1, -1):
if (dim_compute[i] < trace_range[0] and dim_compute[i] not in active_nodes):
dim_compute.pop(i)
continue
# clear source
for dim_source in trace["source"]:
for k in list(dim_source.keys()):
if k < trace_range[0] and k not in active_nodes:
dim_source.pop(k)
trace = self.indice_trace_list[node_idx]
# clear compute
for dim_compute in trace["compute"]:
for i in range(len(dim_compute) - 1, -1, -1):
if (dim_compute[i] < trace_barrier and dim_compute[i] not in active_nodes):
dim_compute.pop(i)
continue
# clear source
for dim_source in trace["source"]:
for k in list(dim_source.keys()):
if k < trace_barrier and k not in active_nodes:
dim_source.pop(k)
def trace_indice(self) -> None:
for idx, node in enumerate(self.node_mgr.get_node_list()):