diff --git a/colossalai/autochunk/autochunk_codegen.py b/colossalai/autochunk/autochunk_codegen.py index de5e7356b..8c3155a60 100644 --- a/colossalai/autochunk/autochunk_codegen.py +++ b/colossalai/autochunk/autochunk_codegen.py @@ -22,7 +22,7 @@ if CODEGEN_AVAILABLE: from torch.fx.node import Argument, Node, _get_qualified_name, _type_repr, map_arg from .search_chunk import SearchChunk -from .utils import delete_free_var_from_last_use, find_idx_by_name, get_node_shape +from .utils import delete_free_var_from_last_use, find_idx_by_name, get_logger, get_node_shape def _gen_chunk_slice_dim(chunk_dim: int, chunk_indice_name: str, shape: List) -> str: @@ -276,11 +276,17 @@ if CODEGEN_AVAILABLE: class AutoChunkCodeGen(CodeGen): - def __init__(self, meta_graph, max_memory=None, print_mem=False): + def __init__(self, + meta_graph, + max_memory: int = None, + print_mem: bool = False, + print_progress: bool = False) -> None: super().__init__() # find the chunk regions - self.search_chunk = SearchChunk(meta_graph, max_memory, print_mem) + self.search_chunk = SearchChunk(meta_graph, max_memory, print_mem, print_progress) self.chunk_infos = self.search_chunk.search_region() + if print_progress: + get_logger().info("AutoChunk start codegen") def _gen_python_code(self, nodes, root_module: str, namespace: _Namespace) -> PythonCode: free_vars: List[str] = [] diff --git a/colossalai/autochunk/estimate_memory.py b/colossalai/autochunk/estimate_memory.py index 21f34481b..a03a5413b 100644 --- a/colossalai/autochunk/estimate_memory.py +++ b/colossalai/autochunk/estimate_memory.py @@ -43,6 +43,8 @@ class EstimateMemory(object): delete_node = [] if user.op not in ("output",): nodes_to_delete = user_to_last_uses.get(user, []) + if len(user.users) == 0: + nodes_to_delete.append(user) if to_keep is not None: keep_list = [] for n in nodes_to_delete: @@ -135,6 +137,8 @@ class EstimateMemory(object): if user.op in ("placeholder", "output"): return 0 nodes_to_delete = user_to_last_uses.get(user, []) + if len(user.users) == 0: + nodes_to_delete.append(user) delete_size = 0 for n in nodes_to_delete: if n.name in chunk_inputs_names: @@ -294,3 +298,26 @@ class EstimateMemory(object): # param_memory = parameter_size(gm) # all_memory = act_memory + param_memory return act_memory_peak_log, act_memory_after_node_log, active_node_list_log + + def get_active_nodes(self, node_list: List) -> List: + """ + Get active nodes for every node + + Args: + node_list (List): _description_ + + Returns: + active_node_list_log (List): active nodes of every node. active nodes refer to + nodes generated but not deleted. + """ + active_node_list = [] + active_node_list_log = [] + user_to_last_uses = self._get_last_usr(node_list) + user_to_last_uses_no_free_var = self._get_last_usr(node_list) + delete_free_var_from_last_use(user_to_last_uses_no_free_var) + for _, node in enumerate(node_list): + # log active node, only effective without chunk + self._add_active_node(node, active_node_list) + self._remove_deactive_node(node, user_to_last_uses, active_node_list) + active_node_list_log.append(copy.deepcopy(active_node_list)) + return active_node_list_log diff --git a/colossalai/autochunk/search_chunk.py b/colossalai/autochunk/search_chunk.py index 236f9697d..a86196712 100644 --- a/colossalai/autochunk/search_chunk.py +++ b/colossalai/autochunk/search_chunk.py @@ -8,7 +8,7 @@ from .reorder_graph import ReorderGraph from .select_chunk import SelectChunk from .trace_flow import TraceFlow from .trace_indice import TraceIndice -from .utils import get_node_shape, is_non_compute_node, is_non_compute_node_except_placeholder +from .utils import get_logger, get_node_shape, is_non_compute_node, is_non_compute_node_except_placeholder class SearchChunk(object): @@ -40,14 +40,14 @@ class SearchChunk(object): print_mem (bool): print estimated memory """ - def __init__(self, gm, max_memory=None, print_mem=False) -> None: - self.gm = gm + def __init__(self, gm, max_memory=None, print_mem=False, print_progress=False) -> None: self.print_mem = print_mem + self.print_progress = print_progress self.trace_indice = TraceIndice(list(gm.graph.nodes)) - self.trace_indice.trace_indice() + self.estimate_memory = EstimateMemory() + self._init_trace() self.trace_flow = TraceFlow(self.trace_indice) self.reorder_graph = ReorderGraph(self.trace_indice) - self.estimate_memory = EstimateMemory() self.select_chunk = SelectChunk( self.trace_indice, self.estimate_memory, @@ -55,7 +55,33 @@ class SearchChunk(object): max_memory=max_memory, ) - def _find_peak_node(self, mem_peak): + def _init_trace(self) -> None: + """ + find the max trace range for every node + reduce the computation complexity of trace_indice + """ + # find all max ranges + active_nodes = self.estimate_memory.get_active_nodes(self.trace_indice.node_list) + cur_node_idx = len(self._get_free_var_idx()) + max_chunk_region_list = [] + while True: + max_chunk_region = self._search_max_chunk_region(active_nodes, cur_node_idx) + cur_node_idx = max_chunk_region[1] + if cur_node_idx == len(active_nodes) - 1: + break + max_chunk_region_list.append(max_chunk_region) + + # nothing to limit for the first range + max_chunk_region_list = max_chunk_region_list[1:] + max_chunk_region_list[0] = (0, max_chunk_region_list[0][1]) + + # set trace range and do the trace + if self.print_progress: + get_logger().info("AutoChunk start tracing indice") + self.trace_indice.set_trace_range(max_chunk_region_list, active_nodes) + self.trace_indice.trace_indice() + + def _find_peak_node(self, mem_peak: List) -> int: max_value = max(mem_peak) max_idx = mem_peak.index(max_value) return max_idx @@ -73,7 +99,7 @@ class SearchChunk(object): free_var_idx.append(idx) return free_var_idx - def _search_max_chunk_region(self, active_node: List, peak_node: Node, chunk_regions: List) -> Tuple: + def _search_max_chunk_region(self, active_node: List, peak_node_idx: int, chunk_regions: List = None) -> Tuple: """ Search max chunk region according to peak memory node @@ -81,7 +107,7 @@ class SearchChunk(object): Args: active_node (List): active node status for every node - peak_node (Node): peak memory node + peak_node_idx (int): peak memory node idx chunk_regions (List): chunk region infos Returns: @@ -97,7 +123,7 @@ class SearchChunk(object): # from peak_node to free_var inside_flag = False chunk_region_start = free_var_num - for i in range(peak_node, -1, -1): + for i in range(peak_node_idx, -1, -1): if active_node_num[i] <= threshold: inside_flag = True if inside_flag and active_node_num[i] > threshold: @@ -107,21 +133,23 @@ class SearchChunk(object): # from peak_node to len-2 inside_flag = False chunk_region_end = len(active_node) - 1 - for i in range(peak_node, len(active_node)): + for i in range(peak_node_idx, len(active_node)): if active_node_num[i] <= threshold: inside_flag = True if inside_flag and active_node_num[i] > threshold: chunk_region_end = i break - for i in chunk_regions: - region = i["region"] - if chunk_region_start >= region[0] and chunk_region_end <= region[1]: - return None - elif (region[0] <= chunk_region_start <= region[1] and chunk_region_end > region[1]): - chunk_region_start = region[1] + 1 - elif (region[0] <= chunk_region_end <= region[1] and chunk_region_start < region[0]): - chunk_region_end = region[0] - 1 + # avoid chunk regions overlap + if chunk_regions is not None: + for i in chunk_regions: + region = i["region"] + if chunk_region_start >= region[0] and chunk_region_end <= region[1]: + return None + elif (region[0] <= chunk_region_start <= region[1] and chunk_region_end > region[1]): + chunk_region_start = region[1] + 1 + elif (region[0] <= chunk_region_end <= region[1] and chunk_region_start < region[0]): + chunk_region_end = region[0] - 1 return chunk_region_start, chunk_region_end def _find_chunk_info(self, input_trace, output_trace, start_idx, end_idx) -> List: @@ -154,6 +182,9 @@ class SearchChunk(object): # dim size cannot be 1 if (get_node_shape(end_node)[end_dim] == 1 or get_node_shape(start_node)[start_dim] == 1): continue + # must have users + if len(end_node.users) == 0: + continue # check index source align if not self.trace_flow.check_index_source(start_dim, start_node, start_idx, end_dim, end_node): continue @@ -253,6 +284,9 @@ class SearchChunk(object): Returns: chunk_infos (Dict) """ + if self.print_progress: + get_logger().info("AutoChunk start searching chunk regions") + chunk_infos = [] ( init_mem_peak, @@ -272,6 +306,11 @@ class SearchChunk(object): _, active_node, ) = self.estimate_memory.estimate_chunk_inference_mem(self.trace_indice.node_list, chunk_infos) + + if self.print_progress: + get_logger().info("AutoChunk find chunk region %d = (%d, %d)" % + (len(chunk_infos), chunk_info["region"][0], chunk_info["region"][1])) + if self._stop_search(init_mem_peak, mem_peak): break if self.print_mem: diff --git a/colossalai/autochunk/trace_flow.py b/colossalai/autochunk/trace_flow.py index e657c188e..830b4629e 100644 --- a/colossalai/autochunk/trace_flow.py +++ b/colossalai/autochunk/trace_flow.py @@ -281,7 +281,10 @@ class TraceFlow(object): if chunk_dim is not None: user_source = self.trace_indice._find_source_trace_from_node(user)[chunk_dim] if input_node_idx in user_source: - input_dict[user_idx] = user_source[input_node_idx] + if get_node_shape(input_node)[user_source[input_node_idx][0]] == 1: + input_dict[user_idx] = [None] + else: + input_dict[user_idx] = user_source[input_node_idx] else: return None, None if len(input_dict) == 0: diff --git a/colossalai/autochunk/trace_indice.py b/colossalai/autochunk/trace_indice.py index 5c2e9b520..827f60d8b 100644 --- a/colossalai/autochunk/trace_indice.py +++ b/colossalai/autochunk/trace_indice.py @@ -33,6 +33,8 @@ class TraceIndice(object): self.indice_trace_list = self._init_indice_trace_list() self.indice_view_list = {} self.indice_count = -1 + self.trace_range = [] + self.active_node_list = [] def _init_indice_trace_list(self): indice_trace_list = [] @@ -48,6 +50,10 @@ class TraceIndice(object): indice_trace_list.append(cur_trace) return indice_trace_list + def set_trace_range(self, trace_range: List, active_node_list: List) -> None: + self.trace_range = trace_range + self.active_node_list = active_node_list + def _add_indice(self): """ Update the count and return it. To record the idx number. @@ -493,6 +499,9 @@ class TraceIndice(object): new_dim_num = sum([1 if str(i) == "None" else 0 for i in node_args]) for _ in range(new_dim_num): self._del_dim(node_idx, 0) + delete_dim_num = sum([1 if str(i) == "0" else 0 for i in node_args]) + for _ in range(delete_dim_num): + self._add_dim(node_idx, 0) self._assign_indice_as_input(node, node_idx) for _, node_arg in enumerate(node_args): @@ -513,6 +522,9 @@ class TraceIndice(object): elif "None" == node_arg_str: self._add_dim(node_idx, new_idx_count) new_idx_count += 1 + elif "0" == node_arg_str: + self._del_dim(node_idx, new_idx_count) + origin_idx_count += 1 else: raise NotImplementedError() @@ -596,6 +608,37 @@ class TraceIndice(object): } self.indice_view_list[node] = view_dict + def _clear_trace(self, node_idx: int) -> None: + """ + clear too far trace to speed up computation + """ + trace_range = None + for i in range(len(self.trace_range)): + if self.trace_range[i][1] == node_idx: + trace_range = (self.trace_range[i][0], self.trace_range[i][1]) + break + if self.trace_range[i][1] > node_idx: + break + if trace_range is None: + return + + active_nodes = self.active_node_list[trace_range[0]:trace_range[1] + 1] + active_nodes = set(flat_list(active_nodes)) + active_nodes = [find_idx_by_name(i, self.node_list) for i in active_nodes] + for i in range(trace_range[0], trace_range[1] + 1): + trace = self.indice_trace_list[i] + # clear compute + for dim_compute in trace["compute"]: + for i in range(len(dim_compute) - 1, -1, -1): + if dim_compute[i] < trace_range[0] and dim_compute[i] not in active_nodes: + dim_compute.pop(i) + continue + # clear source + for dim_source in trace["source"]: + for k in list(dim_source.keys()): + if k < trace_range[0] and k not in active_nodes: + dim_source.pop(k) + def trace_indice(self): for idx, node in enumerate(self.node_list): if node.op == "placeholder": @@ -655,3 +698,6 @@ class TraceIndice(object): continue else: raise NotImplementedError(node.op, "op not implemented yet!") + + # limit trace range + self._clear_trace(idx) diff --git a/colossalai/autochunk/utils.py b/colossalai/autochunk/utils.py index ff1a64bc3..e87068512 100644 --- a/colossalai/autochunk/utils.py +++ b/colossalai/autochunk/utils.py @@ -2,6 +2,14 @@ from typing import Any, Callable, Dict, Iterable, List, Tuple from torch.fx.node import Node +from colossalai.logging import get_dist_logger + +logger = get_dist_logger() + + +def get_logger(): + return logger + def flat_list(inputs: Any) -> List: """ diff --git a/tests/test_autochunk/test_evoformer_stack_codegen.py b/tests/test_autochunk/test_evoformer_stack_codegen.py new file mode 100644 index 000000000..5fabb2702 --- /dev/null +++ b/tests/test_autochunk/test_evoformer_stack_codegen.py @@ -0,0 +1,163 @@ +from functools import partial + +import pytest +import torch +import torch.fx +import torch.multiprocessing as mp + +try: + from fastfold.model.nn.evoformer import EvoformerStack + HAS_REPO = True +except: + HAS_REPO = False + +import colossalai +from colossalai.core import global_context as gpc +from colossalai.fx._compatibility import is_compatible_with_meta +from colossalai.fx.codegen.activation_checkpoint_codegen import CODEGEN_AVAILABLE +from colossalai.fx.graph_module import ColoGraphModule +from colossalai.fx.passes.meta_info_prop import MetaInfoProp +from colossalai.utils import free_port + +if CODEGEN_AVAILABLE and is_compatible_with_meta(): + from colossalai.autochunk.autochunk_codegen import AutoChunkCodeGen + from colossalai.fx.profiler import MetaTensor + from colossalai.fx.tracer.experimental import ColoTracer, symbolic_trace + + +def _test_fwd(model: torch.nn.Module, gm: ColoGraphModule, node, pair, node_mask, pair_mask): + # for memory test + # model = model.cuda() + # torch.cuda.reset_peak_memory_stats() + # now_mem = torch.cuda.memory_allocated() / 1024**2 + # with torch.no_grad(): + # node1 = node.clone() + # pair1 = pair.clone() + # node_mask1 = node_mask.clone() + # pair_mask1 = pair_mask.clone() + # gm(node1, pair1, node_mask1, pair_mask1, None) + # new_max_mem = torch.cuda.max_memory_allocated() / 1024**2 + # print("autochunk max mem:%.2f"% (new_max_mem - now_mem)) + + # test forward + model = model.cuda() + with torch.no_grad(): + non_fx_out = model(node, pair, node_mask, pair_mask, None) + fx_out = gm(node, pair, node_mask, pair_mask, None) + + assert torch.allclose(non_fx_out[0], fx_out[0], + atol=1e-4), "fx_out doesn't comply with original output, diff is %.2e" % torch.mean( + torch.abs(non_fx_out[0] - fx_out[0])) + assert torch.allclose(non_fx_out[1], fx_out[1], + atol=1e-4), "fx_out doesn't comply with original output, diff is %.2e" % torch.mean( + torch.abs(non_fx_out[1] - fx_out[1])) + + +def _build_openfold(): + model = EvoformerStack( + c_m=256, + c_z=128, + c_hidden_msa_att=32, + c_hidden_opm=32, + c_hidden_mul=128, + c_hidden_pair_att=32, + c_s=384, + no_heads_msa=8, + no_heads_pair=4, + no_blocks=2, # 48 + transition_n=4, + msa_dropout=0.15, + pair_dropout=0.25, + blocks_per_ckpt=None, + inf=1000000000.0, + eps=1e-08, + clear_cache_between_blocks=False, + is_multimer=False, + ).eval().cuda() + return model + + +def _test_evoformer_stack_codegen(rank, msa_len, pair_len, max_memory): + # launch colossalai + colossalai.launch( + config={}, + rank=rank, + world_size=1, + host="localhost", + port=free_port(), + backend="nccl", + ) + + # build model and input + model = _build_openfold() + node = torch.randn(1, msa_len, pair_len, 256).cuda() + node_mask = torch.randn(1, msa_len, pair_len).cuda() + pair = torch.randn(1, pair_len, pair_len, 128).cuda() + pair_mask = torch.randn(1, pair_len, pair_len).cuda() + + # trace the meta graph and setup codegen + meta_graph = symbolic_trace( + model, + meta_args={ + "m": node.to(torch.device("meta")), + "z": pair.to(torch.device("meta")), + "msa_mask": node_mask.to(torch.device("meta")), + "pair_mask": pair_mask.to(torch.device("meta")), + }, + concrete_args={ + "chunk_size": None, + "_mask_trans": True, + }, + ) + interp = MetaInfoProp(meta_graph) + interp.propagate(MetaTensor(node, fake_device="cuda:0"), MetaTensor(pair, fake_device="cuda:0"), + MetaTensor(node_mask, fake_device="cuda:0"), MetaTensor(pair_mask, fake_device="cuda:0"), None) + codegen = AutoChunkCodeGen(meta_graph, max_memory=max_memory, print_mem=False, print_progress=False) + + # trace and recompile + # MetaInfoProp requires symbolic_trace but CodeGen requires ColoTracer + graph = ColoTracer().trace( + model, + meta_args={ + "m": node.to(torch.device("meta")), + "z": pair.to(torch.device("meta")), + "msa_mask": node_mask.to(torch.device("meta")), + "pair_mask": pair_mask.to(torch.device("meta")), + }, + concrete_args={ + "chunk_size": None, + "_mask_trans": True, + }, + ) + graph.set_codegen(codegen) + gm = ColoGraphModule(model, graph, ckpt_codegen=False) + gm.recompile() + + # assert we have inserted chunk + code = graph.python_code("self").src + # print(code) + assert "chunk_result = None; chunk_size = None;" in code + + _test_fwd(model, gm, node, pair, node_mask, pair_mask) + gpc.destroy() + + +@pytest.mark.skipif( + not (CODEGEN_AVAILABLE and is_compatible_with_meta() and HAS_REPO), + reason="torch version is lower than 1.12.0", +) +@pytest.mark.parametrize("max_memory", [None, 24, 28, 32]) +@pytest.mark.parametrize("msa_len", [32]) +@pytest.mark.parametrize("pair_len", [64]) +def test_evoformer_stack_codegen(msa_len, pair_len, max_memory): + run_func = partial( + _test_evoformer_stack_codegen, + msa_len=msa_len, + pair_len=pair_len, + max_memory=max_memory, + ) + mp.spawn(run_func, nprocs=1) + + +if __name__ == "__main__": + _test_evoformer_stack_codegen(0, 32, 64, None)