[autochunk] support parsing blocks (#2506)

This commit is contained in:
oahzxl
2023-01-20 11:18:17 +08:00
committed by GitHub
parent 35c0c0006e
commit c04f183237
7 changed files with 314 additions and 22 deletions

View File

@@ -33,6 +33,8 @@ class TraceIndice(object):
self.indice_trace_list = self._init_indice_trace_list()
self.indice_view_list = {}
self.indice_count = -1
self.trace_range = []
self.active_node_list = []
def _init_indice_trace_list(self):
indice_trace_list = []
@@ -48,6 +50,10 @@ class TraceIndice(object):
indice_trace_list.append(cur_trace)
return indice_trace_list
def set_trace_range(self, trace_range: List, active_node_list: List) -> None:
self.trace_range = trace_range
self.active_node_list = active_node_list
def _add_indice(self):
"""
Update the count and return it. To record the idx number.
@@ -493,6 +499,9 @@ class TraceIndice(object):
new_dim_num = sum([1 if str(i) == "None" else 0 for i in node_args])
for _ in range(new_dim_num):
self._del_dim(node_idx, 0)
delete_dim_num = sum([1 if str(i) == "0" else 0 for i in node_args])
for _ in range(delete_dim_num):
self._add_dim(node_idx, 0)
self._assign_indice_as_input(node, node_idx)
for _, node_arg in enumerate(node_args):
@@ -513,6 +522,9 @@ class TraceIndice(object):
elif "None" == node_arg_str:
self._add_dim(node_idx, new_idx_count)
new_idx_count += 1
elif "0" == node_arg_str:
self._del_dim(node_idx, new_idx_count)
origin_idx_count += 1
else:
raise NotImplementedError()
@@ -596,6 +608,37 @@ class TraceIndice(object):
}
self.indice_view_list[node] = view_dict
def _clear_trace(self, node_idx: int) -> None:
"""
clear too far trace to speed up computation
"""
trace_range = None
for i in range(len(self.trace_range)):
if self.trace_range[i][1] == node_idx:
trace_range = (self.trace_range[i][0], self.trace_range[i][1])
break
if self.trace_range[i][1] > node_idx:
break
if trace_range is None:
return
active_nodes = self.active_node_list[trace_range[0]:trace_range[1] + 1]
active_nodes = set(flat_list(active_nodes))
active_nodes = [find_idx_by_name(i, self.node_list) for i in active_nodes]
for i in range(trace_range[0], trace_range[1] + 1):
trace = self.indice_trace_list[i]
# clear compute
for dim_compute in trace["compute"]:
for i in range(len(dim_compute) - 1, -1, -1):
if dim_compute[i] < trace_range[0] and dim_compute[i] not in active_nodes:
dim_compute.pop(i)
continue
# clear source
for dim_source in trace["source"]:
for k in list(dim_source.keys()):
if k < trace_range[0] and k not in active_nodes:
dim_source.pop(k)
def trace_indice(self):
for idx, node in enumerate(self.node_list):
if node.op == "placeholder":
@@ -655,3 +698,6 @@ class TraceIndice(object):
continue
else:
raise NotImplementedError(node.op, "op not implemented yet!")
# limit trace range
self._clear_trace(idx)