diff --git a/chunk_codegen.py b/chunk_codegen.py index aa9d7ecd8..a14f7c134 100644 --- a/chunk_codegen.py +++ b/chunk_codegen.py @@ -120,6 +120,13 @@ class NodeIndexTracer(object): return self.idx_trace_list[node_idx]['compute'] def assign_index_as_input(self, node, node_idx): + """ + Assign node's trace as its input node. + + Args: + node (node) + node_idx (int) + """ input_node_idx = _find_idx_by_name(node.args[0].name, self.nodes_list) input_node_idx_trace = self.idx_trace_list[input_node_idx]['idx'] @@ -127,6 +134,13 @@ class NodeIndexTracer(object): self.idx_trace_list[node_idx]['idx'] = new_idx_trace def assign_all_index(self, node, node_idx): + """ + Add new index for all node's dims. + + Args: + node (node) + node_idx (int) + """ shape = node.meta['tensor_meta'].shape new_trace = [] for _ in shape: @@ -134,6 +148,15 @@ class NodeIndexTracer(object): self.idx_trace_list[node_idx]['idx'] = new_trace def assign_transpose_index(self, node, node_idx): + """ + Assign index for transpose op. + 1. swap input's dim according to transpose args + 2. inherit input's computation + + Args: + node (node) + node_idx (int) + """ tranpose_dim = node.args[1:] input_node_idx_trace = self.find_idx_trace_from_node(node.args[0]) @@ -145,6 +168,15 @@ class NodeIndexTracer(object): self.inherit_computation(node.args[0], node) def assign_permute_index(self, node, node_idx): + """ + Assign index for permute op. + 1. swap input's dim according to permute args + 2. inherit input's computation + + Args: + node (node) + node_idx (int) + """ permute_dim = node.args[1:] input_node_idx_trace = self.find_idx_trace_from_node(node.args[0]) @@ -156,6 +188,16 @@ class NodeIndexTracer(object): self.inherit_computation(node.args[0], node) def assign_linear_index(self, node, node_idx): + """ + Assign index for linear op. + 1. copy trace from input node and change last index accroding to weight + 2. mark equal for input node last index, weight first dim and bias dim. + 3. inherit input's computation, mark computation for last dim. + + Args: + node (node) + node_idx (int) + """ input_node, weight, bias = node.args input_node_idx_trace = self.find_idx_trace_from_node(input_node) weight_idx_trace = self.find_idx_trace_from_node(weight) @@ -173,6 +215,16 @@ class NodeIndexTracer(object): self.mark_idx_equal(input_node_idx_trace[-1], bias_idx_trace[0]) def assign_matmul_index(self, node, node_idx): + """ + Assign index for matmul op. + 1. copy trace from matmul_left and change last index accroding to matmul_right. (assert they have same length) + 2. mark equal for input matmul_left -1 index and matmul_right -2 dim. + 3. inherit matmul_left and matmul_right computation, mark computation for last dim. + + Args: + node (node) + node_idx (int) + """ matmul_left, matmul_right = node.args matmul_left_idx_trace = self.find_idx_trace_from_node(matmul_left) matmul_right_idx_trace = self.find_idx_trace_from_node(matmul_right) @@ -188,21 +240,63 @@ class NodeIndexTracer(object): self.mark_idx_equal(matmul_left_idx_trace[-1], matmul_right_idx_trace[-2]) def assign_layernorm_index(self, node, idx): + """ + Assign index for layernorm op. + 1. assign index as input node + 2. inherit computation and mark last 2 dims as computed. + + Args: + node (node) + node_idx (int) + """ self.assign_index_as_input(node, idx) self.inherit_computation(node.args[0], node) self.mark_computation(node, idx, [-1, -2]) def assign_elementwise_index(self, node, idx): + """ + Assign index for element-wise op (eg. relu sigmoid add mul). + 1. assign index as input node + 2. inherit computation from all input nodes. + + Args: + node (node) + node_idx (int) + """ self.assign_index_as_input(node, idx) for node_in in node.args: if type(node_in) not in (int, float): self.inherit_computation(node_in, node) def assign_softmax_index(self, node, idx): + """ + Assign index for softmax op. + 1. assign index as input node + 2. inherit computation and mark softmax dim as computed. + + Args: + node (node) + node_idx (int) + """ self.assign_index_as_input(node, idx) + self.inherit_computation(node.args[0], node) self.mark_computation(node, idx, [node.kwargs['dim']]) def assign_view_reshape_index(self, node, node_idx): + """ + Assign index for view and reshape op. + 1. get origin shape and target shape by meta info. + 2. compute the real value of -1 in target shape. + 3. determine changed dim, and assgin index for generated dim. + 4. log changed dim and generated dim for restore + 5. look into view list to see whether the view is associated with other, + if so assgin equal dim according to previous view. + 6. inherit computation. + + Args: + node (node) + node_idx (int) + """ # get data, turn into number origin_node = node.args[0] origin_shape = origin_node.meta['tensor_meta'].shape @@ -305,6 +399,7 @@ class NodeIndexTracer(object): else: raise NotImplementedError(node.op, "op not implemented yet!") + def _get_meta_node_size(x): x = x.meta['tensor_meta'] x = x.numel * torch.tensor([], dtype=x.dtype).element_size()