add doc string

This commit is contained in:
oahzxl 2022-11-14 23:49:48 +08:00
parent c36dba07de
commit 70a98b8f56

View File

@ -26,13 +26,28 @@ class NodeIndexTracer(object):
self.idx_trace_list = [{'idx': [], 'compute': []} for _ in range(len(self.nodes_list))] self.idx_trace_list = [{'idx': [], 'compute': []} for _ in range(len(self.nodes_list))]
self.idx_trace_equal = [] self.idx_trace_equal = []
self.idx_view_list = [] self.idx_view_list = []
self.idx_count = 1 self.idx_count = -1
def add_index(self): def add_index(self):
"""
Update the count and return it. To record the idx number.
Returns:
idx_count: int
"""
self.idx_count += 1 self.idx_count += 1
return self.idx_count - 1 return self.idx_count
def inherit_computation(self, node_from, node_to): def inherit_computation(self, node_from, node_to):
"""
Inherit computed dim from node_from to node_to.
If a dim in node_from is marked as computed and exists in node_to,
still mark it as computed in node_to.
Args:
node_from (node): node to be inherited
node_to (node): new node to inherit
"""
_, compute_from = self.find_trace_from_node(node_from) _, compute_from = self.find_trace_from_node(node_from)
idx_to, compute_to = self.find_trace_from_node(node_to) idx_to, compute_to = self.find_trace_from_node(node_to)
for i in compute_from: for i in compute_from:
@ -40,9 +55,24 @@ class NodeIndexTracer(object):
compute_to.append(i) compute_to.append(i)
def mark_idx_equal(self, idx1, idx2): def mark_idx_equal(self, idx1, idx2):
"""
Mark 2 index to be equal.
Args:
idx1 (int): index count.
idx2 (int): index count.
"""
self.idx_trace_equal.append((idx1, idx2)) self.idx_trace_equal.append((idx1, idx2))
def mark_computation(self, node, idx, dim): def mark_computation(self, node, idx, dim):
"""
Mark some dims of node as computed.
Args:
node (node)
idx (int): node index
dim (list or int): dims to be marked as computed
"""
input_node_idx_trace = self.find_idx_trace_from_node(node) input_node_idx_trace = self.find_idx_trace_from_node(node)
if isinstance(dim, int): if isinstance(dim, int):
dim = [dim] dim = [dim]
@ -52,15 +82,40 @@ class NodeIndexTracer(object):
self.idx_trace_list[idx]['compute'].append(cur_idx) self.idx_trace_list[idx]['compute'].append(cur_idx)
def find_trace_from_node(self, node): def find_trace_from_node(self, node):
"""
Find node idx and compute trace by the node.
Args:
node (node)
Returns:
idx (list): idx of the node
compute (list): computed idx of the node.
"""
node_idx = _find_idx_by_name(node.name, self.nodes_list) node_idx = _find_idx_by_name(node.name, self.nodes_list)
node_dict = self.idx_trace_list[node_idx] node_dict = self.idx_trace_list[node_idx]
return node_dict['idx'], node_dict['compute'] return node_dict['idx'], node_dict['compute']
def find_idx_trace_from_node(self, node): def find_idx_trace_from_node(self, node):
"""
Find node idx trace by the node.
Args:
node (node)
Returns:
idx (list): idx of the node
"""
node_idx = _find_idx_by_name(node.name, self.nodes_list) node_idx = _find_idx_by_name(node.name, self.nodes_list)
return self.idx_trace_list[node_idx]['idx'] return self.idx_trace_list[node_idx]['idx']
def find_compute_trace_from_node(self, node): def find_compute_trace_from_node(self, node):
"""
Find node compute trace by the node.
Args:
node (node)
Returns:
compute (list): computed idx of the node.
"""
node_idx = _find_idx_by_name(node.name, self.nodes_list) node_idx = _find_idx_by_name(node.name, self.nodes_list)
return self.idx_trace_list[node_idx]['compute'] return self.idx_trace_list[node_idx]['compute']