[autochunk] support diffusion for autochunk (#2621)

* add alphafold benchmark

* renae alphafold test

* rename tests

* rename diffuser

* renme

* rename

* update transformer

* update benchmark

* update benchmark

* update bench memory

* update transformer benchmark

* rename

* support diffuser

* support unet metainfo prop

* fix bug and simplify code

* update linear and support some op

* optimize max region search, support conv

* update unet test

* support some op

* support groupnorm and interpolate

* update flow search

* add fix dim in node flow

* fix utils

* rename

* support diffusion

* update diffuser

* update chunk search

* optimize imports

* import

* finish autochunk
This commit is contained in:
oahzxl
2023-02-07 16:32:45 +08:00
committed by GitHub
parent 291b051171
commit 6ba8364881
6 changed files with 216 additions and 166 deletions

View File

@@ -22,6 +22,7 @@ def assert_codegen_run(
concrete_args: List = None,
max_memory: int = None,
print_mem: bool = False,
print_est_mem: bool = False,
print_progress: bool = False,
print_code: bool = False,
) -> List[Dict]:
@@ -35,13 +36,14 @@ def assert_codegen_run(
meta_args={k: v.to(torch.device("meta")) for k, v in meta_args},
concrete_args={k: v for k, v in concrete_args},
)
model = model.cuda().eval()
interp = MetaInfoProp(meta_graph)
meta_tensors = [MetaTensor(i[1], fake_device="cuda:0") for i in meta_args] + [i[1] for i in concrete_args]
interp.propagate(*meta_tensors)
codegen = AutoChunkCodeGen(
meta_graph,
max_memory=max_memory,
print_mem=print_mem,
print_mem=print_est_mem,
print_progress=print_progress,
)
chunks = codegen.chunk_infos
@@ -61,17 +63,29 @@ def assert_codegen_run(
code = graph.python_code("self").src
if print_code:
print(code)
assert "chunk_result = None; chunk_size = None;" in code
assert "chunk_size = None; " in code
# assert result
inputs = [i[1] for i in meta_args] + [i[1] for i in concrete_args]
inputs = [i.cuda() if isinstance(i, torch.Tensor) else i for i in inputs]
model.cuda().eval()
gm.eval()
with torch.no_grad():
out_gm = gm(*inputs)
out_model = model(*inputs)
if print_mem:
torch.cuda.reset_peak_memory_stats()
now_mem_gm = torch.cuda.memory_allocated() / 1024**2
out_gm = gm(*[i.clone() if isinstance(i, torch.Tensor) else i for i in inputs])
if print_mem:
max_mem_gm = torch.cuda.max_memory_allocated() / 1024**2
torch.cuda.reset_peak_memory_stats()
now_mem_ori = torch.cuda.memory_allocated() / 1024**2
out_model = model(*[i.clone() if isinstance(i, torch.Tensor) else i for i in inputs])
if print_mem:
max_mem_ori = torch.cuda.max_memory_allocated() / 1024**2
print("origin mem: %.2fMB, autochunk mem: %.2fMB" % (max_mem_ori - now_mem_ori, max_mem_gm - now_mem_gm))
assert torch.allclose(out_gm["sample"], out_model["sample"],
atol=1e-4), "fx_out doesn't comply with original output, diff is %.2e" % torch.mean(
atol=1e-3), "fx_out doesn't comply with original output, diff is %.2e" % torch.mean(
torch.abs(out_gm["sample"] - out_model["sample"]))
return chunks
@@ -82,9 +96,10 @@ def run_test(
model: Any,
data: tuple,
max_memory: int,
print_code: bool,
print_mem: bool,
print_progress: bool,
print_code: bool = False,
print_mem: bool = False,
print_est_mem: bool = False,
print_progress: bool = False,
get_chunk_target: Any = None,
) -> None:
# launch colossalai
@@ -106,6 +121,7 @@ def run_test(
max_memory=max_memory,
print_code=print_code,
print_mem=print_mem,
print_est_mem=print_est_mem,
print_progress=print_progress,
)

View File

@@ -17,10 +17,9 @@ from test_autochunk_diffuser_utils import run_test
from colossalai.autochunk.autochunk_codegen import AUTOCHUNK_AVAILABLE
BATCH_SIZE = 2
SEQ_LENGTH = 5
HEIGHT = 224
WIDTH = 224
BATCH_SIZE = 1
HEIGHT = 448
WIDTH = 448
IN_CHANNELS = 3
LATENTS_SHAPE = (BATCH_SIZE, IN_CHANNELS, HEIGHT // 7, WIDTH // 7)
@@ -34,26 +33,19 @@ def get_data(shape: tuple) -> Tuple[List, List]:
return meta_args, concrete_args
@pytest.mark.skipif(
True,
reason="not implemented",
)
@pytest.mark.skipif(
not (AUTOCHUNK_AVAILABLE and HAS_REPO),
reason="torch version is lower than 1.12.0",
)
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("shape", [LATENTS_SHAPE])
@pytest.mark.parametrize("max_memory", [64])
@pytest.mark.parametrize("max_memory", [None])
def test_evoformer_block(model, shape, max_memory):
run_func = partial(
run_test,
max_memory=max_memory,
model=model,
data=get_data(shape),
print_code=False,
print_mem=False,
print_progress=False,
)
mp.spawn(run_func, nprocs=1)
@@ -62,9 +54,10 @@ if __name__ == "__main__":
run_test(
rank=0,
data=get_data(LATENTS_SHAPE),
max_memory=64,
max_memory=None,
model=UNet2DModel,
print_code=False,
print_mem=False,
print_est_mem=False,
print_progress=False,
)