mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-15 22:19:38 +00:00
[autochunk] support vit (#3084)
support vit for autochunk * support some new ops for vit * fix some bugs * add test for vit
This commit is contained in:
@@ -0,0 +1,147 @@
|
||||
import time
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import torch
|
||||
import torch.fx
|
||||
|
||||
import colossalai
|
||||
from colossalai.autochunk.autochunk_codegen import AUTOCHUNK_AVAILABLE
|
||||
from colossalai.fx.graph_module import ColoGraphModule
|
||||
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
|
||||
from colossalai.fx.profiler import parameter_size
|
||||
from colossalai.utils import free_port
|
||||
|
||||
if AUTOCHUNK_AVAILABLE:
|
||||
from colossalai.autochunk.autochunk_codegen import AutoChunkCodeGen
|
||||
from colossalai.fx.profiler import MetaTensor
|
||||
from colossalai.fx.tracer.experimental import ColoTracer, symbolic_trace
|
||||
|
||||
|
||||
def _benchmark_autochunk_unet_gm(
|
||||
model: Any,
|
||||
data: tuple,
|
||||
max_memory: int = None,
|
||||
) -> None:
|
||||
model = model.cuda().eval()
|
||||
|
||||
# build model and input
|
||||
meta_args, concrete_args = data
|
||||
if concrete_args is None:
|
||||
concrete_args = {}
|
||||
|
||||
# trace the meta graph and setup codegen
|
||||
meta_graph = symbolic_trace(
|
||||
model,
|
||||
meta_args={k: v.to(torch.device("meta")) for k, v in meta_args},
|
||||
concrete_args={k: v for k, v in concrete_args},
|
||||
)
|
||||
model = model.cuda().eval()
|
||||
interp = MetaInfoProp(meta_graph)
|
||||
meta_tensors = [i[1] for i in meta_args] + [i[1] for i in concrete_args]
|
||||
meta_tensors = [MetaTensor(i, fake_device="cuda:0") if isinstance(i, torch.Tensor) else i for i in meta_tensors]
|
||||
interp.propagate(*meta_tensors)
|
||||
codegen = AutoChunkCodeGen(
|
||||
meta_graph,
|
||||
max_memory=max_memory,
|
||||
)
|
||||
|
||||
# trace and recompile
|
||||
# MetaInfoProp requires symbolic_trace but CodeGen requires ColoTracer
|
||||
graph = ColoTracer().trace(
|
||||
model.cuda().eval(),
|
||||
meta_args={k: v.to(torch.device("meta")) for k, v in meta_args},
|
||||
concrete_args={k: v for k, v in concrete_args},
|
||||
)
|
||||
graph.set_codegen(codegen)
|
||||
gm = ColoGraphModule(model, graph, ckpt_codegen=False)
|
||||
gm.recompile()
|
||||
|
||||
# init inputs
|
||||
inputs = [i[1] for i in meta_args] + [i[1] for i in concrete_args]
|
||||
inputs = [i.cuda() if isinstance(i, torch.Tensor) else i for i in inputs]
|
||||
model.cuda().eval()
|
||||
|
||||
# bench
|
||||
para_mem = float(parameter_size(model)) / 1024**2
|
||||
act_mem = _benchmark_memory(gm, inputs)
|
||||
speed = _benchmark_speed(gm, inputs)
|
||||
print("unet autochunk, time: %.4fs, act mem: %.2fMB, para mem: %.2fMB, all mem: %.2fMB" %
|
||||
(speed, act_mem, para_mem, act_mem + para_mem))
|
||||
|
||||
|
||||
def _benchmark_autochunk_unet_origin(
|
||||
model: Any,
|
||||
data: tuple,
|
||||
) -> None:
|
||||
# build model and input
|
||||
meta_args, concrete_args = data
|
||||
if concrete_args is None:
|
||||
concrete_args = {}
|
||||
|
||||
# init inputs
|
||||
inputs = [i[1] for i in meta_args] + [i[1] for i in concrete_args]
|
||||
inputs = [i.cuda() if isinstance(i, torch.Tensor) else i for i in inputs]
|
||||
model.cuda().eval()
|
||||
|
||||
# bench
|
||||
para_mem = float(parameter_size(model)) / 1024**2
|
||||
act_mem = _benchmark_memory(model, inputs)
|
||||
speed = _benchmark_speed(model, inputs)
|
||||
print("unet origin, time: %.4fs, act mem: %.2fMB, para mem: %.2fMB, all mem: %.2fMB" %
|
||||
(speed, act_mem, para_mem, act_mem + para_mem))
|
||||
return act_mem
|
||||
|
||||
|
||||
def _benchmark_memory(model, inputs):
|
||||
with torch.no_grad():
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
now_mem = float(torch.cuda.memory_allocated()) / 1024**2
|
||||
model(*inputs)
|
||||
new_max_mem = float(torch.cuda.max_memory_allocated()) / 1024**2
|
||||
return new_max_mem - now_mem
|
||||
|
||||
|
||||
def _benchmark_speed(model, inputs, loop=5):
|
||||
with torch.no_grad():
|
||||
for _ in range(loop // 2 + 1):
|
||||
model(*inputs)
|
||||
torch.cuda.synchronize()
|
||||
time1 = time.time()
|
||||
for _ in range(loop):
|
||||
model(*inputs)
|
||||
torch.cuda.synchronize()
|
||||
time2 = time.time()
|
||||
return (time2 - time1) / loop
|
||||
|
||||
|
||||
def benchmark_autochunk_unet(batch=1, height=448, width=448):
|
||||
from test_autochunk_unet import UNet2DModel, get_data
|
||||
model = UNet2DModel()
|
||||
latent_shape = (batch, 3, height // 7, width // 7)
|
||||
|
||||
print("\nbatch: %d, height: %d, width: %d" % (batch, height, width))
|
||||
max_mem = _benchmark_autochunk_unet_origin(model, get_data(latent_shape))
|
||||
for ratio in [0.5, 0.4, 0.3, 0.2]:
|
||||
try:
|
||||
_benchmark_autochunk_unet_gm(model, get_data(latent_shape), max_mem * ratio)
|
||||
except RuntimeError as e:
|
||||
if e.args[0] == 'Search failed. Try a larger memory threshold.':
|
||||
break
|
||||
except Exception as e:
|
||||
raise e
|
||||
_benchmark_autochunk_unet_gm(model, get_data(latent_shape), None)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# launch colossalai
|
||||
colossalai.launch(
|
||||
config={},
|
||||
rank=0,
|
||||
world_size=1,
|
||||
host="localhost",
|
||||
port=free_port(),
|
||||
backend="nccl",
|
||||
)
|
||||
benchmark_autochunk_unet(batch=1, height=224 * 2, width=224 * 2)
|
||||
benchmark_autochunk_unet(batch=1, height=224 * 3, width=224 * 3)
|
||||
benchmark_autochunk_unet(batch=1, height=224 * 4, width=224 * 4)
|
@@ -39,7 +39,7 @@ def get_data(shape: tuple) -> Tuple[List, List]:
|
||||
)
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("shape", [LATENTS_SHAPE])
|
||||
@pytest.mark.parametrize("max_memory", [None])
|
||||
@pytest.mark.parametrize("max_memory", [None, 150, 300])
|
||||
def test_evoformer_block(model, shape, max_memory):
|
||||
run_func = partial(
|
||||
run_test,
|
||||
@@ -57,7 +57,7 @@ if __name__ == "__main__":
|
||||
max_memory=None,
|
||||
model=UNet2DModel,
|
||||
print_code=False,
|
||||
print_mem=False,
|
||||
print_mem=True,
|
||||
print_est_mem=False,
|
||||
print_progress=False,
|
||||
)
|
||||
|
@@ -0,0 +1,53 @@
|
||||
from functools import partial
|
||||
from typing import List, Tuple
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
try:
|
||||
from timm.models.vision_transformer import vit_large_patch16_384 as vit
|
||||
MODELS = [vit]
|
||||
HAS_REPO = True
|
||||
except:
|
||||
MODELS = []
|
||||
HAS_REPO = False
|
||||
|
||||
from test_autochunk_vit_utils import run_test
|
||||
|
||||
from colossalai.autochunk.autochunk_codegen import AUTOCHUNK_AVAILABLE
|
||||
|
||||
|
||||
def get_data() -> Tuple[List, List]:
|
||||
data = torch.rand(1, 3, 384, 384)
|
||||
meta_args = {'x': data}
|
||||
return data, meta_args
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not (AUTOCHUNK_AVAILABLE and HAS_REPO),
|
||||
reason="torch version is lower than 1.12.0",
|
||||
)
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("max_memory", [None, 32, 40])
|
||||
def test_evoformer_block(model, max_memory):
|
||||
run_func = partial(
|
||||
run_test,
|
||||
max_memory=max_memory,
|
||||
model=model,
|
||||
data=get_data(),
|
||||
)
|
||||
mp.spawn(run_func, nprocs=1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_test(
|
||||
rank=0,
|
||||
data=get_data(),
|
||||
max_memory=None,
|
||||
model=vit,
|
||||
print_code=False,
|
||||
print_mem=False,
|
||||
print_est_mem=False,
|
||||
print_progress=False,
|
||||
)
|
@@ -0,0 +1,128 @@
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import torch
|
||||
import torch.fx
|
||||
|
||||
import colossalai
|
||||
from colossalai.autochunk.autochunk_codegen import AUTOCHUNK_AVAILABLE
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.fx.graph_module import ColoGraphModule
|
||||
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
|
||||
from colossalai.utils import free_port
|
||||
|
||||
if AUTOCHUNK_AVAILABLE:
|
||||
from colossalai.autochunk.autochunk_codegen import AutoChunkCodeGen
|
||||
from colossalai.fx.profiler import MetaTensor
|
||||
from colossalai.fx.tracer.experimental import ColoTracer, symbolic_trace
|
||||
|
||||
|
||||
def assert_codegen_run(
|
||||
model: Any,
|
||||
meta_args: Dict,
|
||||
data: Any,
|
||||
max_memory: int = None,
|
||||
print_mem: bool = False,
|
||||
print_est_mem: bool = False,
|
||||
print_progress: bool = False,
|
||||
print_code: bool = False,
|
||||
) -> List[Dict]:
|
||||
model = model()
|
||||
|
||||
# trace the meta graph and setup codegen
|
||||
meta_graph = symbolic_trace(model, meta_args={k: v.to(torch.device("meta")) for k, v in meta_args.items()})
|
||||
model = model.cuda().eval()
|
||||
interp = MetaInfoProp(meta_graph)
|
||||
meta_tensors = [MetaTensor(i[1], fake_device="cuda:0") for i in meta_args.items()]
|
||||
interp.propagate(*meta_tensors)
|
||||
codegen = AutoChunkCodeGen(
|
||||
meta_graph,
|
||||
max_memory=max_memory,
|
||||
print_mem=print_est_mem,
|
||||
print_progress=print_progress,
|
||||
)
|
||||
chunks = codegen.chunk_infos
|
||||
|
||||
# trace and recompile
|
||||
# MetaInfoProp requires symbolic_trace but CodeGen requires ColoTracer
|
||||
graph = ColoTracer().trace(
|
||||
model.cuda(),
|
||||
meta_args={k: v.to(torch.device("meta")) for k, v in meta_args.items()},
|
||||
)
|
||||
graph.set_codegen(codegen)
|
||||
gm = ColoGraphModule(model, graph, ckpt_codegen=False)
|
||||
gm.recompile()
|
||||
|
||||
# assert chunk in code
|
||||
code = graph.python_code("self").src
|
||||
if print_code:
|
||||
print(code)
|
||||
assert "chunk_size = None; " in code
|
||||
|
||||
# assert result
|
||||
inputs = [data.cuda()]
|
||||
model.cuda().eval()
|
||||
gm.eval()
|
||||
with torch.no_grad():
|
||||
if print_mem:
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
now_mem_gm = torch.cuda.memory_allocated() / 1024**2
|
||||
out_gm = gm(*[i.clone() if isinstance(i, torch.Tensor) else i for i in inputs])
|
||||
if print_mem:
|
||||
max_mem_gm = torch.cuda.max_memory_allocated() / 1024**2
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
now_mem_ori = torch.cuda.memory_allocated() / 1024**2
|
||||
out_model = model(*[i.clone() if isinstance(i, torch.Tensor) else i for i in inputs])
|
||||
if print_mem:
|
||||
max_mem_ori = torch.cuda.max_memory_allocated() / 1024**2
|
||||
print("origin mem: %.2fMB, autochunk mem: %.2fMB" % (max_mem_ori - now_mem_ori, max_mem_gm - now_mem_gm))
|
||||
|
||||
assert torch.allclose(out_gm, out_model,
|
||||
atol=1e-3), "fx_out doesn't comply with original output, diff is %.2e" % torch.mean(
|
||||
torch.abs(out_gm - out_model))
|
||||
|
||||
return chunks
|
||||
|
||||
|
||||
def run_test(
|
||||
rank: int,
|
||||
model: Any,
|
||||
data: tuple,
|
||||
max_memory: int,
|
||||
print_code: bool = False,
|
||||
print_mem: bool = False,
|
||||
print_est_mem: bool = False,
|
||||
print_progress: bool = False,
|
||||
get_chunk_target: Any = None,
|
||||
) -> None:
|
||||
# launch colossalai
|
||||
colossalai.launch(
|
||||
config={},
|
||||
rank=rank,
|
||||
world_size=1,
|
||||
host="localhost",
|
||||
port=free_port(),
|
||||
backend="nccl",
|
||||
)
|
||||
|
||||
# build model and input
|
||||
data, meta_args = data
|
||||
chunks = assert_codegen_run(
|
||||
model,
|
||||
meta_args=meta_args,
|
||||
data=data,
|
||||
max_memory=max_memory,
|
||||
print_code=print_code,
|
||||
print_mem=print_mem,
|
||||
print_est_mem=print_est_mem,
|
||||
print_progress=print_progress,
|
||||
)
|
||||
|
||||
if get_chunk_target is not None:
|
||||
chunk_found = [i["region"] for i in chunks]
|
||||
chunk_target = get_chunk_target()[max_memory]
|
||||
assert (chunk_found == chunk_target), "found regions %s doesn't equal target regions %s" % (
|
||||
str(chunk_found),
|
||||
str(chunk_target),
|
||||
)
|
||||
|
||||
gpc.destroy()
|
Reference in New Issue
Block a user