diff --git a/autochunk_benchmark.py b/autochunk_benchmark.py index f8e603f4e..20f615b21 100644 --- a/autochunk_benchmark.py +++ b/autochunk_benchmark.py @@ -33,23 +33,6 @@ def _benchmark_evoformer(model: torch.nn.Module, node, pair, title): ) -def benchmark_evoformer(): - # init data and model - msa_len = 300 - pair_len = 800 - node = torch.randn(1, msa_len, pair_len, 256).cuda() - pair = torch.randn(1, pair_len, pair_len, 128).cuda() - model = evoformer_base().cuda() - - # build autochunk model - max_memory = 3000 # MB - autochunk = _build_autochunk(model, max_memory, node, pair) - - # benchmark - _benchmark_evoformer(model, node, pair, "openfold") - _benchmark_evoformer(autochunk, node, pair, "autochunk") - - def _build_autochunk(model, max_memory, node, pair): # trace the module and replace codegen graph = ColoTracer().trace( @@ -81,5 +64,22 @@ def _build_autochunk(model, max_memory, node, pair): return gm +def benchmark_evoformer(): + # init data and model + msa_len = 300 + pair_len = 800 + node = torch.randn(1, msa_len, pair_len, 256).cuda() + pair = torch.randn(1, pair_len, pair_len, 128).cuda() + model = evoformer_base().cuda() + + # build autochunk model + max_memory = 3000 # MB + autochunk = _build_autochunk(model, max_memory, node, pair) + + # benchmark + _benchmark_evoformer(model, node, pair, "openfold") + _benchmark_evoformer(autochunk, node, pair, "autochunk") + + if __name__ == "__main__": benchmark_evoformer()