From 1d7ca02301c9ff71953070ea963b8e107fa4ccb6 Mon Sep 17 00:00:00 2001 From: oahzxl Date: Thu, 29 Dec 2022 14:28:38 +0800 Subject: [PATCH] add benchmark --- autochunk_benchmark.py | 79 ++++++++++++++++++++++++++++++++++++++++++ chunk_codegen.py | 16 +++++---- 2 files changed, 89 insertions(+), 6 deletions(-) create mode 100644 autochunk_benchmark.py diff --git a/autochunk_benchmark.py b/autochunk_benchmark.py new file mode 100644 index 000000000..a34464212 --- /dev/null +++ b/autochunk_benchmark.py @@ -0,0 +1,79 @@ +import copy +import torch +import torch.nn.functional as F +import pytest +import torch.fx +import torch.multiprocessing as mp +from torch.fx import GraphModule +from colossalai.fx import ColoTracer +import colossalai +from colossalai.utils import free_port +from colossalai.core import global_context as gpc +from colossalai.fx.graph_module import ColoGraphModule +from colossalai.fx.passes.meta_info_prop import MetaInfoProp, TensorMetadata +from colossalai.fx.profiler import MetaTensor +from evoformer.evoformer import evoformer_base +from chunk_codegen import ChunkCodeGen +import time + + +def _benchmark_evoformer(model: torch.nn.Module, node, pair): + loop = 10 + with torch.no_grad(): + for _ in range(loop // 4): + model(node, pair) + torch.cuda.synchronize() + time1 = time.time() + for _ in range(loop): + model(node, pair) + torch.cuda.synchronize() + time2 = time.time() + return (time2 - time1) / loop + + +def benchmark_evoformer(): + # data + msa_len = 300 + pair_len = 800 + node = torch.randn(1, msa_len, pair_len, 256).cuda() + pair = torch.randn(1, pair_len, pair_len, 128).cuda() + + # build gm model + max_memory = 3000 # MB + model = evoformer_base().cuda() + # trace the module and replace codegen + graph = ColoTracer().trace( + model, + meta_args={ + "node": node.to(torch.device("meta")), + "pair": pair.to(torch.device("meta")), + }, + ) + gm_prop = torch.fx.symbolic_trace(model) # must use symbolic_trace + interp = MetaInfoProp(gm_prop) + interp.propagate( + MetaTensor(node, fake_device="cuda:0"), MetaTensor(pair, fake_device="cuda:0") + ) + # now run it twice to get meta info in graph module, not necessary + gm = torch.fx.GraphModule(model, graph) + interp = MetaInfoProp(gm) + interp.propagate( + MetaTensor(node, fake_device="cuda:0"), MetaTensor(pair, fake_device="cuda:0") + ) + # set code_gen + codegen = ChunkCodeGen(gm_prop, max_memory) + graph.set_codegen(codegen) + gm = ColoGraphModule(model, graph) + gm.recompile() + # print + code = graph.python_code("self").src + print(code) + + time_gm = _benchmark_evoformer(gm, node, pair) + print("gm %.4fs" % time_gm) + time_openfold = _benchmark_evoformer(model, node, pair) + print("openfold %.4fs" % time_openfold) + + +if __name__ == "__main__": + benchmark_evoformer() diff --git a/chunk_codegen.py b/chunk_codegen.py index 6caed88d8..033db50db 100644 --- a/chunk_codegen.py +++ b/chunk_codegen.py @@ -1398,13 +1398,14 @@ class MemoryEstimator(object): class ChunkSelector(object): def __init__( - self, index_tracer: IndexTracer, memory_estimator: MemoryEstimator, stratge + self, index_tracer: IndexTracer, memory_estimator: MemoryEstimator, stratge, max_memory=None ): self.index_tracer = index_tracer self.memory_estimator = memory_estimator assert stratge in ["min_memory", "fit_memory"] + assert (stratge == "fit_memory" and max_memory is not None) or stratge != "fit_memory" self.stratge = stratge - self.max_memory = 600 # MB + self.max_memory = max_memory # MB def _select_best_chunk_region( self, possible_chunk_regions, chunk_infos, peak_node, max_chunk_region, mem_peak @@ -1556,13 +1557,13 @@ class ChunkSelector(object): class ChunkRegionSearch(object): - def __init__(self, gm) -> None: + def __init__(self, gm, max_memory=None) -> None: self.gm = gm self.index_tracer = IndexTracer(list(gm.graph.nodes)) self.index_tracer.trace_index() self.memory_estimator = MemoryEstimator(self.index_tracer) self.chunk_selector = ChunkSelector( - self.index_tracer, self.memory_estimator, stratge="fit_memory" + self.index_tracer, self.memory_estimator, stratge="fit_memory", max_memory=max_memory ) def _find_peak_node(self, mem_peak): @@ -1897,6 +1898,7 @@ def emit_code_with_chunk( delete_unused_value_func, meta_nodes, meta_graph, + max_memory=None, ): """Emit code with nested activation checkpoint When we detect some of the node.activation_checkpoint is a List, we will use @@ -1912,7 +1914,7 @@ def emit_code_with_chunk( node_list = list(nodes) # find the chunk regions - chunk_region_search = ChunkRegionSearch(meta_graph) + chunk_region_search = ChunkRegionSearch(meta_graph, max_memory) chunk_search = chunk_region_search.search_region() chunk_regions = [i["region"] for i in chunk_search] @@ -1989,9 +1991,10 @@ def emit_code_with_chunk( if CODEGEN_AVAILABLE: class ChunkCodeGen(CodeGen): - def __init__(self, meta_graph): + def __init__(self, meta_graph, max_memory=None): super().__init__() self.meta_graph = meta_graph + self.max_memory = max_memory self.meta_node = list(meta_graph.graph.nodes) def _gen_python_code( @@ -2230,6 +2233,7 @@ if CODEGEN_AVAILABLE: delete_unused_values, self.meta_node, self.meta_graph, + self.max_memory ) if len(body) == 0: