mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-07-19 01:39:26 +00:00
add doc string
This commit is contained in:
parent
c36dba07de
commit
70a98b8f56
@ -26,13 +26,28 @@ class NodeIndexTracer(object):
|
||||
self.idx_trace_list = [{'idx': [], 'compute': []} for _ in range(len(self.nodes_list))]
|
||||
self.idx_trace_equal = []
|
||||
self.idx_view_list = []
|
||||
self.idx_count = 1
|
||||
self.idx_count = -1
|
||||
|
||||
def add_index(self):
|
||||
"""
|
||||
Update the count and return it. To record the idx number.
|
||||
|
||||
Returns:
|
||||
idx_count: int
|
||||
"""
|
||||
self.idx_count += 1
|
||||
return self.idx_count - 1
|
||||
return self.idx_count
|
||||
|
||||
def inherit_computation(self, node_from, node_to):
|
||||
"""
|
||||
Inherit computed dim from node_from to node_to.
|
||||
If a dim in node_from is marked as computed and exists in node_to,
|
||||
still mark it as computed in node_to.
|
||||
|
||||
Args:
|
||||
node_from (node): node to be inherited
|
||||
node_to (node): new node to inherit
|
||||
"""
|
||||
_, compute_from = self.find_trace_from_node(node_from)
|
||||
idx_to, compute_to = self.find_trace_from_node(node_to)
|
||||
for i in compute_from:
|
||||
@ -40,9 +55,24 @@ class NodeIndexTracer(object):
|
||||
compute_to.append(i)
|
||||
|
||||
def mark_idx_equal(self, idx1, idx2):
|
||||
"""
|
||||
Mark 2 index to be equal.
|
||||
|
||||
Args:
|
||||
idx1 (int): index count.
|
||||
idx2 (int): index count.
|
||||
"""
|
||||
self.idx_trace_equal.append((idx1, idx2))
|
||||
|
||||
def mark_computation(self, node, idx, dim):
|
||||
"""
|
||||
Mark some dims of node as computed.
|
||||
|
||||
Args:
|
||||
node (node)
|
||||
idx (int): node index
|
||||
dim (list or int): dims to be marked as computed
|
||||
"""
|
||||
input_node_idx_trace = self.find_idx_trace_from_node(node)
|
||||
if isinstance(dim, int):
|
||||
dim = [dim]
|
||||
@ -52,15 +82,40 @@ class NodeIndexTracer(object):
|
||||
self.idx_trace_list[idx]['compute'].append(cur_idx)
|
||||
|
||||
def find_trace_from_node(self, node):
|
||||
"""
|
||||
Find node idx and compute trace by the node.
|
||||
|
||||
Args:
|
||||
node (node)
|
||||
Returns:
|
||||
idx (list): idx of the node
|
||||
compute (list): computed idx of the node.
|
||||
"""
|
||||
node_idx = _find_idx_by_name(node.name, self.nodes_list)
|
||||
node_dict = self.idx_trace_list[node_idx]
|
||||
return node_dict['idx'], node_dict['compute']
|
||||
|
||||
def find_idx_trace_from_node(self, node):
|
||||
"""
|
||||
Find node idx trace by the node.
|
||||
|
||||
Args:
|
||||
node (node)
|
||||
Returns:
|
||||
idx (list): idx of the node
|
||||
"""
|
||||
node_idx = _find_idx_by_name(node.name, self.nodes_list)
|
||||
return self.idx_trace_list[node_idx]['idx']
|
||||
|
||||
def find_compute_trace_from_node(self, node):
|
||||
"""
|
||||
Find node compute trace by the node.
|
||||
|
||||
Args:
|
||||
node (node)
|
||||
Returns:
|
||||
compute (list): computed idx of the node.
|
||||
"""
|
||||
node_idx = _find_idx_by_name(node.name, self.nodes_list)
|
||||
return self.idx_trace_list[node_idx]['compute']
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user