mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-13 13:11:05 +00:00
[autochunk] refactor chunk memory estimation (#2762)
* refact memory code * dont log free var memory * add memory align * update chunk target * update setting for new memory * finish test * update tracer * update typo * update test
This commit is contained in:
@@ -33,7 +33,6 @@ class TraceIndice(object):
|
||||
self.indice_trace_list = self._init_indice_trace_list()
|
||||
self.indice_view_list = {}
|
||||
self.indice_count = -1
|
||||
self.trace_range = []
|
||||
self.active_node_list = []
|
||||
|
||||
def _init_indice_trace_list(self) -> List:
|
||||
@@ -50,8 +49,7 @@ class TraceIndice(object):
|
||||
indice_trace_list.append(cur_trace)
|
||||
return indice_trace_list
|
||||
|
||||
def set_trace_range(self, trace_range: List, active_node_list: List) -> None:
|
||||
self.trace_range = trace_range
|
||||
def set_active_nodes(self, active_node_list: List) -> None:
|
||||
self.active_node_list = active_node_list
|
||||
|
||||
def _add_indice(self) -> int:
|
||||
@@ -731,23 +729,35 @@ class TraceIndice(object):
|
||||
dim_from.reverse()
|
||||
|
||||
# search view list
|
||||
for view_node, view_dict in self.indice_view_list.items():
|
||||
if (view_dict["idx_to"] == idx_from and view_dict["dim_to"] == dim_from
|
||||
and view_dict["dim_from"] == dim_to):
|
||||
# inheirt indice from current node
|
||||
if len_diff == 1:
|
||||
if origin_shape[dim_from[0]] == 1:
|
||||
self._inherit_indice(origin_node, dim_from[1], node, dim_to[0], init=False)
|
||||
elif origin_shape[dim_from[1]] == 1:
|
||||
self._inherit_indice(origin_node, dim_from[0], node, dim_to[0], init=False)
|
||||
elif len_diff == -1:
|
||||
if target_shape[dim_to[0]] == 1:
|
||||
self._inherit_indice(origin_node, dim_from[0], node, dim_to[1], init=False)
|
||||
elif target_shape[dim_to[1]] == 1:
|
||||
self._inherit_indice(origin_node, dim_from[0], node, dim_to[0], init=False)
|
||||
# inherid indice from input node of last view
|
||||
for dim_to_i in dim_to:
|
||||
self._inherit_indice(view_node.args[0], dim_to_i, node, dim_to_i, init=False)
|
||||
# for view_node, view_dict in self.indice_view_list.items():
|
||||
# if (view_dict["idx_to"] == idx_from and view_dict["dim_to"] == dim_from
|
||||
# and view_dict["dim_from"] == dim_to):
|
||||
# # inheirt indice from current node
|
||||
# if len_diff == 1:
|
||||
# if origin_shape[dim_from[0]] == 1:
|
||||
# self._inherit_indice(origin_node, dim_from[1], node, dim_to[0], init=False)
|
||||
# elif origin_shape[dim_from[1]] == 1:
|
||||
# self._inherit_indice(origin_node, dim_from[0], node, dim_to[0], init=False)
|
||||
# elif len_diff == -1:
|
||||
# if target_shape[dim_to[0]] == 1:
|
||||
# self._inherit_indice(origin_node, dim_from[0], node, dim_to[1], init=False)
|
||||
# elif target_shape[dim_to[1]] == 1:
|
||||
# self._inherit_indice(origin_node, dim_from[0], node, dim_to[0], init=False)
|
||||
# # inherid indice from input node of last view
|
||||
# for dim_to_i in dim_to:
|
||||
# self._inherit_indice(view_node.args[0], dim_to_i, node, dim_to_i, init=False)
|
||||
|
||||
# inheirt indice from current node
|
||||
if len_diff == 1:
|
||||
if origin_shape[dim_from[0]] == 1:
|
||||
self._inherit_indice(origin_node, dim_from[1], node, dim_to[0], init=False)
|
||||
elif origin_shape[dim_from[1]] == 1:
|
||||
self._inherit_indice(origin_node, dim_from[0], node, dim_to[0], init=False)
|
||||
elif len_diff == -1:
|
||||
if target_shape[dim_to[0]] == 1:
|
||||
self._inherit_indice(origin_node, dim_from[0], node, dim_to[1], init=False)
|
||||
elif target_shape[dim_to[1]] == 1:
|
||||
self._inherit_indice(origin_node, dim_from[0], node, dim_to[0], init=False)
|
||||
|
||||
# log view, not used now
|
||||
view_dict = {
|
||||
@@ -762,32 +772,22 @@ class TraceIndice(object):
|
||||
"""
|
||||
clear too far trace to speed up computation
|
||||
"""
|
||||
trace_range = None
|
||||
for i in range(len(self.trace_range)):
|
||||
if self.trace_range[i][1] == node_idx:
|
||||
trace_range = (self.trace_range[i][0], self.trace_range[i][1])
|
||||
break
|
||||
if self.trace_range[i][1] > node_idx:
|
||||
break
|
||||
if trace_range is None:
|
||||
return
|
||||
trace_barrier = max(node_idx - 100, 0)
|
||||
active_nodes = self.active_node_list[trace_barrier]
|
||||
active_nodes = [self.node_mgr.find_node_idx(i) for i in active_nodes.keys()]
|
||||
|
||||
active_nodes = self.active_node_list[trace_range[0]:trace_range[1] + 1]
|
||||
active_nodes = set(flat_list(active_nodes))
|
||||
active_nodes = [self.node_mgr.find_node_idx_by_name(i) for i in active_nodes]
|
||||
for i in range(trace_range[0], trace_range[1] + 1):
|
||||
trace = self.indice_trace_list[i]
|
||||
# clear compute
|
||||
for dim_compute in trace["compute"]:
|
||||
for i in range(len(dim_compute) - 1, -1, -1):
|
||||
if (dim_compute[i] < trace_range[0] and dim_compute[i] not in active_nodes):
|
||||
dim_compute.pop(i)
|
||||
continue
|
||||
# clear source
|
||||
for dim_source in trace["source"]:
|
||||
for k in list(dim_source.keys()):
|
||||
if k < trace_range[0] and k not in active_nodes:
|
||||
dim_source.pop(k)
|
||||
trace = self.indice_trace_list[node_idx]
|
||||
# clear compute
|
||||
for dim_compute in trace["compute"]:
|
||||
for i in range(len(dim_compute) - 1, -1, -1):
|
||||
if (dim_compute[i] < trace_barrier and dim_compute[i] not in active_nodes):
|
||||
dim_compute.pop(i)
|
||||
continue
|
||||
# clear source
|
||||
for dim_source in trace["source"]:
|
||||
for k in list(dim_source.keys()):
|
||||
if k < trace_barrier and k not in active_nodes:
|
||||
dim_source.pop(k)
|
||||
|
||||
def trace_indice(self) -> None:
|
||||
for idx, node in enumerate(self.node_mgr.get_node_list()):
|
||||
|
Reference in New Issue
Block a user