diff --git a/chunk_codegen.py b/chunk_codegen.py index 8477fe9a1..aa9d7ecd8 100644 --- a/chunk_codegen.py +++ b/chunk_codegen.py @@ -26,13 +26,28 @@ class NodeIndexTracer(object): self.idx_trace_list = [{'idx': [], 'compute': []} for _ in range(len(self.nodes_list))] self.idx_trace_equal = [] self.idx_view_list = [] - self.idx_count = 1 + self.idx_count = -1 def add_index(self): + """ + Update the count and return it. To record the idx number. + + Returns: + idx_count: int + """ self.idx_count += 1 - return self.idx_count - 1 + return self.idx_count def inherit_computation(self, node_from, node_to): + """ + Inherit computed dim from node_from to node_to. + If a dim in node_from is marked as computed and exists in node_to, + still mark it as computed in node_to. + + Args: + node_from (node): node to be inherited + node_to (node): new node to inherit + """ _, compute_from = self.find_trace_from_node(node_from) idx_to, compute_to = self.find_trace_from_node(node_to) for i in compute_from: @@ -40,9 +55,24 @@ class NodeIndexTracer(object): compute_to.append(i) def mark_idx_equal(self, idx1, idx2): + """ + Mark 2 index to be equal. + + Args: + idx1 (int): index count. + idx2 (int): index count. + """ self.idx_trace_equal.append((idx1, idx2)) def mark_computation(self, node, idx, dim): + """ + Mark some dims of node as computed. + + Args: + node (node) + idx (int): node index + dim (list or int): dims to be marked as computed + """ input_node_idx_trace = self.find_idx_trace_from_node(node) if isinstance(dim, int): dim = [dim] @@ -52,15 +82,40 @@ class NodeIndexTracer(object): self.idx_trace_list[idx]['compute'].append(cur_idx) def find_trace_from_node(self, node): + """ + Find node idx and compute trace by the node. + + Args: + node (node) + Returns: + idx (list): idx of the node + compute (list): computed idx of the node. + """ node_idx = _find_idx_by_name(node.name, self.nodes_list) node_dict = self.idx_trace_list[node_idx] return node_dict['idx'], node_dict['compute'] def find_idx_trace_from_node(self, node): + """ + Find node idx trace by the node. + + Args: + node (node) + Returns: + idx (list): idx of the node + """ node_idx = _find_idx_by_name(node.name, self.nodes_list) return self.idx_trace_list[node_idx]['idx'] def find_compute_trace_from_node(self, node): + """ + Find node compute trace by the node. + + Args: + node (node) + Returns: + compute (list): computed idx of the node. + """ node_idx = _find_idx_by_name(node.name, self.nodes_list) return self.idx_trace_list[node_idx]['compute']