[autochunk] support multi outputs chunk search (#2538)

Support multi outputs chunk search. Previously we only support single output chunk search. It is more flexible and improve performance by a large margin. For transformer, we reduce memory by 40% than previous search strategy.

1. rewrite search strategy to support multi outputs chunk search
2. fix many, many bugs
3. update tests
This commit is contained in:
oahzxl
2023-02-01 13:18:51 +08:00
committed by GitHub
parent f477a14f4a
commit 05671fcb42
14 changed files with 428 additions and 258 deletions

View File

@@ -23,6 +23,7 @@ def assert_codegen_run(
concrete_args: List = None,
max_memory: int = None,
print_mem: bool = False,
print_est_mem: bool = False,
print_progress: bool = False,
print_code: bool = False,
) -> List[Dict]:
@@ -41,7 +42,7 @@ def assert_codegen_run(
codegen = AutoChunkCodeGen(
meta_graph,
max_memory=max_memory,
print_mem=print_mem,
print_mem=print_est_mem,
print_progress=print_progress,
)
chunks = codegen.chunk_infos
@@ -61,13 +62,20 @@ def assert_codegen_run(
code = graph.python_code("self").src
if print_code:
print(code)
assert "chunk_result = None; chunk_size = None;" in code
assert "chunk_size = None; " in code
# assert result
inputs = [i[1] for i in meta_args] + [i[1] for i in concrete_args]
inputs = [i.cuda() if isinstance(i, torch.Tensor) else i for i in inputs]
model.cuda()
with torch.no_grad():
out_gm = gm(*inputs)
if print_mem:
torch.cuda.reset_peak_memory_stats()
now_mem = torch.cuda.memory_allocated() / 1024**2
out_gm = gm(*[i.clone() if isinstance(i, torch.Tensor) else i for i in inputs])
if print_mem:
new_max_mem = torch.cuda.max_memory_allocated() / 1024**2
print("mem: %.2fMB" % (new_max_mem - now_mem))
out_model = model(*inputs)
out_gm = flat_list(out_gm)
out_model = flat_list(out_model)
@@ -85,9 +93,10 @@ def run_test(
max_memory: int,
get_model: Any,
get_data: Any,
print_code: bool,
print_mem: bool,
print_progress: bool,
print_code: bool = False,
print_mem: bool = False,
print_est_mem: bool = False,
print_progress: bool = False,
get_chunk_target: Any = None,
) -> None:
# launch colossalai
@@ -110,6 +119,7 @@ def run_test(
max_memory=max_memory,
print_code=print_code,
print_mem=print_mem,
print_est_mem=print_est_mem,
print_progress=print_progress,
)

View File

@@ -55,9 +55,10 @@ def get_data(msa_len: int, pair_len: int) -> Tuple[List, List]:
def get_chunk_target() -> Dict:
return {
None: [(118, 123), (219, 237), (264, 289), (302, 309), (97, 104), (144, 152), (185, 193), (241, 242), (21, 46)],
20: [(118, 123), (230, 237), (275, 282), (305, 306), (100, 101), (32, 39), (73, 79)],
24: [(118, 123)],
None: [(120, 123), (222, 237), (269, 289), (305, 311), (100, 105), (146, 152), (187, 193), (241, 242),
(25, 50)],
20: [(120, 123), (232, 237), (277, 282), (305, 306), (100, 101), (34, 39)],
24: [(120, 123)],
}
@@ -75,9 +76,6 @@ def test_evoformer_block(data_args, max_memory):
get_model=get_model,
get_data=get_data,
get_chunk_target=get_chunk_target,
print_code=False,
print_mem=False,
print_progress=False,
)
mp.spawn(run_func, nprocs=1)
@@ -86,10 +84,12 @@ if __name__ == "__main__":
run_test(
rank=0,
data_args=(32, 64),
max_memory=20,
max_memory=24,
get_model=get_model,
get_data=get_data,
get_chunk_target=get_chunk_target,
print_code=False,
print_mem=False,
print_est_mem=False,
print_progress=False,
)

View File

@@ -70,9 +70,6 @@ def test_evoformer_stack(data_args, max_memory):
max_memory=max_memory,
get_model=get_model,
get_data=get_data,
print_code=False,
print_mem=False,
print_progress=False,
)
mp.spawn(run_func, nprocs=1)
@@ -81,7 +78,7 @@ if __name__ == "__main__":
run_test(
rank=0,
data_args=(32, 64),
max_memory=20,
max_memory=None,
get_model=get_model,
get_data=get_data,
print_code=False,

View File

@@ -55,10 +55,10 @@ def get_data(msa_len: int, pair_len: int) -> Tuple[List, List]:
def get_chunk_target() -> Dict:
return {
None: [(126, 131), (227, 245), (272, 297), (310, 317), (105, 112), (152, 160), (193, 201), (249, 250),
(33, 46)],
20: [(126, 131), (238, 245), (283, 290), (313, 314), (108, 109), (35, 46)],
24: [(126, 131)],
None: [(128, 131), (230, 245), (277, 297), (313, 319), (108, 113), (154, 160), (195, 201), (249, 250),
(36, 46)],
20: [(128, 131), (240, 245), (285, 290), (313, 314), (108, 109), (41, 46)],
24: [(128, 131)],
}
@@ -75,9 +75,7 @@ def test_extramsa_block(data_args, max_memory):
max_memory=max_memory,
get_model=get_model,
get_data=get_data,
print_code=False,
print_mem=False,
print_progress=False,
get_chunk_target=get_chunk_target,
)
mp.spawn(run_func, nprocs=1)
@@ -86,7 +84,7 @@ if __name__ == "__main__":
run_test(
rank=0,
data_args=(32, 64),
max_memory=20,
max_memory=None,
get_model=get_model,
get_data=get_data,
get_chunk_target=get_chunk_target,

View File

@@ -17,8 +17,8 @@ from test_transformer_utils import run_test
from colossalai.autochunk.autochunk_codegen import AUTOCHUNK_AVAILABLE
BATCH_SIZE = 2
SEQ_LENGTH = 256
BATCH_SIZE = 1
SEQ_LENGTH = 512
def get_data(shape: tuple) -> Tuple[List, List]:
@@ -37,17 +37,14 @@ def get_data(shape: tuple) -> Tuple[List, List]:
)
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("shape", [(BATCH_SIZE, SEQ_LENGTH)])
@pytest.mark.parametrize("max_memory", [None, 4.5, 5])
def test_gpt(model, shape, max_memory):
@pytest.mark.parametrize("max_memory", [None, 6, 8])
def test_autochunk_gpt(model, shape, max_memory):
run_func = partial(
run_test,
data=get_data(shape),
max_memory=max_memory,
model=model,
config=GPT2Config(n_embd=96, n_position=shape[1], n_layer=2, n_head=4),
print_code=False,
print_mem=False,
print_progress=False,
)
mp.spawn(run_func, nprocs=1)
@@ -59,7 +56,8 @@ if __name__ == "__main__":
max_memory=None,
model=GPT2Model,
config=GPT2Config(n_embd=96, n_position=SEQ_LENGTH, n_layer=2, n_head=4),
print_code=True,
print_mem=True,
print_code=False,
print_est_mem=False,
print_mem=False,
print_progress=False,
)

