mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-23 06:00:44 +00:00
add doc str
This commit is contained in:
parent
70a98b8f56
commit
f379d1a94d
@ -120,6 +120,13 @@ class NodeIndexTracer(object):
|
|||||||
return self.idx_trace_list[node_idx]['compute']
|
return self.idx_trace_list[node_idx]['compute']
|
||||||
|
|
||||||
def assign_index_as_input(self, node, node_idx):
|
def assign_index_as_input(self, node, node_idx):
|
||||||
|
"""
|
||||||
|
Assign node's trace as its input node.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node (node)
|
||||||
|
node_idx (int)
|
||||||
|
"""
|
||||||
input_node_idx = _find_idx_by_name(node.args[0].name, self.nodes_list)
|
input_node_idx = _find_idx_by_name(node.args[0].name, self.nodes_list)
|
||||||
input_node_idx_trace = self.idx_trace_list[input_node_idx]['idx']
|
input_node_idx_trace = self.idx_trace_list[input_node_idx]['idx']
|
||||||
|
|
||||||
@ -127,6 +134,13 @@ class NodeIndexTracer(object):
|
|||||||
self.idx_trace_list[node_idx]['idx'] = new_idx_trace
|
self.idx_trace_list[node_idx]['idx'] = new_idx_trace
|
||||||
|
|
||||||
def assign_all_index(self, node, node_idx):
|
def assign_all_index(self, node, node_idx):
|
||||||
|
"""
|
||||||
|
Add new index for all node's dims.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node (node)
|
||||||
|
node_idx (int)
|
||||||
|
"""
|
||||||
shape = node.meta['tensor_meta'].shape
|
shape = node.meta['tensor_meta'].shape
|
||||||
new_trace = []
|
new_trace = []
|
||||||
for _ in shape:
|
for _ in shape:
|
||||||
@ -134,6 +148,15 @@ class NodeIndexTracer(object):
|
|||||||
self.idx_trace_list[node_idx]['idx'] = new_trace
|
self.idx_trace_list[node_idx]['idx'] = new_trace
|
||||||
|
|
||||||
def assign_transpose_index(self, node, node_idx):
|
def assign_transpose_index(self, node, node_idx):
|
||||||
|
"""
|
||||||
|
Assign index for transpose op.
|
||||||
|
1. swap input's dim according to transpose args
|
||||||
|
2. inherit input's computation
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node (node)
|
||||||
|
node_idx (int)
|
||||||
|
"""
|
||||||
tranpose_dim = node.args[1:]
|
tranpose_dim = node.args[1:]
|
||||||
input_node_idx_trace = self.find_idx_trace_from_node(node.args[0])
|
input_node_idx_trace = self.find_idx_trace_from_node(node.args[0])
|
||||||
|
|
||||||
@ -145,6 +168,15 @@ class NodeIndexTracer(object):
|
|||||||
self.inherit_computation(node.args[0], node)
|
self.inherit_computation(node.args[0], node)
|
||||||
|
|
||||||
def assign_permute_index(self, node, node_idx):
|
def assign_permute_index(self, node, node_idx):
|
||||||
|
"""
|
||||||
|
Assign index for permute op.
|
||||||
|
1. swap input's dim according to permute args
|
||||||
|
2. inherit input's computation
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node (node)
|
||||||
|
node_idx (int)
|
||||||
|
"""
|
||||||
permute_dim = node.args[1:]
|
permute_dim = node.args[1:]
|
||||||
input_node_idx_trace = self.find_idx_trace_from_node(node.args[0])
|
input_node_idx_trace = self.find_idx_trace_from_node(node.args[0])
|
||||||
|
|
||||||
@ -156,6 +188,16 @@ class NodeIndexTracer(object):
|
|||||||
self.inherit_computation(node.args[0], node)
|
self.inherit_computation(node.args[0], node)
|
||||||
|
|
||||||
def assign_linear_index(self, node, node_idx):
|
def assign_linear_index(self, node, node_idx):
|
||||||
|
"""
|
||||||
|
Assign index for linear op.
|
||||||
|
1. copy trace from input node and change last index accroding to weight
|
||||||
|
2. mark equal for input node last index, weight first dim and bias dim.
|
||||||
|
3. inherit input's computation, mark computation for last dim.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node (node)
|
||||||
|
node_idx (int)
|
||||||
|
"""
|
||||||
input_node, weight, bias = node.args
|
input_node, weight, bias = node.args
|
||||||
input_node_idx_trace = self.find_idx_trace_from_node(input_node)
|
input_node_idx_trace = self.find_idx_trace_from_node(input_node)
|
||||||
weight_idx_trace = self.find_idx_trace_from_node(weight)
|
weight_idx_trace = self.find_idx_trace_from_node(weight)
|
||||||
@ -173,6 +215,16 @@ class NodeIndexTracer(object):
|
|||||||
self.mark_idx_equal(input_node_idx_trace[-1], bias_idx_trace[0])
|
self.mark_idx_equal(input_node_idx_trace[-1], bias_idx_trace[0])
|
||||||
|
|
||||||
def assign_matmul_index(self, node, node_idx):
|
def assign_matmul_index(self, node, node_idx):
|
||||||
|
"""
|
||||||
|
Assign index for matmul op.
|
||||||
|
1. copy trace from matmul_left and change last index accroding to matmul_right. (assert they have same length)
|
||||||
|
2. mark equal for input matmul_left -1 index and matmul_right -2 dim.
|
||||||
|
3. inherit matmul_left and matmul_right computation, mark computation for last dim.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node (node)
|
||||||
|
node_idx (int)
|
||||||
|
"""
|
||||||
matmul_left, matmul_right = node.args
|
matmul_left, matmul_right = node.args
|
||||||
matmul_left_idx_trace = self.find_idx_trace_from_node(matmul_left)
|
matmul_left_idx_trace = self.find_idx_trace_from_node(matmul_left)
|
||||||
matmul_right_idx_trace = self.find_idx_trace_from_node(matmul_right)
|
matmul_right_idx_trace = self.find_idx_trace_from_node(matmul_right)
|
||||||
@ -188,21 +240,63 @@ class NodeIndexTracer(object):
|
|||||||
self.mark_idx_equal(matmul_left_idx_trace[-1], matmul_right_idx_trace[-2])
|
self.mark_idx_equal(matmul_left_idx_trace[-1], matmul_right_idx_trace[-2])
|
||||||
|
|
||||||
def assign_layernorm_index(self, node, idx):
|
def assign_layernorm_index(self, node, idx):
|
||||||
|
"""
|
||||||
|
Assign index for layernorm op.
|
||||||
|
1. assign index as input node
|
||||||
|
2. inherit computation and mark last 2 dims as computed.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node (node)
|
||||||
|
node_idx (int)
|
||||||
|
"""
|
||||||
self.assign_index_as_input(node, idx)
|
self.assign_index_as_input(node, idx)
|
||||||
self.inherit_computation(node.args[0], node)
|
self.inherit_computation(node.args[0], node)
|
||||||
self.mark_computation(node, idx, [-1, -2])
|
self.mark_computation(node, idx, [-1, -2])
|
||||||
|
|
||||||
def assign_elementwise_index(self, node, idx):
|
def assign_elementwise_index(self, node, idx):
|
||||||
|
"""
|
||||||
|
Assign index for element-wise op (eg. relu sigmoid add mul).
|
||||||
|
1. assign index as input node
|
||||||
|
2. inherit computation from all input nodes.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node (node)
|
||||||
|
node_idx (int)
|
||||||
|
"""
|
||||||
self.assign_index_as_input(node, idx)
|
self.assign_index_as_input(node, idx)
|
||||||
for node_in in node.args:
|
for node_in in node.args:
|
||||||
if type(node_in) not in (int, float):
|
if type(node_in) not in (int, float):
|
||||||
self.inherit_computation(node_in, node)
|
self.inherit_computation(node_in, node)
|
||||||
|
|
||||||
def assign_softmax_index(self, node, idx):
|
def assign_softmax_index(self, node, idx):
|
||||||
|
"""
|
||||||
|
Assign index for softmax op.
|
||||||
|
1. assign index as input node
|
||||||
|
2. inherit computation and mark softmax dim as computed.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node (node)
|
||||||
|
node_idx (int)
|
||||||
|
"""
|
||||||
self.assign_index_as_input(node, idx)
|
self.assign_index_as_input(node, idx)
|
||||||
|
self.inherit_computation(node.args[0], node)
|
||||||
self.mark_computation(node, idx, [node.kwargs['dim']])
|
self.mark_computation(node, idx, [node.kwargs['dim']])
|
||||||
|
|
||||||
def assign_view_reshape_index(self, node, node_idx):
|
def assign_view_reshape_index(self, node, node_idx):
|
||||||
|
"""
|
||||||
|
Assign index for view and reshape op.
|
||||||
|
1. get origin shape and target shape by meta info.
|
||||||
|
2. compute the real value of -1 in target shape.
|
||||||
|
3. determine changed dim, and assgin index for generated dim.
|
||||||
|
4. log changed dim and generated dim for restore
|
||||||
|
5. look into view list to see whether the view is associated with other,
|
||||||
|
if so assgin equal dim according to previous view.
|
||||||
|
6. inherit computation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node (node)
|
||||||
|
node_idx (int)
|
||||||
|
"""
|
||||||
# get data, turn into number
|
# get data, turn into number
|
||||||
origin_node = node.args[0]
|
origin_node = node.args[0]
|
||||||
origin_shape = origin_node.meta['tensor_meta'].shape
|
origin_shape = origin_node.meta['tensor_meta'].shape
|
||||||
@ -305,6 +399,7 @@ class NodeIndexTracer(object):
|
|||||||
else:
|
else:
|
||||||
raise NotImplementedError(node.op, "op not implemented yet!")
|
raise NotImplementedError(node.op, "op not implemented yet!")
|
||||||
|
|
||||||
|
|
||||||
def _get_meta_node_size(x):
|
def _get_meta_node_size(x):
|
||||||
x = x.meta['tensor_meta']
|
x = x.meta['tensor_meta']
|
||||||
x = x.numel * torch.tensor([], dtype=x.dtype).element_size()
|
x = x.numel * torch.tensor([], dtype=x.dtype).element_size()
|
||||||
|
Loading…
Reference in New Issue
Block a user