mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-13 05:01:44 +00:00
[autochunk] support parsing blocks (#2506)
This commit is contained in:
@@ -33,6 +33,8 @@ class TraceIndice(object):
|
||||
self.indice_trace_list = self._init_indice_trace_list()
|
||||
self.indice_view_list = {}
|
||||
self.indice_count = -1
|
||||
self.trace_range = []
|
||||
self.active_node_list = []
|
||||
|
||||
def _init_indice_trace_list(self):
|
||||
indice_trace_list = []
|
||||
@@ -48,6 +50,10 @@ class TraceIndice(object):
|
||||
indice_trace_list.append(cur_trace)
|
||||
return indice_trace_list
|
||||
|
||||
def set_trace_range(self, trace_range: List, active_node_list: List) -> None:
|
||||
self.trace_range = trace_range
|
||||
self.active_node_list = active_node_list
|
||||
|
||||
def _add_indice(self):
|
||||
"""
|
||||
Update the count and return it. To record the idx number.
|
||||
@@ -493,6 +499,9 @@ class TraceIndice(object):
|
||||
new_dim_num = sum([1 if str(i) == "None" else 0 for i in node_args])
|
||||
for _ in range(new_dim_num):
|
||||
self._del_dim(node_idx, 0)
|
||||
delete_dim_num = sum([1 if str(i) == "0" else 0 for i in node_args])
|
||||
for _ in range(delete_dim_num):
|
||||
self._add_dim(node_idx, 0)
|
||||
self._assign_indice_as_input(node, node_idx)
|
||||
|
||||
for _, node_arg in enumerate(node_args):
|
||||
@@ -513,6 +522,9 @@ class TraceIndice(object):
|
||||
elif "None" == node_arg_str:
|
||||
self._add_dim(node_idx, new_idx_count)
|
||||
new_idx_count += 1
|
||||
elif "0" == node_arg_str:
|
||||
self._del_dim(node_idx, new_idx_count)
|
||||
origin_idx_count += 1
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
@@ -596,6 +608,37 @@ class TraceIndice(object):
|
||||
}
|
||||
self.indice_view_list[node] = view_dict
|
||||
|
||||
def _clear_trace(self, node_idx: int) -> None:
|
||||
"""
|
||||
clear too far trace to speed up computation
|
||||
"""
|
||||
trace_range = None
|
||||
for i in range(len(self.trace_range)):
|
||||
if self.trace_range[i][1] == node_idx:
|
||||
trace_range = (self.trace_range[i][0], self.trace_range[i][1])
|
||||
break
|
||||
if self.trace_range[i][1] > node_idx:
|
||||
break
|
||||
if trace_range is None:
|
||||
return
|
||||
|
||||
active_nodes = self.active_node_list[trace_range[0]:trace_range[1] + 1]
|
||||
active_nodes = set(flat_list(active_nodes))
|
||||
active_nodes = [find_idx_by_name(i, self.node_list) for i in active_nodes]
|
||||
for i in range(trace_range[0], trace_range[1] + 1):
|
||||
trace = self.indice_trace_list[i]
|
||||
# clear compute
|
||||
for dim_compute in trace["compute"]:
|
||||
for i in range(len(dim_compute) - 1, -1, -1):
|
||||
if dim_compute[i] < trace_range[0] and dim_compute[i] not in active_nodes:
|
||||
dim_compute.pop(i)
|
||||
continue
|
||||
# clear source
|
||||
for dim_source in trace["source"]:
|
||||
for k in list(dim_source.keys()):
|
||||
if k < trace_range[0] and k not in active_nodes:
|
||||
dim_source.pop(k)
|
||||
|
||||
def trace_indice(self):
|
||||
for idx, node in enumerate(self.node_list):
|
||||
if node.op == "placeholder":
|
||||
@@ -655,3 +698,6 @@ class TraceIndice(object):
|
||||
continue
|
||||
else:
|
||||
raise NotImplementedError(node.op, "op not implemented yet!")
|
||||
|
||||
# limit trace range
|
||||
self._clear_trace(idx)
|
||||
|
Reference in New Issue
Block a user