diff --git a/chunk_codegen.py b/chunk_codegen.py index 1e8305ba3..ce7d84917 100644 --- a/chunk_codegen.py +++ b/chunk_codegen.py @@ -16,6 +16,11 @@ def _delete_free_var_from_last_use(user_to_last_uses): if n.op == 'placeholder': user_to_last_uses[key].remove(n) +def _get_node_shape(node): + if hasattr(node.meta['tensor_meta'], "shape"): + return node.meta['tensor_meta'].shape + return None + class FlowTracer(object): def __init__(self, gm) -> None: @@ -136,11 +141,25 @@ class IndexTracer(object): def __init__(self, gm) -> None: self.gm = gm self.nodes_list = list(gm.graph.nodes) - self.idx_trace_list = [{'idx': [], 'compute': {}} for _ in range(len(self.nodes_list))] + self.idx_trace_list = self._init_idx_trace_list() self.idx_trace_equal = [] self.idx_view_list = [] self.idx_count = -1 + def _init_idx_trace_list(self): + idx_trace_list = [] + for n in self.nodes_list: + if _get_node_shape(n) != None: + cur_trace = { + 'idx': [None for _ in range(len(_get_node_shape(n)))], + 'compute': [[] for _ in range(len(_get_node_shape(n)))], + 'source': [[] for _ in range(len(_get_node_shape(n)))], + } + else: + cur_trace = {'idx': [], 'compute': [], 'source': []} + idx_trace_list.append(cur_trace) + return idx_trace_list + def _add_index(self): """ Update the count and return it. To record the idx number. @@ -150,35 +169,81 @@ class IndexTracer(object): """ self.idx_count += 1 return self.idx_count - - def _inherit_computation(self, node_from, node_to): - """ - Inherit computed dim from node_from to node_to. - If a dim in node_from is marked as computed and exists in node_to, - still mark it as computed in node_to. - - Args: - node_from (node): node to be inherited - node_to (node): new node to inherit - """ - _, compute_from = self._find_trace_from_node(node_from) - idx_to, compute_to = self._find_trace_from_node(node_to) - for k, v in compute_from.items(): - if k in idx_to: - if k in compute_to: - compute_to[k].extend(v) - else: - compute_to[k] = copy.deepcopy(v) - def _mark_idx_equal(self, idx1, idx2): + def _del_dim(self, idx, dim_idx): + self.idx_trace_list[idx]['idx'].pop(dim_idx) + self.idx_trace_list[idx]['compute'].pop(dim_idx) + self.idx_trace_list[idx]['source'].pop(dim_idx) + + def _add_dim(self, idx, dim_idx): + self.idx_trace_list[idx]['idx'].insert(dim_idx, self._add_index()) + self.idx_trace_list[idx]['compute'].insert(dim_idx, []) + self.idx_trace_list[idx]['source'].insert(dim_idx, []) + + def _transform_index(self, node, node_dim): + node_idx = self._find_idx_trace_from_node(node) + dims = list(range(len(node_idx))) + return dims[node_dim] + + def _inherit_index(self, node_from, node_from_dim, node_to, node_to_dim): + node_from_dim = self._transform_index(node_from, node_from_dim) + node_to_dim = self._transform_index(node_to, node_to_dim) + node_from_trace = self._find_trace_from_node(node_from) + node_to_trace = self._find_trace_from_node(node_to) + node_to_trace['idx'][node_to_dim] = node_from_trace['idx'][node_from_dim] + node_to_trace['compute'][node_to_dim] = copy.deepcopy(node_from_trace['compute'][node_from_dim]) + node_from_idx = _find_idx_by_name(node_from.name, self.nodes_list) + node_to_trace['source'][node_to_dim] = [] + node_to_trace['source'][node_to_dim].append({node_from_idx: node_from_dim}) + node_to_trace['source'][node_to_dim].extend(node_from_trace['source'][node_from_dim]) + + def _inherit_all_computation(self, node_from, node_to): + node_from_compute = self._find_compute_trace_from_node(node_from) + node_to_compute = self._find_compute_trace_from_node(node_to) + assert len(node_from_compute) == len(node_to_compute) + for i in range(len(node_from_compute)): + self._add_source(node_from, i, node_to, i) + node_to_compute[i] = copy.deepcopy(node_from_compute[i]) + + def _add_source(self, node_from, node_from_dim, node_to, node_to_dim): + node_from_dim = self._transform_index(node_from, node_from_dim) + node_from_trace = self._find_trace_from_node(node_from) + node_to_dim = self._transform_index(node_to, node_to_dim) + node_to_trace = self._find_trace_from_node(node_to) + node_from_idx = _find_idx_by_name(node_from.name, self.nodes_list) + node_to_trace['source'][node_to_dim].append({node_from_idx: node_from_dim}) + node_to_trace['source'][node_to_dim].extend(node_from_trace['source'][node_from_dim]) + + def _mark_computation_from_node(self, node_from, node_to, exclude=None): + if exclude == None: + exclude = [] + else: + exclude = [self._transform_index(node_to, i) for i in exclude] + node_from_compute = self._find_compute_trace_from_node(node_from) + node_to_compute = self._find_compute_trace_from_node(node_to) + # assert len(node_from_compute) == len(node_to_compute) + for i in range(-1, -min(len(node_from_compute), len(node_to_compute)) - 1, -1): + if self._transform_index(node_to, i) in exclude: + continue + self._add_source(node_from, i, node_to, i) + for j in node_from_compute[i]: + if j not in node_to_compute[i]: + node_to_compute[i].append(j) + + def _mark_idx_equal(self, node1, dim1, node2, dim2): """ Mark 2 index to be equal. Args: idx1 (int): index count. idx2 (int): index count. - """ - self.idx_trace_equal.append((idx1, idx2)) + """ + # node1_idx = _find_idx_by_name(node1.name, self.nodes_list) + # node2_idx = _find_idx_by_name(node2.name, self.nodes_list) + # if node1_idx > node2_idx: + # self._add_source(node2, dim2, node1, dim1) + # else: + # self._add_source(node1, dim1, node2, dim2) def _mark_computation(self, node, idx, dim): """ @@ -189,16 +254,14 @@ class IndexTracer(object): idx (int): node index dim (list or int): dims to be marked as computed """ - input_node_idx_trace = self._find_idx_trace_from_node(node) if isinstance(dim, int): dim = [dim] + dims = list(range(len(_get_node_shape(node)))) for d in dim: - cur_idx = input_node_idx_trace[d] - if cur_idx not in self.idx_trace_list[idx]['compute']: - self.idx_trace_list[idx]['compute'][cur_idx] = [idx] - else: - self.idx_trace_list[idx]['compute'][cur_idx].append(idx) - + cur_dim = dims[d] + if idx not in self.idx_trace_list[idx]['compute'][cur_dim]: + self.idx_trace_list[idx]['compute'][cur_dim].append(idx) + def _find_trace_from_node(self, node): """ Find node idx and compute trace by the node. @@ -211,7 +274,7 @@ class IndexTracer(object): """ node_idx = _find_idx_by_name(node.name, self.nodes_list) node_dict = self.idx_trace_list[node_idx] - return node_dict['idx'], node_dict['compute'] + return node_dict def _find_idx_trace_from_node(self, node): """ @@ -237,19 +300,23 @@ class IndexTracer(object): node_idx = _find_idx_by_name(node.name, self.nodes_list) return self.idx_trace_list[node_idx]['compute'] - def _assign_index_as_input(self, node, node_idx): + def _assign_index_as_input(self, node, node_idx, input_node=None): """ Assign node's trace as its input node. Args: node (node) node_idx (int) - """ - input_node_idx = _find_idx_by_name(node.args[0].name, self.nodes_list) + """ + if input_node == None: + input_node = node.args[0] + input_node_idx = _find_idx_by_name(input_node.name, self.nodes_list) input_node_idx_trace = self.idx_trace_list[input_node_idx]['idx'] new_idx_trace = copy.deepcopy(input_node_idx_trace) self.idx_trace_list[node_idx]['idx'] = new_idx_trace + + self._inherit_all_computation(input_node, node) def _assign_all_index(self, node, node_idx): """ @@ -275,15 +342,12 @@ class IndexTracer(object): node (node) node_idx (int) """ + input_node = node.args[0] tranpose_dim = node.args[1:] - input_node_idx_trace = self._find_idx_trace_from_node(node.args[0]) - new_idx_trace = copy.deepcopy(input_node_idx_trace) - new_idx_trace[tranpose_dim[0]] = input_node_idx_trace[tranpose_dim[1]] - new_idx_trace[tranpose_dim[1]] = input_node_idx_trace[tranpose_dim[0]] - - self.idx_trace_list[node_idx]['idx'] = new_idx_trace - self._inherit_computation(node.args[0], node) + self._assign_index_as_input(node, node_idx, input_node) + self._inherit_index(input_node, tranpose_dim[1], node, tranpose_dim[0]) + self._inherit_index(input_node, tranpose_dim[0], node, tranpose_dim[1]) def _assign_permute_index(self, node, node_idx): """ @@ -296,14 +360,11 @@ class IndexTracer(object): node_idx (int) """ permute_dim = node.args[1:] - input_node_idx_trace = self._find_idx_trace_from_node(node.args[0]) + input_node = node.args[0] - new_idx_trace = copy.deepcopy(input_node_idx_trace) + self._assign_index_as_input(node, node_idx, input_node) for idx, d in enumerate(permute_dim): - new_idx_trace[idx] = input_node_idx_trace[d] - - self.idx_trace_list[node_idx]['idx'] = new_idx_trace - self._inherit_computation(node.args[0], node) + self._inherit_index(input_node, d, node, idx) def _assign_linear_index(self, node, node_idx): """ @@ -321,20 +382,15 @@ class IndexTracer(object): bias = None else: input_node, weight, bias = node.args - input_node_idx_trace = self._find_idx_trace_from_node(input_node) - weight_idx_trace = self._find_idx_trace_from_node(weight) - new_idx_trace = copy.deepcopy(input_node_idx_trace) - new_idx_trace[-1] = weight_idx_trace[1] - self.idx_trace_list[node_idx]['idx'] = new_idx_trace + self._assign_index_as_input(node, node_idx) + self._inherit_index(weight, 1, node, -1) - self._inherit_computation(input_node, node) self._mark_computation(node, node_idx, [-1]) - self._mark_idx_equal(input_node_idx_trace[-1], weight_idx_trace[0]) + self._mark_idx_equal(input_node, -1, weight, 0) if bias: - bias_idx_trace = self._find_idx_trace_from_node(bias) - self._mark_idx_equal(input_node_idx_trace[-1], bias_idx_trace[0]) + self._mark_idx_equal(input_node, -1, bias, 0) def _assign_matmul_index(self, node, node_idx): """ @@ -348,18 +404,14 @@ class IndexTracer(object): node_idx (int) """ matmul_left, matmul_right = node.args - matmul_left_idx_trace = self._find_idx_trace_from_node(matmul_left) - matmul_right_idx_trace = self._find_idx_trace_from_node(matmul_right) - assert(len(matmul_left_idx_trace) == len(matmul_right_idx_trace)) - new_idx_trace = copy.deepcopy(matmul_left_idx_trace) - new_idx_trace[-1] = matmul_right_idx_trace[-1] - self.idx_trace_list[node_idx]['idx'] = new_idx_trace + assert(len(_get_node_shape(matmul_left)) == len(_get_node_shape(matmul_right))) + self._assign_index_as_input(node, node_idx, matmul_left) + self._inherit_index(matmul_right, -1, node, -1) - self._inherit_computation(matmul_left, node) - self._inherit_computation(matmul_right, node) + self._mark_computation_from_node(matmul_right, node, [-1, -2]) self._mark_computation(node, node_idx, [-1]) - self._mark_idx_equal(matmul_left_idx_trace[-1], matmul_right_idx_trace[-2]) + self._mark_idx_equal(matmul_left, -1, matmul_right, -2) def _assign_layernorm_index(self, node, idx): """ @@ -372,7 +424,6 @@ class IndexTracer(object): node_idx (int) """ self._assign_index_as_input(node, idx) - self._inherit_computation(node.args[0], node) self._mark_computation(node, idx, [-1, -2]) def _assign_elementwise_index(self, node, idx): @@ -386,9 +437,59 @@ class IndexTracer(object): node_idx (int) """ self._assign_index_as_input(node, idx) + nodes_in = [] for node_in in node.args: - if type(node_in) not in (int, float): - self._inherit_computation(node_in, node) + if type(node_in) == type(node): + nodes_in.append(node_in) + self._mark_computation_from_node(node_in, node) + assert len(nodes_in) <= 2 + if len(nodes_in) == 2: + node_in0_shape = _get_node_shape(nodes_in[0]) + node_in1_shape = _get_node_shape(nodes_in[1]) + for i in range(-1, -min(len(node_in0_shape), len(node_in1_shape)) - 1, -1): + if node_in0_shape[i] == node_in1_shape[i]: + self._mark_idx_equal(nodes_in[0], i, nodes_in[1], i) + + def _assgin_no_change_index(self, node, idx): + self._assign_index_as_input(node, idx) + for node_in in node.args: + if type(node_in) == type(node): + self._mark_computation_from_node(node_in, node) + + def _assign_einsum_index(self, node, idx): + """ + Assign index for einsum op. + + Args: + node (node) + node_idx (int) + """ + patterns = node.args[0] + input_nodes = node.args[1:] + + patterns = patterns.replace(" ", "") + left, right = patterns.split("->") + left = left.split(",") + + all_index = [] + for i in left: + for c in i: + all_index.append(c) + all_index = set(all_index) + free_index = set([i for i in right]) + sum_index = all_index - free_index + + for right_idx, right_indice in enumerate(right): + for left_idx, left_str in enumerate(left): + if right_indice in left_str: + source_idx = left_str.index(right_indice) + self._inherit_index(input_nodes[left_idx], source_idx, node, right_idx) + + for i in sum_index: + for left_idx, left_str in enumerate(left): + if i in left_str: + self._mark_computation(node, idx, left_str.index(i)) + break def _assign_softmax_index(self, node, idx): """ @@ -401,7 +502,6 @@ class IndexTracer(object): node_idx (int) """ self._assign_index_as_input(node, idx) - self._inherit_computation(node.args[0], node) self._mark_computation(node, idx, [node.kwargs['dim']]) def _assign_unsqueeze_index(self, node, node_idx): @@ -412,10 +512,12 @@ class IndexTracer(object): Args: node (node) node_idx (int) - """ + """ + self._del_dim(node_idx, -1) self._assign_index_as_input(node, node_idx) - self._inherit_computation(node.args[0], node) self.idx_trace_list[node_idx]['idx'].insert(node.args[1], self._add_index()) + self.idx_trace_list[node_idx]['compute'].insert(node.args[1], []) + self.idx_trace_list[node_idx]['source'].insert(node.args[1], []) def _assign_dropout_index(self, node, node_idx): """ @@ -427,7 +529,6 @@ class IndexTracer(object): node_idx (int) """ self._assign_index_as_input(node, node_idx) - def _assign_ones_like_index(self, node, node_idx): """ @@ -439,17 +540,6 @@ class IndexTracer(object): node_idx (int) """ self._assign_all_index(node, node_idx) - - def _assign_to_index(self, node, node_idx): - """ - Assign index for to op. - 1. assign new index for all dim - - Args: - node (node) - node_idx (int) - """ - self._assign_index_as_input(node, node_idx) def _assign_view_reshape_index(self, node, node_idx): """ @@ -494,26 +584,26 @@ class IndexTracer(object): dim_equal = [i == j for i, j in zip(origin_shape[:-1], target_shape)] dim_to = [dim_equal.index(False)] dim_from = [dim_equal.index(False), dim_equal.index(False) + 1] + self._add_dim(node_idx, -1) elif len_diff == -1: # dim expand dim_equal = [i == j for i, j in zip(origin_shape, target_shape[:-1])] dim_from = [dim_equal.index(False)] dim_to = [dim_equal.index(False), dim_equal.index(False) + 1] + self._del_dim(node_idx, -1) else: raise NotImplementedError("shape" + str(origin_shape) + 'and' + str(target_shape) + "view not implemented") # get new index origin_trace = self._find_idx_trace_from_node(origin_node) - new_trace = copy.deepcopy(origin_trace) + self._assign_index_as_input(node, node_idx, origin_node) dim_from.reverse() for i in dim_from: - new_trace.pop(i) + self._del_dim(node_idx, i) for i in dim_to: - new_trace.insert(i, self._add_index()) - self.idx_trace_list[node_idx]['idx'] = new_trace + self._add_dim(node_idx, i) # inherit computation - self._inherit_computation(origin_node, node) compute_log = self._find_compute_trace_from_node(origin_node) for i in dim_from: if origin_trace[i] in compute_log: @@ -524,15 +614,10 @@ class IndexTracer(object): # log view, not used now view_dict = {"idx_from": [origin_trace[i] for i in dim_from], "dim_from": dim_from, - "idx_to": [new_trace[i] for i in dim_to], + "idx_to": [self.idx_trace_list[node_idx]['idx'][i] for i in dim_to], "dim_to": dim_to} self.idx_view_list.append(view_dict) - - def _remove_duplicate_compute(self): - for i in self.idx_trace_list: - for k, v in i['compute'].items(): - i['compute'][k] = list(set(v)) - + def _merge_equal_idx(self): idx_equal = copy.deepcopy(self.idx_trace_equal) idx_equal.reverse() @@ -556,8 +641,8 @@ class IndexTracer(object): self._assign_view_reshape_index(node, idx) elif 'unsqueeze' in node.name: self._assign_unsqueeze_index(node, idx) - elif 'to' in node.name: - self._assign_to_index(node, idx) + elif any(i in node.name for i in ['to', 'contiguous']): + self._assgin_no_change_index(node, idx) else: raise NotImplementedError(node.name, "method not implemented yet!") elif node.op == 'call_function': @@ -573,6 +658,8 @@ class IndexTracer(object): self._assign_ones_like_index(node, idx) elif 'dropout' in node.name: self._assign_dropout_index(node, idx) + elif 'einsum' in node.name: + self._assign_einsum_index(node, idx) elif 'getattr' in node.name: continue # get attr like shape elif 'getitem' in node.name: @@ -590,10 +677,20 @@ class IndexTracer(object): continue else: raise NotImplementedError(node.op, "op not implemented yet!") - - self._remove_duplicate_compute() - self._merge_equal_idx() - + # self._merge_equal_idx() + + def check_index(self, trace_idx, start_idx, end_idx): + for i in range(start_idx, end_idx + 1): + cur_idx = self.idx_trace_list[i]['idx'] + cur_compute = self.idx_trace_list[i]['compute'] + if trace_idx in cur_compute: + for j in cur_compute[trace_idx]: + if j < start_idx or j > end_idx: + return False + # same_idx = [1 if j == trace_idx else 0 for j in cur_idx] + # if sum(same_idx) > 1: + # return False + return True class MemoryEstimator(object): def __init__(self) -> None: @@ -897,6 +994,8 @@ class ChunkRegionSearch(object): self._is_not_compute(after_trace, (start_idx, end_idx), i) and self.node_list[end_idx].meta['tensor_meta'].shape[i] != 1): continue + if not self.index_tracer.check_index(before_trace['idx'][i], start_idx, end_idx): + continue flow_flag, chunk_info = self._detect_flow(before_trace, after_trace, start_idx, end_idx, i) if flow_flag == None: continue @@ -910,7 +1009,10 @@ class ChunkRegionSearch(object): input_trace = [] for i, n in enumerate(self.node_list): if len(n.args) > 0 and n.op != 'output': - input_idx = _find_idx_by_name(n.args[0].name, self.node_list) + if isinstance(n.args[0], str): + input_idx = _find_idx_by_name(n.args[1].name, self.node_list) + else: + input_idx = _find_idx_by_name(n.args[0].name, self.node_list) input_trace.append(output_trace[input_idx]) else: input_trace.append(None) @@ -930,7 +1032,7 @@ class ChunkRegionSearch(object): if len(free_dim) > 0: free_dim = [free_dim[0]] chunk_info = [chunk_info[0]] - possible_chunk_region.append({'region': (start_idx, end_idx), 'dim': free_dim, 'chunk_info': chunk_info}) + possible_chunk_region.append({'region': (start_idx, end_idx), 'dim': free_dim, 'chunk_info': chunk_info}) return possible_chunk_region def _search_best_chunk_region(self, possible_chunk_regions): @@ -1130,6 +1232,7 @@ def emit_code_with_chunk(body, ckpt_func, nodes, emit_node_func, delete_unused_v if node_idx in chunk_starts: within_chunk_region = True + region_idx = chunk_starts.index(node_idx) # add for loop chunk_input_meta = [meta_nodes[i] for i in chunk_inputs_idx[region_idx]] @@ -1150,7 +1253,6 @@ def emit_code_with_chunk(body, ckpt_func, nodes, emit_node_func, delete_unused_v if node_idx in chunk_ends: body.append(_gen_loop_end(node, chunk_inputs[region_idx], node_list, chunk_dims[region_idx])) within_chunk_region = False - region_idx += 1 node_idx += 1