mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-07-19 18:00:43 +00:00
add doc string
This commit is contained in:
parent
c36dba07de
commit
70a98b8f56
@ -26,13 +26,28 @@ class NodeIndexTracer(object):
|
|||||||
self.idx_trace_list = [{'idx': [], 'compute': []} for _ in range(len(self.nodes_list))]
|
self.idx_trace_list = [{'idx': [], 'compute': []} for _ in range(len(self.nodes_list))]
|
||||||
self.idx_trace_equal = []
|
self.idx_trace_equal = []
|
||||||
self.idx_view_list = []
|
self.idx_view_list = []
|
||||||
self.idx_count = 1
|
self.idx_count = -1
|
||||||
|
|
||||||
def add_index(self):
|
def add_index(self):
|
||||||
|
"""
|
||||||
|
Update the count and return it. To record the idx number.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
idx_count: int
|
||||||
|
"""
|
||||||
self.idx_count += 1
|
self.idx_count += 1
|
||||||
return self.idx_count - 1
|
return self.idx_count
|
||||||
|
|
||||||
def inherit_computation(self, node_from, node_to):
|
def inherit_computation(self, node_from, node_to):
|
||||||
|
"""
|
||||||
|
Inherit computed dim from node_from to node_to.
|
||||||
|
If a dim in node_from is marked as computed and exists in node_to,
|
||||||
|
still mark it as computed in node_to.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node_from (node): node to be inherited
|
||||||
|
node_to (node): new node to inherit
|
||||||
|
"""
|
||||||
_, compute_from = self.find_trace_from_node(node_from)
|
_, compute_from = self.find_trace_from_node(node_from)
|
||||||
idx_to, compute_to = self.find_trace_from_node(node_to)
|
idx_to, compute_to = self.find_trace_from_node(node_to)
|
||||||
for i in compute_from:
|
for i in compute_from:
|
||||||
@ -40,9 +55,24 @@ class NodeIndexTracer(object):
|
|||||||
compute_to.append(i)
|
compute_to.append(i)
|
||||||
|
|
||||||
def mark_idx_equal(self, idx1, idx2):
|
def mark_idx_equal(self, idx1, idx2):
|
||||||
|
"""
|
||||||
|
Mark 2 index to be equal.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
idx1 (int): index count.
|
||||||
|
idx2 (int): index count.
|
||||||
|
"""
|
||||||
self.idx_trace_equal.append((idx1, idx2))
|
self.idx_trace_equal.append((idx1, idx2))
|
||||||
|
|
||||||
def mark_computation(self, node, idx, dim):
|
def mark_computation(self, node, idx, dim):
|
||||||
|
"""
|
||||||
|
Mark some dims of node as computed.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node (node)
|
||||||
|
idx (int): node index
|
||||||
|
dim (list or int): dims to be marked as computed
|
||||||
|
"""
|
||||||
input_node_idx_trace = self.find_idx_trace_from_node(node)
|
input_node_idx_trace = self.find_idx_trace_from_node(node)
|
||||||
if isinstance(dim, int):
|
if isinstance(dim, int):
|
||||||
dim = [dim]
|
dim = [dim]
|
||||||
@ -52,15 +82,40 @@ class NodeIndexTracer(object):
|
|||||||
self.idx_trace_list[idx]['compute'].append(cur_idx)
|
self.idx_trace_list[idx]['compute'].append(cur_idx)
|
||||||
|
|
||||||
def find_trace_from_node(self, node):
|
def find_trace_from_node(self, node):
|
||||||
|
"""
|
||||||
|
Find node idx and compute trace by the node.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node (node)
|
||||||
|
Returns:
|
||||||
|
idx (list): idx of the node
|
||||||
|
compute (list): computed idx of the node.
|
||||||
|
"""
|
||||||
node_idx = _find_idx_by_name(node.name, self.nodes_list)
|
node_idx = _find_idx_by_name(node.name, self.nodes_list)
|
||||||
node_dict = self.idx_trace_list[node_idx]
|
node_dict = self.idx_trace_list[node_idx]
|
||||||
return node_dict['idx'], node_dict['compute']
|
return node_dict['idx'], node_dict['compute']
|
||||||
|
|
||||||
def find_idx_trace_from_node(self, node):
|
def find_idx_trace_from_node(self, node):
|
||||||
|
"""
|
||||||
|
Find node idx trace by the node.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node (node)
|
||||||
|
Returns:
|
||||||
|
idx (list): idx of the node
|
||||||
|
"""
|
||||||
node_idx = _find_idx_by_name(node.name, self.nodes_list)
|
node_idx = _find_idx_by_name(node.name, self.nodes_list)
|
||||||
return self.idx_trace_list[node_idx]['idx']
|
return self.idx_trace_list[node_idx]['idx']
|
||||||
|
|
||||||
def find_compute_trace_from_node(self, node):
|
def find_compute_trace_from_node(self, node):
|
||||||
|
"""
|
||||||
|
Find node compute trace by the node.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node (node)
|
||||||
|
Returns:
|
||||||
|
compute (list): computed idx of the node.
|
||||||
|
"""
|
||||||
node_idx = _find_idx_by_name(node.name, self.nodes_list)
|
node_idx = _find_idx_by_name(node.name, self.nodes_list)
|
||||||
return self.idx_trace_list[node_idx]['compute']
|
return self.idx_trace_list[node_idx]['compute']
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user