mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-05-31 03:15:40 +00:00
Support multi outputs chunk search. Previously we only support single output chunk search. It is more flexible and improve performance by a large margin. For transformer, we reduce memory by 40% than previous search strategy. 1. rewrite search strategy to support multi outputs chunk search 2. fix many, many bugs 3. update tests
88 lines
2.1 KiB
Python
88 lines
2.1 KiB
Python
from functools import partial
|
|
from typing import List, Tuple
|
|
|
|
import pytest
|
|
import torch
|
|
import torch.fx
|
|
import torch.multiprocessing as mp
|
|
|
|
try:
|
|
from fastfold.model.nn.evoformer import EvoformerStack
|
|
HAS_REPO = True
|
|
except:
|
|
HAS_REPO = False
|
|
|
|
from test_alphafold_utils import run_test
|
|
|
|
from colossalai.autochunk.autochunk_codegen import AUTOCHUNK_AVAILABLE
|
|
|
|
|
|
def get_model():
|
|
model = EvoformerStack(
|
|
c_m=256,
|
|
c_z=128,
|
|
c_hidden_msa_att=32,
|
|
c_hidden_opm=32,
|
|
c_hidden_mul=128,
|
|
c_hidden_pair_att=32,
|
|
c_s=384,
|
|
no_heads_msa=8,
|
|
no_heads_pair=4,
|
|
no_blocks=2, # 48
|
|
transition_n=4,
|
|
msa_dropout=0.15,
|
|
pair_dropout=0.25,
|
|
blocks_per_ckpt=None,
|
|
inf=1000000000.0,
|
|
eps=1e-08,
|
|
clear_cache_between_blocks=False,
|
|
is_multimer=False,
|
|
).eval().cuda()
|
|
return model
|
|
|
|
|
|
def get_data(msa_len: int, pair_len: int) -> Tuple[List, List]:
|
|
node = torch.randn(1, msa_len, pair_len, 256).cuda()
|
|
node_mask = torch.randn(1, msa_len, pair_len).cuda()
|
|
pair = torch.randn(1, pair_len, pair_len, 128).cuda()
|
|
pair_mask = torch.randn(1, pair_len, pair_len).cuda()
|
|
|
|
meta_args = [
|
|
("m", node),
|
|
("z", pair),
|
|
("msa_mask", node_mask),
|
|
("pair_mask", pair_mask),
|
|
]
|
|
concrete_args = [("chunk_size", None), ("_mask_trans", True)]
|
|
return meta_args, concrete_args
|
|
|
|
|
|
@pytest.mark.skipif(
|
|
not (AUTOCHUNK_AVAILABLE and HAS_REPO),
|
|
reason="torch version is lower than 1.12.0",
|
|
)
|
|
@pytest.mark.parametrize("max_memory", [None, 20, 24])
|
|
@pytest.mark.parametrize("data_args", [(32, 64)]) # (msa_len, pair_len)
|
|
def test_evoformer_stack(data_args, max_memory):
|
|
run_func = partial(
|
|
run_test,
|
|
data_args=data_args,
|
|
max_memory=max_memory,
|
|
get_model=get_model,
|
|
get_data=get_data,
|
|
)
|
|
mp.spawn(run_func, nprocs=1)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
run_test(
|
|
rank=0,
|
|
data_args=(32, 64),
|
|
max_memory=None,
|
|
get_model=get_model,
|
|
get_data=get_data,
|
|
print_code=False,
|
|
print_mem=False,
|
|
print_progress=False,
|
|
)
|