[autochunk] add benchmark for transformer and alphafold (#2543)

This commit is contained in:
oahzxl
2023-02-02 15:06:43 +08:00
committed by GitHub
parent 9885ec2b2e
commit c4b15661d7
10 changed files with 286 additions and 5 deletions

View File

@@ -0,0 +1,120 @@
from typing import Any, Dict, List
import torch
import torch.fx
import colossalai
from colossalai.autochunk.autochunk_codegen import AUTOCHUNK_AVAILABLE
from colossalai.core import global_context as gpc
from colossalai.fx.graph_module import ColoGraphModule
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
from colossalai.utils import free_port
if AUTOCHUNK_AVAILABLE:
from colossalai.autochunk.autochunk_codegen import AutoChunkCodeGen
from colossalai.fx.profiler import MetaTensor
from colossalai.fx.tracer.experimental import ColoTracer, symbolic_trace
def assert_codegen_run(
model: Any,
meta_args: List,
concrete_args: List = None,
max_memory: int = None,
print_mem: bool = False,
print_progress: bool = False,
print_code: bool = False,
) -> List[Dict]:
if concrete_args is None:
concrete_args = []
model = model()
# trace the meta graph and setup codegen
meta_graph = symbolic_trace(
model,
meta_args={k: v.to(torch.device("meta")) for k, v in meta_args},
concrete_args={k: v for k, v in concrete_args},
)
interp = MetaInfoProp(meta_graph)
meta_tensors = [MetaTensor(i[1], fake_device="cuda:0") for i in meta_args] + [i[1] for i in concrete_args]
interp.propagate(*meta_tensors)
codegen = AutoChunkCodeGen(
meta_graph,
max_memory=max_memory,
print_mem=print_mem,
print_progress=print_progress,
)
chunks = codegen.chunk_infos
# trace and recompile
# MetaInfoProp requires symbolic_trace but CodeGen requires ColoTracer
graph = ColoTracer().trace(
model.cuda(),
meta_args={k: v.to(torch.device("meta")) for k, v in meta_args},
concrete_args={k: v for k, v in concrete_args},
)
graph.set_codegen(codegen)
gm = ColoGraphModule(model, graph, ckpt_codegen=False)
gm.recompile()
# assert chunk in code
code = graph.python_code("self").src
if print_code:
print(code)
assert "chunk_result = None; chunk_size = None;" in code
# assert result
inputs = [i[1] for i in meta_args] + [i[1] for i in concrete_args]
model.cuda().eval()
gm.eval()
with torch.no_grad():
out_gm = gm(*inputs)
out_model = model(*inputs)
assert torch.allclose(out_gm["sample"], out_model["sample"],
atol=1e-4), "fx_out doesn't comply with original output, diff is %.2e" % torch.mean(
torch.abs(out_gm["sample"] - out_model["sample"]))
return chunks
def run_test(
rank: int,
model: Any,
data: tuple,
max_memory: int,
print_code: bool,
print_mem: bool,
print_progress: bool,
get_chunk_target: Any = None,
) -> None:
# launch colossalai
colossalai.launch(
config={},
rank=rank,
world_size=1,
host="localhost",
port=free_port(),
backend="nccl",
)
# build model and input
meta_args, concrete_args = data
chunks = assert_codegen_run(
model,
meta_args=meta_args,
concrete_args=concrete_args,
max_memory=max_memory,
print_code=print_code,
print_mem=print_mem,
print_progress=print_progress,
)
if get_chunk_target is not None:
chunk_found = [i["region"] for i in chunks]
chunk_target = get_chunk_target()[max_memory]
assert (chunk_found == chunk_target), "found regions %s doesn't equal target regions %s" % (
str(chunk_found),
str(chunk_target),
)
gpc.destroy()

View File

@@ -0,0 +1,70 @@
from functools import partial
from typing import List, Tuple
import pytest
import torch
import torch.multiprocessing as mp
try:
from diffusers import UNet2DModel
MODELS = [UNet2DModel]
HAS_REPO = True
except:
MODELS = []
HAS_REPO = False
from test_autochunk_diffuser_utils import run_test
from colossalai.autochunk.autochunk_codegen import AUTOCHUNK_AVAILABLE
BATCH_SIZE = 2
SEQ_LENGTH = 5
HEIGHT = 224
WIDTH = 224
IN_CHANNELS = 3
LATENTS_SHAPE = (BATCH_SIZE, IN_CHANNELS, HEIGHT // 7, WIDTH // 7)
def get_data(shape: tuple) -> Tuple[List, List]:
sample = torch.randn(shape)
meta_args = [
("sample", sample),
]
concrete_args = [("timestep", 50)]
return meta_args, concrete_args
@pytest.mark.skipif(
True,
reason="not implemented",
)
@pytest.mark.skipif(
not (AUTOCHUNK_AVAILABLE and HAS_REPO),
reason="torch version is lower than 1.12.0",
)
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("shape", [LATENTS_SHAPE])
@pytest.mark.parametrize("max_memory", [64])
def test_evoformer_block(model, shape, max_memory):
run_func = partial(
run_test,
max_memory=max_memory,
model=model,
data=get_data(shape),
print_code=False,
print_mem=False,
print_progress=False,
)
mp.spawn(run_func, nprocs=1)
if __name__ == "__main__":
run_test(
rank=0,
data=get_data(LATENTS_SHAPE),
max_memory=64,
model=UNet2DModel,
print_code=False,
print_mem=False,
print_progress=False,
)