diff --git a/autochunk_benchmark.py b/autochunk_benchmark.py index 0c55a3a88..f8e603f4e 100644 --- a/autochunk_benchmark.py +++ b/autochunk_benchmark.py @@ -34,15 +34,23 @@ def _benchmark_evoformer(model: torch.nn.Module, node, pair, title): def benchmark_evoformer(): - # data + # init data and model msa_len = 300 pair_len = 800 node = torch.randn(1, msa_len, pair_len, 256).cuda() pair = torch.randn(1, pair_len, pair_len, 128).cuda() - - # build gm model - max_memory = 3000 # MB model = evoformer_base().cuda() + + # build autochunk model + max_memory = 3000 # MB + autochunk = _build_autochunk(model, max_memory, node, pair) + + # benchmark + _benchmark_evoformer(model, node, pair, "openfold") + _benchmark_evoformer(autochunk, node, pair, "autochunk") + + +def _build_autochunk(model, max_memory, node, pair): # trace the module and replace codegen graph = ColoTracer().trace( model, @@ -70,9 +78,7 @@ def benchmark_evoformer(): # print code = graph.python_code("self").src print(code) - - _benchmark_evoformer(gm, node, pair, "autochunk") - _benchmark_evoformer(model, node, pair, "openfold") + return gm if __name__ == "__main__":