[autochunk] add benchmark for transformer and alphafold (#2543)

This commit is contained in:
oahzxl
2023-02-02 15:06:43 +08:00
committed by GitHub
parent 9885ec2b2e
commit c4b15661d7
10 changed files with 286 additions and 5 deletions

View File

@@ -0,0 +1,150 @@
import time
from typing import Any, Dict, List
import torch
import torch.fx
import colossalai
from colossalai.autochunk.autochunk_codegen import AUTOCHUNK_AVAILABLE
from colossalai.fx.graph_module import ColoGraphModule
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
from colossalai.fx.profiler import parameter_size
from colossalai.utils import free_port
if AUTOCHUNK_AVAILABLE:
from colossalai.autochunk.autochunk_codegen import AutoChunkCodeGen
from colossalai.fx.profiler import MetaTensor
from colossalai.fx.tracer.experimental import ColoTracer, symbolic_trace
def _benchmark_autochunk_gpt_gm(
model: Any,
data: tuple,
max_memory: int = None,
) -> None:
model = model.cuda().eval()
# build model and input
meta_args, concrete_args, sequence = data
if concrete_args is None:
concrete_args = {}
# trace the meta graph and setup codegen
meta_graph = symbolic_trace(
model,
meta_args={k: v.to(torch.device("meta")) for k, v in meta_args.items()},
concrete_args={k: v for k, v in concrete_args.items()},
)
interp = MetaInfoProp(meta_graph)
meta_tensors = [meta_args[i] if i in meta_args else concrete_args[i] for i in sequence]
meta_tensors = [MetaTensor(i, fake_device="cuda:0") if isinstance(i, torch.Tensor) else i for i in meta_tensors]
interp.propagate(*meta_tensors)
codegen = AutoChunkCodeGen(
meta_graph,
max_memory=max_memory,
)
# trace and recompile
# MetaInfoProp requires symbolic_trace but CodeGen requires ColoTracer
graph = ColoTracer().trace(
model.cuda().eval(),
meta_args={k: v.to(torch.device("meta")) for k, v in meta_args.items()},
concrete_args={k: v for k, v in concrete_args.items()},
)
graph.set_codegen(codegen)
gm = ColoGraphModule(model, graph, ckpt_codegen=False)
gm.recompile()
# init inputs
inputs = [meta_args[i] if i in meta_args else concrete_args[i] for i in sequence]
inputs = [i.cuda() if isinstance(i, torch.Tensor) else i for i in inputs]
model.cuda().eval()
# bench
para_mem = float(parameter_size(model)) / 1024**2 * 6
act_mem = _benchmark_memory(gm, inputs)
speed = _benchmark_speed(gm, inputs)
print("gpt autochunk, time: %.4fs, act mem: %.2fMB, para mem: %.2fMB, all mem: %.2fMB" %
(speed, act_mem, para_mem, act_mem + para_mem))
def _benchmark_autochunk_gpt_origin(
model: Any,
data: tuple,
) -> None:
# build model and input
meta_args, concrete_args, sequence = data
if concrete_args is None:
concrete_args = {}
# init inputs
inputs = [meta_args[i] if i in meta_args else concrete_args[i] for i in sequence]
inputs = [i.cuda() if isinstance(i, torch.Tensor) else i for i in inputs]
model.cuda().eval()
# bench
para_mem = float(parameter_size(model)) / 1024**2 * 6
act_mem = _benchmark_memory(model, inputs)
speed = _benchmark_speed(model, inputs)
print("gpt origin, time: %.4fs, act mem: %.2fMB, para mem: %.2fMB, all mem: %.2fMB" %
(speed, act_mem, para_mem, act_mem + para_mem))
return act_mem
def _benchmark_memory(model, inputs):
with torch.no_grad():
torch.cuda.reset_peak_memory_stats()
now_mem = float(torch.cuda.memory_allocated()) / 1024**2
model(*[i.clone() if isinstance(i, torch.Tensor) else i for i in inputs])
new_max_mem = float(torch.cuda.max_memory_allocated()) / 1024**2
return new_max_mem - now_mem
def _benchmark_speed(model, inputs, loop=5):
with torch.no_grad():
for _ in range(loop // 2 + 1):
model(*inputs)
torch.cuda.synchronize()
time1 = time.time()
for _ in range(loop):
model(*inputs)
torch.cuda.synchronize()
time2 = time.time()
return (time2 - time1) / loop
def benchmark_autochunk_gpt(batch=1, seq=512, n_embd=768, n_head=12):
from test_autochunk_gpt import GPT2Config, GPT2Model, get_data
model = GPT2Model
config = GPT2Config(n_embd=n_embd, n_position=seq, n_layer=2, n_head=n_head)
config.max_position_embeddings = seq
model = model(config=config)
shape = [batch, seq]
print("\nbatch: %d, seq: %d, n_embd: %d, n_head: %d" % (batch, seq, n_embd, n_head))
max_mem = _benchmark_autochunk_gpt_origin(model, get_data(shape))
for ratio in [0.5, 0.4, 0.3, 0.2]:
try:
_benchmark_autochunk_gpt_gm(model, get_data(shape), max_mem * ratio)
except RuntimeError as e:
if e.args[0] == 'Search failed. Try a larger memory threshold.':
break
except Exception as e:
raise e
_benchmark_autochunk_gpt_gm(model, get_data(shape), None)
if __name__ == "__main__":
# launch colossalai
colossalai.launch(
config={},
rank=0,
world_size=1,
host="localhost",
port=free_port(),
backend="nccl",
)
benchmark_autochunk_gpt(batch=1, seq=1024, n_embd=768, n_head=12)
benchmark_autochunk_gpt(batch=1, seq=2048, n_embd=768, n_head=12)
benchmark_autochunk_gpt(batch=1, seq=4096, n_embd=768, n_head=12)
benchmark_autochunk_gpt(batch=1, seq=6144, n_embd=768, n_head=12)
benchmark_autochunk_gpt(batch=1, seq=8192, n_embd=768, n_head=12)

View File

@@ -0,0 +1,63 @@
from functools import partial
from typing import List, Tuple
import pytest
import torch
import torch.multiprocessing as mp
try:
from transformers import GPT2Config, GPT2Model
MODELS = [GPT2Model]
HAS_REPO = True
except:
MODELS = []
HAS_REPO = False
from test_autochunk_transformer_utils import run_test
from colossalai.autochunk.autochunk_codegen import AUTOCHUNK_AVAILABLE
BATCH_SIZE = 1
SEQ_LENGTH = 512
def get_data(shape: tuple) -> Tuple[List, List]:
input_ids = torch.zeros(shape, dtype=torch.int64)
token_type_ids = torch.zeros(shape, dtype=torch.int64)
attention_mask = torch.ones(shape, dtype=torch.int64)
meta_args = dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
concrete_args = {"past_key_values": None}
sequence = ["input_ids", "past_key_values", "attention_mask", "token_type_ids"]
return meta_args, concrete_args, sequence
@pytest.mark.skipif(
not (AUTOCHUNK_AVAILABLE and HAS_REPO),
reason="torch version is lower than 1.12.0",
)
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("shape", [(BATCH_SIZE, SEQ_LENGTH)])
@pytest.mark.parametrize("max_memory", [None, 6, 8])
def test_autochunk_gpt(model, shape, max_memory):
run_func = partial(
run_test,
data=get_data(shape),
max_memory=max_memory,
model=model,
config=GPT2Config(n_embd=96, n_position=shape[1], n_layer=2, n_head=4),
)
mp.spawn(run_func, nprocs=1)
if __name__ == "__main__":
run_test(
rank=0,
data=get_data((BATCH_SIZE, SEQ_LENGTH)),
max_memory=None,
model=GPT2Model,
config=GPT2Config(n_embd=96, n_position=SEQ_LENGTH, n_layer=2, n_head=4),
print_code=False,
print_est_mem=False,
print_mem=False,
print_progress=False,
)

View File

@@ -0,0 +1,141 @@
from typing import Any, Dict, List
import torch
import torch.fx
import colossalai
from colossalai.autochunk.autochunk_codegen import AUTOCHUNK_AVAILABLE
from colossalai.core import global_context as gpc
from colossalai.fx.graph_module import ColoGraphModule
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
from colossalai.utils import free_port
if AUTOCHUNK_AVAILABLE:
from colossalai.autochunk.autochunk_codegen import AutoChunkCodeGen
from colossalai.fx.profiler import MetaTensor
from colossalai.fx.tracer.experimental import ColoTracer, symbolic_trace
def assert_codegen_run(
model: Any,
data: tuple,
max_memory: int = None,
print_est_mem: bool = False,
print_mem: bool = False,
print_progress: bool = False,
print_code: bool = False,
) -> List[Dict]:
meta_args, concrete_args, sequence = data
if concrete_args is None:
concrete_args = {}
# trace the meta graph and setup codegen
meta_graph = symbolic_trace(
model,
meta_args={k: v.to(torch.device("meta")) for k, v in meta_args.items()},
concrete_args={k: v for k, v in concrete_args.items()},
)
interp = MetaInfoProp(meta_graph)
meta_tensors = [meta_args[i] if i in meta_args else concrete_args[i] for i in sequence]
meta_tensors = [MetaTensor(i, fake_device="cuda:0") if isinstance(i, torch.Tensor) else i for i in meta_tensors]
interp.propagate(*meta_tensors)
codegen = AutoChunkCodeGen(
meta_graph,
max_memory=max_memory,
print_mem=print_est_mem,
print_progress=print_progress,
)
chunks = codegen.chunk_infos
# trace and recompile
# MetaInfoProp requires symbolic_trace but CodeGen requires ColoTracer
graph = ColoTracer().trace(
model.cuda(),
meta_args={k: v.to(torch.device("meta")) for k, v in meta_args.items()},
concrete_args={k: v for k, v in concrete_args.items()},
)
graph.set_codegen(codegen)
gm = ColoGraphModule(model, graph, ckpt_codegen=False)
gm.recompile()
# assert chunk in code
code = graph.python_code("self").src
if print_code:
print(code)
assert "chunk_size = None; " in code
# assert result
inputs = [meta_args[i] if i in meta_args else concrete_args[i] for i in sequence]
inputs = [i.cuda() if isinstance(i, torch.Tensor) else i for i in inputs]
model.cuda().eval()
gm.eval()
with torch.no_grad():
if print_mem:
torch.cuda.reset_peak_memory_stats()
now_mem = torch.cuda.memory_allocated() / 1024**2
out_gm = gm(*[i.clone() if isinstance(i, torch.Tensor) else i for i in inputs])
if print_mem:
new_max_mem = torch.cuda.max_memory_allocated() / 1024**2
print("mem: %.2fMB" % (new_max_mem - now_mem))
out_model = model(*inputs)
assert_allclose(out_model, out_gm)
return chunks
def assert_allclose(out_model: Any, out_gm: Any) -> None:
"""
assert allclose for out
"""
if isinstance(out_model, torch.Tensor):
assert torch.allclose(out_model, out_gm,
atol=1e-4), "fx_out doesn't comply with original output, diff is %.2e" % torch.mean(
torch.abs(out_model - out_gm))
elif isinstance(out_model, dict):
for k in out_model.keys():
assert_allclose(out_model[k], out_gm[k])
elif isinstance(out_model, tuple) or isinstance(out_model, list) or isinstance(out_model, set):
for i, j in zip(out_model, out_gm):
assert_allclose(i, j)
def run_test(
rank: int,
model: Any,
config: Any,
data: tuple,
max_memory: int,
print_code: bool = False,
print_est_mem: bool = False,
print_mem: bool = False,
print_progress: bool = False,
get_chunk_target: Any = None,
) -> None:
model = model(config=config)
# launch colossalai
colossalai.launch(
config={},
rank=rank,
world_size=1,
host="localhost",
port=free_port(),
backend="nccl",
)
# build model and input
chunks = assert_codegen_run(
model,
data=data,
max_memory=max_memory,
print_code=print_code,
print_est_mem=print_est_mem,
print_mem=print_mem,
print_progress=print_progress,
)
if get_chunk_target is not None:
chunk_found = [i["region"] for i in chunks]
chunk_target = get_chunk_target()[max_memory]
assert (chunk_found == chunk_target), "found regions %s doesn't equal target regions %s" % (
str(chunk_found),
str(chunk_target),
)