View File

@@ -20,6 +20,7 @@ def assert_codegen_run(
model: Any,
data: tuple,
max_memory: int = None,
print_est_mem: bool = False,
print_mem: bool = False,
print_progress: bool = False,
print_code: bool = False,
@@ -41,7 +42,7 @@ def assert_codegen_run(
codegen = AutoChunkCodeGen(
meta_graph,
max_memory=max_memory,
print_mem=print_mem,
print_mem=print_est_mem,
print_progress=print_progress,
)
chunks = codegen.chunk_infos
@@ -61,7 +62,7 @@ def assert_codegen_run(
code = graph.python_code("self").src
if print_code:
print(code)
assert "chunk_result = None; chunk_size = None;" in code
assert "chunk_size = None; " in code
# assert result
inputs = [meta_args[i] if i in meta_args else concrete_args[i] for i in sequence]
@@ -69,26 +70,44 @@ def assert_codegen_run(
model.cuda().eval()
gm.eval()
with torch.no_grad():
out_gm = gm(*inputs)
if print_mem:
torch.cuda.reset_peak_memory_stats()
now_mem = torch.cuda.memory_allocated() / 1024**2
out_gm = gm(*[i.clone() if isinstance(i, torch.Tensor) else i for i in inputs])
if print_mem:
new_max_mem = torch.cuda.max_memory_allocated() / 1024**2
print("mem: %.2fMB" % (new_max_mem - now_mem))
out_model = model(*inputs)
for k in out_model.keys():
if torch.is_tensor(out_gm[k]):
assert torch.equal(
out_model[k], out_gm[k]
), f'{model.__class__.__name__} has incorrect output {k}, expect {out_model[k]}, but got {out_gm[k]}'
assert_allclose(out_model, out_gm)
return chunks
def assert_allclose(out_model: Any, out_gm: Any) -> None:
"""
assert allclose for out
"""
if isinstance(out_model, torch.Tensor):
assert torch.allclose(out_model, out_gm,
atol=1e-4), "fx_out doesn't comply with original output, diff is %.2e" % torch.mean(
torch.abs(out_model - out_gm))
elif isinstance(out_model, dict):
for k in out_model.keys():
assert_allclose(out_model[k], out_gm[k])
elif isinstance(out_model, tuple) or isinstance(out_model, list) or isinstance(out_model, set):
for i, j in zip(out_model, out_gm):
assert_allclose(i, j)
def run_test(
rank: int,
model: Any,
config: Any,
data: tuple,
max_memory: int,
print_code: bool,
print_mem: bool,
print_progress: bool,
print_code: bool = False,
print_est_mem: bool = False,
print_mem: bool = False,
print_progress: bool = False,
get_chunk_target: Any = None,
) -> None:
model = model(config=config)
@@ -108,6 +127,7 @@ def run_test(
data=data,
max_memory=max_memory,
print_code=print_code,
print_est_mem=print_est_mem,
print_mem=print_mem,
print_progress=print_progress,
)
@@ -119,5 +139,3 @@ def run_test(
str(chunk_found),
str(chunk_target),
)
gpc.destroy()