diff --git a/tests/test_autochunk/benchmark_autochunk.py b/tests/test_autochunk/benchmark_autochunk.py index 7a9d8cdee..6632ece61 100644 --- a/tests/test_autochunk/benchmark_autochunk.py +++ b/tests/test_autochunk/benchmark_autochunk.py @@ -98,14 +98,14 @@ def _build_openfold(): def benchmark_evoformer(): # init data and model msa_len = 256 - pair_len = 256 + pair_len = 512 node = torch.randn(1, msa_len, pair_len, 256).cuda() pair = torch.randn(1, pair_len, pair_len, 128).cuda() model = evoformer_base().cuda() # build autochunk model - max_memory = 1000 # MB fit memory mode - # max_memory = None # min memory mode + # max_memory = 1000 # MB, fit memory mode + max_memory = None # min memory mode autochunk = _build_autochunk(evoformer_base().cuda(), max_memory, node, pair) # build openfold