diff --git a/chunk_codegen.py b/chunk_codegen.py index 6740cd44a..2dc44d381 100644 --- a/chunk_codegen.py +++ b/chunk_codegen.py @@ -200,8 +200,12 @@ class NodeIndexTracer(object): Args: node (node) node_idx (int) - """ - input_node, weight, bias = node.args + """ + if len(node.args) == 2: + input_node, weight = node.args + bias = None + else: + input_node, weight, bias = node.args input_node_idx_trace = self._find_idx_trace_from_node(input_node) weight_idx_trace = self._find_idx_trace_from_node(weight) @@ -284,6 +288,53 @@ class NodeIndexTracer(object): self._assign_index_as_input(node, idx) self._inherit_computation(node.args[0], node) self._mark_computation(node, idx, [node.kwargs['dim']]) + + def _assign_unsqueeze_index(self, node, node_idx): + """ + Assign index for unsqueeze op. + 1. assign new index for unsqueeze dim + + Args: + node (node) + node_idx (int) + """ + self._assign_index_as_input(node, node_idx) + self._inherit_computation(node.args[0], node) + self.idx_trace_list[node_idx]['idx'].insert(node.args[1], self._add_index()) + + def _assign_dropout_index(self, node, node_idx): + """ + Assign index for unsqueeze op. + 1. assign new index for unsqueeze dim + + Args: + node (node) + node_idx (int) + """ + self._assign_index_as_input(node, node_idx) + + + def _assign_ones_like_index(self, node, node_idx): + """ + Assign index for oneslike op. + 1. assign new index for all dim + + Args: + node (node) + node_idx (int) + """ + self._assign_all_index(node, node_idx) + + def _assign_to_index(self, node, node_idx): + """ + Assign index for to op. + 1. assign new index for all dim + + Args: + node (node) + node_idx (int) + """ + self._assign_index_as_input(node, node_idx) def _assign_view_reshape_index(self, node, node_idx): """ @@ -388,6 +439,10 @@ class NodeIndexTracer(object): self._assign_permute_index(node, idx) elif 'view' in node.name or 'reshape' in node.name: self._assign_view_reshape_index(node, idx) + elif 'unsqueeze' in node.name: + self._assign_unsqueeze_index(node, idx) + elif 'to' in node.name: + self._assign_to_index(node, idx) else: raise NotImplementedError(node.name, "method not implemented yet!") elif node.op == 'call_function': @@ -399,6 +454,10 @@ class NodeIndexTracer(object): self._assign_softmax_index(node, idx) elif any(n in node.name for n in ['mul', 'add', 'sigmoid', 'relu']): self._assign_elementwise_index(node, idx) + elif 'ones_like' in node.name: + self._assign_ones_like_index(node, idx) + elif 'dropout' in node.name: + self._assign_dropout_index(node, idx) elif 'getattr' in node.name: continue # get attr like shape elif 'getitem' in node.name: