redesign index tracer, add source and change compute

This commit is contained in:
oahzxl 2022-12-09 17:39:02 +08:00
parent 2b4ebcc278
commit 979e61db92

View File

@ -16,6 +16,11 @@ def _delete_free_var_from_last_use(user_to_last_uses):
if n.op == 'placeholder': if n.op == 'placeholder':
user_to_last_uses[key].remove(n) user_to_last_uses[key].remove(n)
def _get_node_shape(node):
if hasattr(node.meta['tensor_meta'], "shape"):
return node.meta['tensor_meta'].shape
return None
class FlowTracer(object): class FlowTracer(object):
def __init__(self, gm) -> None: def __init__(self, gm) -> None:
@ -136,11 +141,25 @@ class IndexTracer(object):
def __init__(self, gm) -> None: def __init__(self, gm) -> None:
self.gm = gm self.gm = gm
self.nodes_list = list(gm.graph.nodes) self.nodes_list = list(gm.graph.nodes)
self.idx_trace_list = [{'idx': [], 'compute': {}} for _ in range(len(self.nodes_list))] self.idx_trace_list = self._init_idx_trace_list()
self.idx_trace_equal = [] self.idx_trace_equal = []
self.idx_view_list = [] self.idx_view_list = []
self.idx_count = -1 self.idx_count = -1
def _init_idx_trace_list(self):
idx_trace_list = []
for n in self.nodes_list:
if _get_node_shape(n) != None:
cur_trace = {
'idx': [None for _ in range(len(_get_node_shape(n)))],
'compute': [[] for _ in range(len(_get_node_shape(n)))],
'source': [[] for _ in range(len(_get_node_shape(n)))],
}
else:
cur_trace = {'idx': [], 'compute': [], 'source': []}
idx_trace_list.append(cur_trace)
return idx_trace_list
def _add_index(self): def _add_index(self):
""" """
Update the count and return it. To record the idx number. Update the count and return it. To record the idx number.
@ -151,26 +170,67 @@ class IndexTracer(object):
self.idx_count += 1 self.idx_count += 1
return self.idx_count return self.idx_count
def _inherit_computation(self, node_from, node_to): def _del_dim(self, idx, dim_idx):
""" self.idx_trace_list[idx]['idx'].pop(dim_idx)
Inherit computed dim from node_from to node_to. self.idx_trace_list[idx]['compute'].pop(dim_idx)
If a dim in node_from is marked as computed and exists in node_to, self.idx_trace_list[idx]['source'].pop(dim_idx)
still mark it as computed in node_to.
Args: def _add_dim(self, idx, dim_idx):
node_from (node): node to be inherited self.idx_trace_list[idx]['idx'].insert(dim_idx, self._add_index())
node_to (node): new node to inherit self.idx_trace_list[idx]['compute'].insert(dim_idx, [])
""" self.idx_trace_list[idx]['source'].insert(dim_idx, [])
_, compute_from = self._find_trace_from_node(node_from)
idx_to, compute_to = self._find_trace_from_node(node_to) def _transform_index(self, node, node_dim):
for k, v in compute_from.items(): node_idx = self._find_idx_trace_from_node(node)
if k in idx_to: dims = list(range(len(node_idx)))
if k in compute_to: return dims[node_dim]
compute_to[k].extend(v)
def _inherit_index(self, node_from, node_from_dim, node_to, node_to_dim):
node_from_dim = self._transform_index(node_from, node_from_dim)
node_to_dim = self._transform_index(node_to, node_to_dim)
node_from_trace = self._find_trace_from_node(node_from)
node_to_trace = self._find_trace_from_node(node_to)
node_to_trace['idx'][node_to_dim] = node_from_trace['idx'][node_from_dim]
node_to_trace['compute'][node_to_dim] = copy.deepcopy(node_from_trace['compute'][node_from_dim])
node_from_idx = _find_idx_by_name(node_from.name, self.nodes_list)
node_to_trace['source'][node_to_dim] = []
node_to_trace['source'][node_to_dim].append({node_from_idx: node_from_dim})
node_to_trace['source'][node_to_dim].extend(node_from_trace['source'][node_from_dim])
def _inherit_all_computation(self, node_from, node_to):
node_from_compute = self._find_compute_trace_from_node(node_from)
node_to_compute = self._find_compute_trace_from_node(node_to)
assert len(node_from_compute) == len(node_to_compute)
for i in range(len(node_from_compute)):
self._add_source(node_from, i, node_to, i)
node_to_compute[i] = copy.deepcopy(node_from_compute[i])
def _add_source(self, node_from, node_from_dim, node_to, node_to_dim):
node_from_dim = self._transform_index(node_from, node_from_dim)
node_from_trace = self._find_trace_from_node(node_from)
node_to_dim = self._transform_index(node_to, node_to_dim)
node_to_trace = self._find_trace_from_node(node_to)
node_from_idx = _find_idx_by_name(node_from.name, self.nodes_list)
node_to_trace['source'][node_to_dim].append({node_from_idx: node_from_dim})
node_to_trace['source'][node_to_dim].extend(node_from_trace['source'][node_from_dim])
def _mark_computation_from_node(self, node_from, node_to, exclude=None):
if exclude == None:
exclude = []
else: else:
compute_to[k] = copy.deepcopy(v) exclude = [self._transform_index(node_to, i) for i in exclude]
node_from_compute = self._find_compute_trace_from_node(node_from)
node_to_compute = self._find_compute_trace_from_node(node_to)
# assert len(node_from_compute) == len(node_to_compute)
for i in range(-1, -min(len(node_from_compute), len(node_to_compute)) - 1, -1):
if self._transform_index(node_to, i) in exclude:
continue
self._add_source(node_from, i, node_to, i)
for j in node_from_compute[i]:
if j not in node_to_compute[i]:
node_to_compute[i].append(j)
def _mark_idx_equal(self, idx1, idx2): def _mark_idx_equal(self, node1, dim1, node2, dim2):
""" """
Mark 2 index to be equal. Mark 2 index to be equal.
@ -178,7 +238,12 @@ class IndexTracer(object):
idx1 (int): index count. idx1 (int): index count.
idx2 (int): index count. idx2 (int): index count.
""" """
self.idx_trace_equal.append((idx1, idx2)) # node1_idx = _find_idx_by_name(node1.name, self.nodes_list)
# node2_idx = _find_idx_by_name(node2.name, self.nodes_list)
# if node1_idx > node2_idx:
# self._add_source(node2, dim2, node1, dim1)
# else:
# self._add_source(node1, dim1, node2, dim2)
def _mark_computation(self, node, idx, dim): def _mark_computation(self, node, idx, dim):
""" """
@ -189,15 +254,13 @@ class IndexTracer(object):
idx (int): node index idx (int): node index
dim (list or int): dims to be marked as computed dim (list or int): dims to be marked as computed
""" """
input_node_idx_trace = self._find_idx_trace_from_node(node)
if isinstance(dim, int): if isinstance(dim, int):
dim = [dim] dim = [dim]
dims = list(range(len(_get_node_shape(node))))
for d in dim: for d in dim:
cur_idx = input_node_idx_trace[d] cur_dim = dims[d]
if cur_idx not in self.idx_trace_list[idx]['compute']: if idx not in self.idx_trace_list[idx]['compute'][cur_dim]:
self.idx_trace_list[idx]['compute'][cur_idx] = [idx] self.idx_trace_list[idx]['compute'][cur_dim].append(idx)
else:
self.idx_trace_list[idx]['compute'][cur_idx].append(idx)
def _find_trace_from_node(self, node): def _find_trace_from_node(self, node):
""" """
@ -211,7 +274,7 @@ class IndexTracer(object):
""" """
node_idx = _find_idx_by_name(node.name, self.nodes_list) node_idx = _find_idx_by_name(node.name, self.nodes_list)
node_dict = self.idx_trace_list[node_idx] node_dict = self.idx_trace_list[node_idx]
return node_dict['idx'], node_dict['compute'] return node_dict
def _find_idx_trace_from_node(self, node): def _find_idx_trace_from_node(self, node):
""" """
@ -237,7 +300,7 @@ class IndexTracer(object):
node_idx = _find_idx_by_name(node.name, self.nodes_list) node_idx = _find_idx_by_name(node.name, self.nodes_list)
return self.idx_trace_list[node_idx]['compute'] return self.idx_trace_list[node_idx]['compute']
def _assign_index_as_input(self, node, node_idx): def _assign_index_as_input(self, node, node_idx, input_node=None):
""" """
Assign node's trace as its input node. Assign node's trace as its input node.
@ -245,12 +308,16 @@ class IndexTracer(object):
node (node) node (node)
node_idx (int) node_idx (int)
""" """
input_node_idx = _find_idx_by_name(node.args[0].name, self.nodes_list) if input_node == None:
input_node = node.args[0]
input_node_idx = _find_idx_by_name(input_node.name, self.nodes_list)
input_node_idx_trace = self.idx_trace_list[input_node_idx]['idx'] input_node_idx_trace = self.idx_trace_list[input_node_idx]['idx']
new_idx_trace = copy.deepcopy(input_node_idx_trace) new_idx_trace = copy.deepcopy(input_node_idx_trace)
self.idx_trace_list[node_idx]['idx'] = new_idx_trace self.idx_trace_list[node_idx]['idx'] = new_idx_trace
self._inherit_all_computation(input_node, node)
def _assign_all_index(self, node, node_idx): def _assign_all_index(self, node, node_idx):
""" """
Add new index for all node's dims. Add new index for all node's dims.
@ -275,15 +342,12 @@ class IndexTracer(object):
node (node) node (node)
node_idx (int) node_idx (int)
""" """
input_node = node.args[0]
tranpose_dim = node.args[1:] tranpose_dim = node.args[1:]
input_node_idx_trace = self._find_idx_trace_from_node(node.args[0])
new_idx_trace = copy.deepcopy(input_node_idx_trace) self._assign_index_as_input(node, node_idx, input_node)
new_idx_trace[tranpose_dim[0]] = input_node_idx_trace[tranpose_dim[1]] self._inherit_index(input_node, tranpose_dim[1], node, tranpose_dim[0])
new_idx_trace[tranpose_dim[1]] = input_node_idx_trace[tranpose_dim[0]] self._inherit_index(input_node, tranpose_dim[0], node, tranpose_dim[1])
self.idx_trace_list[node_idx]['idx'] = new_idx_trace
self._inherit_computation(node.args[0], node)
def _assign_permute_index(self, node, node_idx): def _assign_permute_index(self, node, node_idx):
""" """
@ -296,14 +360,11 @@ class IndexTracer(object):
node_idx (int) node_idx (int)
""" """
permute_dim = node.args[1:] permute_dim = node.args[1:]
input_node_idx_trace = self._find_idx_trace_from_node(node.args[0]) input_node = node.args[0]
new_idx_trace = copy.deepcopy(input_node_idx_trace) self._assign_index_as_input(node, node_idx, input_node)
for idx, d in enumerate(permute_dim): for idx, d in enumerate(permute_dim):
new_idx_trace[idx] = input_node_idx_trace[d] self._inherit_index(input_node, d, node, idx)
self.idx_trace_list[node_idx]['idx'] = new_idx_trace
self._inherit_computation(node.args[0], node)
def _assign_linear_index(self, node, node_idx): def _assign_linear_index(self, node, node_idx):
""" """
@ -321,20 +382,15 @@ class IndexTracer(object):
bias = None bias = None
else: else:
input_node, weight, bias = node.args input_node, weight, bias = node.args
input_node_idx_trace = self._find_idx_trace_from_node(input_node)
weight_idx_trace = self._find_idx_trace_from_node(weight)
new_idx_trace = copy.deepcopy(input_node_idx_trace) self._assign_index_as_input(node, node_idx)
new_idx_trace[-1] = weight_idx_trace[1] self._inherit_index(weight, 1, node, -1)
self.idx_trace_list[node_idx]['idx'] = new_idx_trace
self._inherit_computation(input_node, node)
self._mark_computation(node, node_idx, [-1]) self._mark_computation(node, node_idx, [-1])
self._mark_idx_equal(input_node_idx_trace[-1], weight_idx_trace[0]) self._mark_idx_equal(input_node, -1, weight, 0)
if bias: if bias:
bias_idx_trace = self._find_idx_trace_from_node(bias) self._mark_idx_equal(input_node, -1, bias, 0)
self._mark_idx_equal(input_node_idx_trace[-1], bias_idx_trace[0])
def _assign_matmul_index(self, node, node_idx): def _assign_matmul_index(self, node, node_idx):
""" """
@ -348,18 +404,14 @@ class IndexTracer(object):
node_idx (int) node_idx (int)
""" """
matmul_left, matmul_right = node.args matmul_left, matmul_right = node.args
matmul_left_idx_trace = self._find_idx_trace_from_node(matmul_left)
matmul_right_idx_trace = self._find_idx_trace_from_node(matmul_right)
assert(len(matmul_left_idx_trace) == len(matmul_right_idx_trace)) assert(len(_get_node_shape(matmul_left)) == len(_get_node_shape(matmul_right)))
new_idx_trace = copy.deepcopy(matmul_left_idx_trace) self._assign_index_as_input(node, node_idx, matmul_left)
new_idx_trace[-1] = matmul_right_idx_trace[-1] self._inherit_index(matmul_right, -1, node, -1)
self.idx_trace_list[node_idx]['idx'] = new_idx_trace
self._inherit_computation(matmul_left, node) self._mark_computation_from_node(matmul_right, node, [-1, -2])
self._inherit_computation(matmul_right, node)
self._mark_computation(node, node_idx, [-1]) self._mark_computation(node, node_idx, [-1])
self._mark_idx_equal(matmul_left_idx_trace[-1], matmul_right_idx_trace[-2]) self._mark_idx_equal(matmul_left, -1, matmul_right, -2)
def _assign_layernorm_index(self, node, idx): def _assign_layernorm_index(self, node, idx):
""" """
@ -372,7 +424,6 @@ class IndexTracer(object):
node_idx (int) node_idx (int)
""" """
self._assign_index_as_input(node, idx) self._assign_index_as_input(node, idx)
self._inherit_computation(node.args[0], node)
self._mark_computation(node, idx, [-1, -2]) self._mark_computation(node, idx, [-1, -2])
def _assign_elementwise_index(self, node, idx): def _assign_elementwise_index(self, node, idx):
@ -386,9 +437,59 @@ class IndexTracer(object):
node_idx (int) node_idx (int)
""" """
self._assign_index_as_input(node, idx) self._assign_index_as_input(node, idx)
nodes_in = []
for node_in in node.args: for node_in in node.args:
if type(node_in) not in (int, float): if type(node_in) == type(node):
self._inherit_computation(node_in, node) nodes_in.append(node_in)
self._mark_computation_from_node(node_in, node)
assert len(nodes_in) <= 2
if len(nodes_in) == 2:
node_in0_shape = _get_node_shape(nodes_in[0])
node_in1_shape = _get_node_shape(nodes_in[1])
for i in range(-1, -min(len(node_in0_shape), len(node_in1_shape)) - 1, -1):
if node_in0_shape[i] == node_in1_shape[i]:
self._mark_idx_equal(nodes_in[0], i, nodes_in[1], i)
def _assgin_no_change_index(self, node, idx):
self._assign_index_as_input(node, idx)
for node_in in node.args:
if type(node_in) == type(node):
self._mark_computation_from_node(node_in, node)
def _assign_einsum_index(self, node, idx):
"""
Assign index for einsum op.
Args:
node (node)
node_idx (int)
"""
patterns = node.args[0]
input_nodes = node.args[1:]
patterns = patterns.replace(" ", "")
left, right = patterns.split("->")
left = left.split(",")
all_index = []
for i in left:
for c in i:
all_index.append(c)
all_index = set(all_index)
free_index = set([i for i in right])
sum_index = all_index - free_index
for right_idx, right_indice in enumerate(right):
for left_idx, left_str in enumerate(left):
if right_indice in left_str:
source_idx = left_str.index(right_indice)
self._inherit_index(input_nodes[left_idx], source_idx, node, right_idx)
for i in sum_index:
for left_idx, left_str in enumerate(left):
if i in left_str:
self._mark_computation(node, idx, left_str.index(i))
break
def _assign_softmax_index(self, node, idx): def _assign_softmax_index(self, node, idx):
""" """
@ -401,7 +502,6 @@ class IndexTracer(object):
node_idx (int) node_idx (int)
""" """
self._assign_index_as_input(node, idx) self._assign_index_as_input(node, idx)
self._inherit_computation(node.args[0], node)
self._mark_computation(node, idx, [node.kwargs['dim']]) self._mark_computation(node, idx, [node.kwargs['dim']])
def _assign_unsqueeze_index(self, node, node_idx): def _assign_unsqueeze_index(self, node, node_idx):
@ -413,9 +513,11 @@ class IndexTracer(object):
node (node) node (node)
node_idx (int) node_idx (int)
""" """
self._del_dim(node_idx, -1)
self._assign_index_as_input(node, node_idx) self._assign_index_as_input(node, node_idx)
self._inherit_computation(node.args[0], node)
self.idx_trace_list[node_idx]['idx'].insert(node.args[1], self._add_index()) self.idx_trace_list[node_idx]['idx'].insert(node.args[1], self._add_index())
self.idx_trace_list[node_idx]['compute'].insert(node.args[1], [])
self.idx_trace_list[node_idx]['source'].insert(node.args[1], [])
def _assign_dropout_index(self, node, node_idx): def _assign_dropout_index(self, node, node_idx):
""" """
@ -428,7 +530,6 @@ class IndexTracer(object):
""" """
self._assign_index_as_input(node, node_idx) self._assign_index_as_input(node, node_idx)
def _assign_ones_like_index(self, node, node_idx): def _assign_ones_like_index(self, node, node_idx):
""" """
Assign index for oneslike op. Assign index for oneslike op.
@ -440,17 +541,6 @@ class IndexTracer(object):
""" """
self._assign_all_index(node, node_idx) self._assign_all_index(node, node_idx)
def _assign_to_index(self, node, node_idx):
"""
Assign index for to op.
1. assign new index for all dim
Args:
node (node)
node_idx (int)
"""
self._assign_index_as_input(node, node_idx)
def _assign_view_reshape_index(self, node, node_idx): def _assign_view_reshape_index(self, node, node_idx):
""" """
Assign index for view and reshape op. Assign index for view and reshape op.
@ -494,26 +584,26 @@ class IndexTracer(object):
dim_equal = [i == j for i, j in zip(origin_shape[:-1], target_shape)] dim_equal = [i == j for i, j in zip(origin_shape[:-1], target_shape)]
dim_to = [dim_equal.index(False)] dim_to = [dim_equal.index(False)]
dim_from = [dim_equal.index(False), dim_equal.index(False) + 1] dim_from = [dim_equal.index(False), dim_equal.index(False) + 1]
self._add_dim(node_idx, -1)
elif len_diff == -1: elif len_diff == -1:
# dim expand # dim expand
dim_equal = [i == j for i, j in zip(origin_shape, target_shape[:-1])] dim_equal = [i == j for i, j in zip(origin_shape, target_shape[:-1])]
dim_from = [dim_equal.index(False)] dim_from = [dim_equal.index(False)]
dim_to = [dim_equal.index(False), dim_equal.index(False) + 1] dim_to = [dim_equal.index(False), dim_equal.index(False) + 1]
self._del_dim(node_idx, -1)
else: else:
raise NotImplementedError("shape" + str(origin_shape) + 'and' + str(target_shape) + "view not implemented") raise NotImplementedError("shape" + str(origin_shape) + 'and' + str(target_shape) + "view not implemented")
# get new index # get new index
origin_trace = self._find_idx_trace_from_node(origin_node) origin_trace = self._find_idx_trace_from_node(origin_node)
new_trace = copy.deepcopy(origin_trace) self._assign_index_as_input(node, node_idx, origin_node)
dim_from.reverse() dim_from.reverse()
for i in dim_from: for i in dim_from:
new_trace.pop(i) self._del_dim(node_idx, i)
for i in dim_to: for i in dim_to:
new_trace.insert(i, self._add_index()) self._add_dim(node_idx, i)
self.idx_trace_list[node_idx]['idx'] = new_trace
# inherit computation # inherit computation
self._inherit_computation(origin_node, node)
compute_log = self._find_compute_trace_from_node(origin_node) compute_log = self._find_compute_trace_from_node(origin_node)
for i in dim_from: for i in dim_from:
if origin_trace[i] in compute_log: if origin_trace[i] in compute_log:
@ -524,15 +614,10 @@ class IndexTracer(object):
# log view, not used now # log view, not used now
view_dict = {"idx_from": [origin_trace[i] for i in dim_from], view_dict = {"idx_from": [origin_trace[i] for i in dim_from],
"dim_from": dim_from, "dim_from": dim_from,
"idx_to": [new_trace[i] for i in dim_to], "idx_to": [self.idx_trace_list[node_idx]['idx'][i] for i in dim_to],
"dim_to": dim_to} "dim_to": dim_to}
self.idx_view_list.append(view_dict) self.idx_view_list.append(view_dict)
def _remove_duplicate_compute(self):
for i in self.idx_trace_list:
for k, v in i['compute'].items():
i['compute'][k] = list(set(v))
def _merge_equal_idx(self): def _merge_equal_idx(self):
idx_equal = copy.deepcopy(self.idx_trace_equal) idx_equal = copy.deepcopy(self.idx_trace_equal)
idx_equal.reverse() idx_equal.reverse()
@ -556,8 +641,8 @@ class IndexTracer(object):
self._assign_view_reshape_index(node, idx) self._assign_view_reshape_index(node, idx)
elif 'unsqueeze' in node.name: elif 'unsqueeze' in node.name:
self._assign_unsqueeze_index(node, idx) self._assign_unsqueeze_index(node, idx)
elif 'to' in node.name: elif any(i in node.name for i in ['to', 'contiguous']):
self._assign_to_index(node, idx) self._assgin_no_change_index(node, idx)
else: else:
raise NotImplementedError(node.name, "method not implemented yet!") raise NotImplementedError(node.name, "method not implemented yet!")
elif node.op == 'call_function': elif node.op == 'call_function':
@ -573,6 +658,8 @@ class IndexTracer(object):
self._assign_ones_like_index(node, idx) self._assign_ones_like_index(node, idx)
elif 'dropout' in node.name: elif 'dropout' in node.name:
self._assign_dropout_index(node, idx) self._assign_dropout_index(node, idx)
elif 'einsum' in node.name:
self._assign_einsum_index(node, idx)
elif 'getattr' in node.name: elif 'getattr' in node.name:
continue # get attr like shape continue # get attr like shape
elif 'getitem' in node.name: elif 'getitem' in node.name:
@ -590,10 +677,20 @@ class IndexTracer(object):
continue continue
else: else:
raise NotImplementedError(node.op, "op not implemented yet!") raise NotImplementedError(node.op, "op not implemented yet!")
# self._merge_equal_idx()
self._remove_duplicate_compute() def check_index(self, trace_idx, start_idx, end_idx):
self._merge_equal_idx() for i in range(start_idx, end_idx + 1):
cur_idx = self.idx_trace_list[i]['idx']
cur_compute = self.idx_trace_list[i]['compute']
if trace_idx in cur_compute:
for j in cur_compute[trace_idx]:
if j < start_idx or j > end_idx:
return False
# same_idx = [1 if j == trace_idx else 0 for j in cur_idx]
# if sum(same_idx) > 1:
# return False
return True
class MemoryEstimator(object): class MemoryEstimator(object):
def __init__(self) -> None: def __init__(self) -> None:
@ -897,6 +994,8 @@ class ChunkRegionSearch(object):
self._is_not_compute(after_trace, (start_idx, end_idx), i) and self._is_not_compute(after_trace, (start_idx, end_idx), i) and
self.node_list[end_idx].meta['tensor_meta'].shape[i] != 1): self.node_list[end_idx].meta['tensor_meta'].shape[i] != 1):
continue continue
if not self.index_tracer.check_index(before_trace['idx'][i], start_idx, end_idx):
continue
flow_flag, chunk_info = self._detect_flow(before_trace, after_trace, start_idx, end_idx, i) flow_flag, chunk_info = self._detect_flow(before_trace, after_trace, start_idx, end_idx, i)
if flow_flag == None: if flow_flag == None:
continue continue
@ -910,6 +1009,9 @@ class ChunkRegionSearch(object):
input_trace = [] input_trace = []
for i, n in enumerate(self.node_list): for i, n in enumerate(self.node_list):
if len(n.args) > 0 and n.op != 'output': if len(n.args) > 0 and n.op != 'output':
if isinstance(n.args[0], str):
input_idx = _find_idx_by_name(n.args[1].name, self.node_list)
else:
input_idx = _find_idx_by_name(n.args[0].name, self.node_list) input_idx = _find_idx_by_name(n.args[0].name, self.node_list)
input_trace.append(output_trace[input_idx]) input_trace.append(output_trace[input_idx])
else: else:
@ -1130,6 +1232,7 @@ def emit_code_with_chunk(body, ckpt_func, nodes, emit_node_func, delete_unused_v
if node_idx in chunk_starts: if node_idx in chunk_starts:
within_chunk_region = True within_chunk_region = True
region_idx = chunk_starts.index(node_idx)
# add for loop # add for loop
chunk_input_meta = [meta_nodes[i] for i in chunk_inputs_idx[region_idx]] chunk_input_meta = [meta_nodes[i] for i in chunk_inputs_idx[region_idx]]
@ -1150,7 +1253,6 @@ def emit_code_with_chunk(body, ckpt_func, nodes, emit_node_func, delete_unused_v
if node_idx in chunk_ends: if node_idx in chunk_ends:
body.append(_gen_loop_end(node, chunk_inputs[region_idx], node_list, chunk_dims[region_idx])) body.append(_gen_loop_end(node, chunk_inputs[region_idx], node_list, chunk_dims[region_idx]))
within_chunk_region = False within_chunk_region = False
region_idx += 1
node_idx += 1 node_idx += 1