mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-04 10:34:41 +00:00
[autochunk] add benchmark for transformer and alphafold (#2543)
This commit is contained in:
@@ -0,0 +1,70 @@
|
||||
from functools import partial
|
||||
from typing import List, Tuple
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
try:
|
||||
from diffusers import UNet2DModel
|
||||
MODELS = [UNet2DModel]
|
||||
HAS_REPO = True
|
||||
except:
|
||||
MODELS = []
|
||||
HAS_REPO = False
|
||||
|
||||
from test_autochunk_diffuser_utils import run_test
|
||||
|
||||
from colossalai.autochunk.autochunk_codegen import AUTOCHUNK_AVAILABLE
|
||||
|
||||
BATCH_SIZE = 2
|
||||
SEQ_LENGTH = 5
|
||||
HEIGHT = 224
|
||||
WIDTH = 224
|
||||
IN_CHANNELS = 3
|
||||
LATENTS_SHAPE = (BATCH_SIZE, IN_CHANNELS, HEIGHT // 7, WIDTH // 7)
|
||||
|
||||
|
||||
def get_data(shape: tuple) -> Tuple[List, List]:
|
||||
sample = torch.randn(shape)
|
||||
meta_args = [
|
||||
("sample", sample),
|
||||
]
|
||||
concrete_args = [("timestep", 50)]
|
||||
return meta_args, concrete_args
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
True,
|
||||
reason="not implemented",
|
||||
)
|
||||
@pytest.mark.skipif(
|
||||
not (AUTOCHUNK_AVAILABLE and HAS_REPO),
|
||||
reason="torch version is lower than 1.12.0",
|
||||
)
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("shape", [LATENTS_SHAPE])
|
||||
@pytest.mark.parametrize("max_memory", [64])
|
||||
def test_evoformer_block(model, shape, max_memory):
|
||||
run_func = partial(
|
||||
run_test,
|
||||
max_memory=max_memory,
|
||||
model=model,
|
||||
data=get_data(shape),
|
||||
print_code=False,
|
||||
print_mem=False,
|
||||
print_progress=False,
|
||||
)
|
||||
mp.spawn(run_func, nprocs=1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_test(
|
||||
rank=0,
|
||||
data=get_data(LATENTS_SHAPE),
|
||||
max_memory=64,
|
||||
model=UNet2DModel,
|
||||
print_code=False,
|
||||
print_mem=False,
|
||||
print_progress=False,
|
||||
)
|
Reference in New Issue
Block a user