From 7a23deb58455b112cf187776857e2a262d0b737e Mon Sep 17 00:00:00 2001 From: oahzxl Date: Thu, 29 Dec 2022 14:47:16 +0800 Subject: [PATCH] code style --- autochunk_benchmark.py | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/autochunk_benchmark.py b/autochunk_benchmark.py index 0c55a3a88..f8e603f4e 100644 --- a/autochunk_benchmark.py +++ b/autochunk_benchmark.py @@ -34,15 +34,23 @@ def _benchmark_evoformer(model: torch.nn.Module, node, pair, title): def benchmark_evoformer(): - # data + # init data and model msa_len = 300 pair_len = 800 node = torch.randn(1, msa_len, pair_len, 256).cuda() pair = torch.randn(1, pair_len, pair_len, 128).cuda() - - # build gm model - max_memory = 3000 # MB model = evoformer_base().cuda() + + # build autochunk model + max_memory = 3000 # MB + autochunk = _build_autochunk(model, max_memory, node, pair) + + # benchmark + _benchmark_evoformer(model, node, pair, "openfold") + _benchmark_evoformer(autochunk, node, pair, "autochunk") + + +def _build_autochunk(model, max_memory, node, pair): # trace the module and replace codegen graph = ColoTracer().trace( model, @@ -70,9 +78,7 @@ def benchmark_evoformer(): # print code = graph.python_code("self").src print(code) - - _benchmark_evoformer(gm, node, pair, "autochunk") - _benchmark_evoformer(model, node, pair, "openfold") + return gm if __name__ == "__main__":