From d361d533e8e7773d2009cc4ff5a82633401ab44a Mon Sep 17 00:00:00 2001 From: oahzxl Date: Wed, 21 Dec 2022 15:01:03 +0800 Subject: [PATCH] refactor flow tracer --- chunk_codegen.py | 281 +++++++++++++++++++++++++++++++++-------- evoformer/evoformer.py | 11 +- 2 files changed, 239 insertions(+), 53 deletions(-) diff --git a/chunk_codegen.py b/chunk_codegen.py index 2c1c09ae5..3ba082ceb 100644 --- a/chunk_codegen.py +++ b/chunk_codegen.py @@ -139,7 +139,13 @@ class IndexTracer(object): node_from_idx = _find_idx_by_name(node_from.name, self.nodes_list) if init: node_to_trace["source"][node_to_dim] = {} - node_to_trace["source"][node_to_dim][node_from_idx] = node_from_dim + # add dim to cur new source + if node_from_idx not in node_to_trace["source"][node_to_dim]: + node_to_trace["source"][node_to_dim][node_from_idx] = [node_from_dim] + else: + if node_from_dim not in node_to_trace["source"][node_to_dim][node_from_idx]: + node_to_trace["source"][node_to_dim][node_from_idx].append(node_from_dim) + # update inputs source node_to_trace["source"][node_to_dim].update( node_from_trace["source"][node_from_dim] ) @@ -654,7 +660,7 @@ class IndexTracer(object): end_node_trace_source.items(), key=lambda d: d[0], reverse=True ) for node_idx, node_dim in sorted_source: - if node_idx == start_node_idx and node_dim == start_dim: + if node_idx == start_node_idx and start_dim in node_dim: return True # it means we meet a node outside the loop, and the node is not input node if node_idx < start_idx: @@ -694,12 +700,12 @@ class IndexTracer(object): for node_dim in range(len(_get_node_shape(node))): if ( input_node_idx in node_trace_source[node_dim] - and node_trace_source[node_dim][input_node_idx] == input_dim + and input_dim in node_trace_source[node_dim][input_node_idx] ): return node_dim return None - def check_index_duplicate(self, chunk_infos): + def check_index_duplicate(self, chunk_infos, return_dim=False): input_dim_after_node = {} for input_node_idx, input_node in enumerate(chunk_infos["inputs"]): for k, v in chunk_infos["inputs_dim"][input_node_idx].items(): @@ -713,17 +719,30 @@ class IndexTracer(object): if _is_non_compute_node_except_placeholder(node): continue count = 0 + duplicate_dims = [] node_trace_source = self._find_source_trace_from_node(node) for node_dim in range(len(_get_node_shape(node))): + duplicate_dim = [] + duplicate_flag = False dim_source = node_trace_source[node_dim] for k, v in dim_source.items(): if chunk_infos["region"][0] <= k <= chunk_infos["region"][1]: - if k in input_dim_after_node and input_dim_after_node[k] == v: - count += 1 - break + if k in input_dim_after_node and input_dim_after_node[k] in v: + duplicate_flag = True + duplicate_dim.append((k, v)) + duplicate_dims.append(duplicate_dim) + if duplicate_flag: + count += 1 + if count > 1: - return False - return True + if return_dim: + return False, duplicate_dims + else: + return False + if return_dim: + return True, None + else: + return True @@ -857,43 +876,45 @@ class FlowTracer(object): flow_block = True return flow_block, chunk_info - for idx in range(start_idx, end_idx + 1): - node = self.node_list[idx] - mix_flow_node = self._get_flow_mix_node(node) - if mix_flow_node is None: - continue + # for idx in range(start_idx, end_idx + 1): + # node = self.node_list[idx] + # mix_flow_node = self._get_flow_mix_node(node) + # if mix_flow_node is None: + # continue - # if there is a flow mix, op must be in [mul, add, matmul] - # element-wise op requires dim to be equal in every dim - if any(n in node.name for n in ["mul", "add"]): - for i in node.args: - if type(i) == type(mix_flow_node) and i != mix_flow_node: - main_flow_var = i - # if mix flow is a broadcast in chunk dim, - # TODO: need to move that flow out of the chunk - mix_flow_node_dim = index_tracer.get_node_chunk_dim( - self.node_list[end_idx], end_dim, node - ) - if mix_flow_node_dim is None: - flow_block = True - break - if _get_node_shape(mix_flow_node)[mix_flow_node_dim] == 1: - flow_block = False - for i in self._get_same_flow_node( - chunk_info["inputs"], mix_flow_node - ): - chunk_info["inputs"].remove(i) - # else, we need to chunk mix var as well - else: - # TODO chunk another value - flow_block = True - break - else: - raise NotImplementedError("%s not implemented" % node.name) - - if flow_block: - flow_block = True - return flow_block, chunk_info + # # if there is a flow mix, op must be in [mul, add, matmul] + # # element-wise op requires dim to be equal in every dim + # if any(n in node.name for n in ["mul", "add"]): + # for i in node.args: + # if type(i) == type(mix_flow_node) and i != mix_flow_node: + # main_flow_var = i + # # if mix flow is a broadcast in chunk dim, + # # TODO: need to move that flow out of the chunk + # mix_flow_node_dim = index_tracer.get_node_chunk_dim( + # self.node_list[end_idx], end_dim, node + # ) + # # TODO: we need to loop every dim + # if isinstance(mix_flow_node_dim, list): + # mix_flow_node_dim = mix_flow_node_dim[0] + # if mix_flow_node_dim is None: + # flow_block = True + # break + # if _get_node_shape(mix_flow_node)[mix_flow_node_dim] == 1: + # flow_block = False + # for i in self._get_same_flow_node( + # chunk_info["inputs"], mix_flow_node + # ): + # chunk_info["inputs"].remove(i) + # # else, we need to chunk mix var as well + # else: + # # TODO chunk another value + # flow_block = True + # break + # else: + # raise NotImplementedError("%s not implemented" % node.name) + # if flow_block: + # flow_block = True + # return flow_block, chunk_info inputs_dim = [] remove_inputs = [] @@ -908,6 +929,9 @@ class FlowTracer(object): dim = index_tracer.get_node_chunk_dim( self.node_list[end_idx], end_dim, input_node ) + # TODO: we need to loop every dim + if isinstance(dim, list): + dim = dim[0] elif user_idx == end_idx: dim = end_dim # n has relation with chunk dim @@ -921,6 +945,8 @@ class FlowTracer(object): for i in remove_inputs: if i in chunk_info["inputs"]: chunk_info["inputs"].remove(i) + + duplicate_result, duplicate_dim = index_tracer.check_index_duplicate(chunk_info, return_dim=True) # we need to log input nodes to avoid deleteing them in the loop non_chunk_inputs = _find_chunk_all_input_nodes( @@ -932,6 +958,150 @@ class FlowTracer(object): return flow_block, chunk_info + def _assgin_single_node_flow(self, arg_node, start_idx, end_idx, + inputs, index_tracer, cur_node_dim, + cur_node_compute, cur_node_source, cur_node_fix_dim, all_node_info, + next_node_list): + arg_idx = _find_idx_by_name(arg_node.name, index_tracer.nodes_list) + # arg in chunk range or be inputs + if not (start_idx <= arg_idx < end_idx): + return True + + # find arg dim + if cur_node_dim is not None: + # dim is computed + if arg_idx in cur_node_compute[cur_node_dim]: + return False + if arg_idx not in cur_node_source[cur_node_dim]: + arg_dim = None + else: + arg_dim = cur_node_source[cur_node_dim][arg_idx][0] + else: + arg_dim = None + + # get fix dim + arg_fix_dim = [] + if cur_node_dim is not None: + for i in cur_node_fix_dim: + fix_dim_source = cur_node_source[i] + if arg_idx in fix_dim_source: + arg_fix_dim.append(fix_dim_source[arg_idx][0]) + + # if already in node_info, arg dim must be same + if arg_node in all_node_info: + if all_node_info[arg_node] != arg_dim: + return False + all_node_info[arg_node]['fix_dim'] = list(set(all_node_info[arg_node]['fix_dim'] + arg_fix_dim)) + # else add it to list + else: + all_node_info[arg_node] = {'chunk_dim': arg_dim, 'fix_dim': arg_fix_dim} + + next_node_list.append(arg_node) + return True + + def flow_search(self, start_idx, start_dim, end_idx, end_dim, index_tracer: IndexTracer): + inputs, outputs = _find_chunk_compute_input_and_output_nodes( + self.node_list[start_idx : end_idx + 1] + ) + # only single ouput + if len(outputs) > 1: + return None + + cur_node_list = [index_tracer.nodes_list[end_idx]] # start from the last node + all_node_info = {cur_node_list[0]: {'chunk_dim': end_dim, 'fix_dim': []}} + + while len(cur_node_list) > 0: + next_node_list = [] + + for cur_node in cur_node_list: + # get cur node info + cur_node_chunk_dim = all_node_info[cur_node]['chunk_dim'] + cur_node_fix_dim = all_node_info[cur_node]['fix_dim'] + cur_node_idx = _find_idx_by_name(cur_node.name, index_tracer.nodes_list) + if cur_node_chunk_dim: + cur_node_compute = index_tracer._find_compute_trace_from_node(cur_node) + cur_node_source = index_tracer._find_source_trace_from_node(cur_node) + else: + cur_node_compute = cur_node_source = None + + # get all valid args + arg_list = [] + for arg in cur_node.args: + if type(arg) != type(cur_node): + continue + if _is_non_compute_node(arg): + continue + arg_list.append(arg) + flow_flag = self._assgin_single_node_flow(arg, start_idx, end_idx, + inputs, index_tracer, cur_node_chunk_dim, + cur_node_compute, cur_node_source, cur_node_fix_dim, all_node_info, + next_node_list) + if flow_flag == False: + return None + + if len(arg_list) == 2: + if any(i in cur_node.name for i in ["add", "mul"]): + for arg in arg_list: + if not (start_idx <= _find_idx_by_name(arg.name, index_tracer.nodes_list) < end_idx): + continue + arg_chunk_dim = all_node_info[arg]['chunk_dim'] + arg_fix_dim = all_node_info[arg]['fix_dim'] + arg_shape = _get_node_shape(arg) + # add all dim as fix dim except chunk dim + for i, shape in enumerate(arg_shape): + if shape != 1 and i != cur_node_chunk_dim: + if i == arg_chunk_dim: + return None + if i not in arg_fix_dim: + arg_fix_dim.append(i) + elif "einsum" in cur_node.name: + pass + elif "matmul" in cur_node.name: + pass + else: + raise NotImplementedError() + cur_node_list = next_node_list + + inputs_dim = [] + remove_inputs = [] + for input_node in inputs: + input_dict = {} + for user in input_node.users.keys(): + if _is_non_compute_node(user): + continue + user_idx = _find_idx_by_name(user.name, self.node_list) + if start_idx <= user_idx <= end_idx: + chunk_dim = all_node_info[user]['chunk_dim'] + if chunk_dim is not None: + input_dict[user_idx] = chunk_dim + if len(input_dict) == 0: + remove_inputs.append(input_node) + else: + inputs_dim.append(input_dict) + for i in remove_inputs: + if i in inputs: + inputs.remove(i) + + chunk_info = { + "region": (start_idx, end_idx), + "inputs": inputs, + "inputs_non_chunk": [], + "inputs_dim": inputs_dim, + "outputs": outputs, + "outputs_dim": end_dim, + "args": {}, + } + + # we need to log input nodes to avoid deleteing them in the loop + non_chunk_inputs = _find_chunk_all_input_nodes( + self.node_list[start_idx : end_idx + 1] + ) + for i in non_chunk_inputs: + if i not in chunk_info["inputs"]: + chunk_info["inputs_non_chunk"].append(i) + + return chunk_info + class MemoryEstimator(object): def __init__(self, index_tracer: IndexTracer) -> None: @@ -1055,12 +1225,13 @@ class MemoryEstimator(object): node_source = self.index_tracer._find_source_trace_from_node(node) for (input_node, input_node_dim) in zip(chunk_inputs, chunk_inputs_dim): for k, v in input_node_dim.items(): + # TODO: inherit dim should be list too, int now inherit_dim = self.index_tracer._find_inherit_dim(input_node, v, self.index_tracer.nodes_list[k]) if k == _find_idx_by_name(node.name, self.index_tracer.nodes_list): chunk_ratio = float(chunk_size) / node_shape[inherit_dim] return chunk_ratio for dim, source in enumerate(node_source): - if k in source and source[k] == inherit_dim: + if k in source and inherit_dim in source[k]: chunk_ratio = float(chunk_size) / node_shape[dim] return chunk_ratio return 1. @@ -1323,9 +1494,11 @@ class ChunkRegionSearch(object): continue for start_node, start_trace in start_traces.items(): for start_dim, start_trace_idx in enumerate(start_trace["idx"]): - # must be same trace idx - if start_trace_idx != end_trace_idx: - continue + if start_idx == 199 and end_idx == 229 and start_dim == 2 and end_dim == 2: + print(1) + self.flow_tracer.flow_search( + start_idx, start_dim, end_idx, end_dim, self.index_tracer + ) # dim size cannot be 1 if ( _get_node_shape(end_node)[end_dim] == 1 @@ -1343,10 +1516,16 @@ class ChunkRegionSearch(object): ): continue # detect flow meet - flow_block, chunk_info = self.flow_tracer._detect_flow( + # flow_block, chunk_info = self.flow_tracer._detect_flow( + # start_idx, start_dim, end_idx, end_dim, self.index_tracer + # ) + # if flow_block: + # continue + # flow search + chunk_info = self.flow_tracer.flow_search( start_idx, start_dim, end_idx, end_dim, self.index_tracer ) - if flow_block: + if chunk_info is None: continue # check index copmute if not self.index_tracer.check_index_duplicate(chunk_info): diff --git a/evoformer/evoformer.py b/evoformer/evoformer.py index 0c5ab952a..cfd2bb2a2 100644 --- a/evoformer/evoformer.py +++ b/evoformer/evoformer.py @@ -6,6 +6,13 @@ from .ops import OutProductMean from .triangle import PairStack +def print_memory(init_mem, text=None): + now_mem = torch.cuda.memory_allocated() / 1024 ** 2 - init_mem + max_mem = torch.cuda.max_memory_allocated() / 1024 ** 2 - init_mem + print("%s now:%.2f max:%.2f" % ("" if text is None else text, now_mem, max_mem)) + torch.cuda.reset_peak_memory_stats() + + class EvoformerBlock(nn.Module): def __init__(self, d_node, d_pair): @@ -16,9 +23,9 @@ class EvoformerBlock(nn.Module): self.pair_stack = PairStack(d_pair=d_pair) def forward(self, node, pair): - node = node + self.msa_stack(node, pair) + node = self.msa_stack(node, pair) pair = pair + self.communication(node) - pair = pair + self.pair_stack(pair) + pair = self.pair_stack(pair) return node, pair