mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-25 15:01:43 +00:00
support new op
This commit is contained in:
parent
f24c418bb0
commit
a9d64377bb
@ -201,6 +201,10 @@ class NodeIndexTracer(object):
|
|||||||
node (node)
|
node (node)
|
||||||
node_idx (int)
|
node_idx (int)
|
||||||
"""
|
"""
|
||||||
|
if len(node.args) == 2:
|
||||||
|
input_node, weight = node.args
|
||||||
|
bias = None
|
||||||
|
else:
|
||||||
input_node, weight, bias = node.args
|
input_node, weight, bias = node.args
|
||||||
input_node_idx_trace = self._find_idx_trace_from_node(input_node)
|
input_node_idx_trace = self._find_idx_trace_from_node(input_node)
|
||||||
weight_idx_trace = self._find_idx_trace_from_node(weight)
|
weight_idx_trace = self._find_idx_trace_from_node(weight)
|
||||||
@ -285,6 +289,53 @@ class NodeIndexTracer(object):
|
|||||||
self._inherit_computation(node.args[0], node)
|
self._inherit_computation(node.args[0], node)
|
||||||
self._mark_computation(node, idx, [node.kwargs['dim']])
|
self._mark_computation(node, idx, [node.kwargs['dim']])
|
||||||
|
|
||||||
|
def _assign_unsqueeze_index(self, node, node_idx):
|
||||||
|
"""
|
||||||
|
Assign index for unsqueeze op.
|
||||||
|
1. assign new index for unsqueeze dim
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node (node)
|
||||||
|
node_idx (int)
|
||||||
|
"""
|
||||||
|
self._assign_index_as_input(node, node_idx)
|
||||||
|
self._inherit_computation(node.args[0], node)
|
||||||
|
self.idx_trace_list[node_idx]['idx'].insert(node.args[1], self._add_index())
|
||||||
|
|
||||||
|
def _assign_dropout_index(self, node, node_idx):
|
||||||
|
"""
|
||||||
|
Assign index for unsqueeze op.
|
||||||
|
1. assign new index for unsqueeze dim
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node (node)
|
||||||
|
node_idx (int)
|
||||||
|
"""
|
||||||
|
self._assign_index_as_input(node, node_idx)
|
||||||
|
|
||||||
|
|
||||||
|
def _assign_ones_like_index(self, node, node_idx):
|
||||||
|
"""
|
||||||
|
Assign index for oneslike op.
|
||||||
|
1. assign new index for all dim
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node (node)
|
||||||
|
node_idx (int)
|
||||||
|
"""
|
||||||
|
self._assign_all_index(node, node_idx)
|
||||||
|
|
||||||
|
def _assign_to_index(self, node, node_idx):
|
||||||
|
"""
|
||||||
|
Assign index for to op.
|
||||||
|
1. assign new index for all dim
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node (node)
|
||||||
|
node_idx (int)
|
||||||
|
"""
|
||||||
|
self._assign_index_as_input(node, node_idx)
|
||||||
|
|
||||||
def _assign_view_reshape_index(self, node, node_idx):
|
def _assign_view_reshape_index(self, node, node_idx):
|
||||||
"""
|
"""
|
||||||
Assign index for view and reshape op.
|
Assign index for view and reshape op.
|
||||||
@ -388,6 +439,10 @@ class NodeIndexTracer(object):
|
|||||||
self._assign_permute_index(node, idx)
|
self._assign_permute_index(node, idx)
|
||||||
elif 'view' in node.name or 'reshape' in node.name:
|
elif 'view' in node.name or 'reshape' in node.name:
|
||||||
self._assign_view_reshape_index(node, idx)
|
self._assign_view_reshape_index(node, idx)
|
||||||
|
elif 'unsqueeze' in node.name:
|
||||||
|
self._assign_unsqueeze_index(node, idx)
|
||||||
|
elif 'to' in node.name:
|
||||||
|
self._assign_to_index(node, idx)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(node.name, "method not implemented yet!")
|
raise NotImplementedError(node.name, "method not implemented yet!")
|
||||||
elif node.op == 'call_function':
|
elif node.op == 'call_function':
|
||||||
@ -399,6 +454,10 @@ class NodeIndexTracer(object):
|
|||||||
self._assign_softmax_index(node, idx)
|
self._assign_softmax_index(node, idx)
|
||||||
elif any(n in node.name for n in ['mul', 'add', 'sigmoid', 'relu']):
|
elif any(n in node.name for n in ['mul', 'add', 'sigmoid', 'relu']):
|
||||||
self._assign_elementwise_index(node, idx)
|
self._assign_elementwise_index(node, idx)
|
||||||
|
elif 'ones_like' in node.name:
|
||||||
|
self._assign_ones_like_index(node, idx)
|
||||||
|
elif 'dropout' in node.name:
|
||||||
|
self._assign_dropout_index(node, idx)
|
||||||
elif 'getattr' in node.name:
|
elif 'getattr' in node.name:
|
||||||
continue # get attr like shape
|
continue # get attr like shape
|
||||||
elif 'getitem' in node.name:
|
elif 'getitem' in node.name:
|
||||||
|
Loading…
Reference in New Issue
Block a user