mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-09 21:09:18 +00:00
[autochunk] support multi outputs chunk search (#2538)
Support multi outputs chunk search. Previously we only support single output chunk search. It is more flexible and improve performance by a large margin. For transformer, we reduce memory by 40% than previous search strategy. 1. rewrite search strategy to support multi outputs chunk search 2. fix many, many bugs 3. update tests
This commit is contained in:
@@ -23,6 +23,7 @@ def assert_codegen_run(
|
||||
concrete_args: List = None,
|
||||
max_memory: int = None,
|
||||
print_mem: bool = False,
|
||||
print_est_mem: bool = False,
|
||||
print_progress: bool = False,
|
||||
print_code: bool = False,
|
||||
) -> List[Dict]:
|
||||
@@ -41,7 +42,7 @@ def assert_codegen_run(
|
||||
codegen = AutoChunkCodeGen(
|
||||
meta_graph,
|
||||
max_memory=max_memory,
|
||||
print_mem=print_mem,
|
||||
print_mem=print_est_mem,
|
||||
print_progress=print_progress,
|
||||
)
|
||||
chunks = codegen.chunk_infos
|
||||
@@ -61,13 +62,20 @@ def assert_codegen_run(
|
||||
code = graph.python_code("self").src
|
||||
if print_code:
|
||||
print(code)
|
||||
assert "chunk_result = None; chunk_size = None;" in code
|
||||
assert "chunk_size = None; " in code
|
||||
|
||||
# assert result
|
||||
inputs = [i[1] for i in meta_args] + [i[1] for i in concrete_args]
|
||||
inputs = [i.cuda() if isinstance(i, torch.Tensor) else i for i in inputs]
|
||||
model.cuda()
|
||||
with torch.no_grad():
|
||||
out_gm = gm(*inputs)
|
||||
if print_mem:
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
now_mem = torch.cuda.memory_allocated() / 1024**2
|
||||
out_gm = gm(*[i.clone() if isinstance(i, torch.Tensor) else i for i in inputs])
|
||||
if print_mem:
|
||||
new_max_mem = torch.cuda.max_memory_allocated() / 1024**2
|
||||
print("mem: %.2fMB" % (new_max_mem - now_mem))
|
||||
out_model = model(*inputs)
|
||||
out_gm = flat_list(out_gm)
|
||||
out_model = flat_list(out_model)
|
||||
@@ -85,9 +93,10 @@ def run_test(
|
||||
max_memory: int,
|
||||
get_model: Any,
|
||||
get_data: Any,
|
||||
print_code: bool,
|
||||
print_mem: bool,
|
||||
print_progress: bool,
|
||||
print_code: bool = False,
|
||||
print_mem: bool = False,
|
||||
print_est_mem: bool = False,
|
||||
print_progress: bool = False,
|
||||
get_chunk_target: Any = None,
|
||||
) -> None:
|
||||
# launch colossalai
|
||||
@@ -110,6 +119,7 @@ def run_test(
|
||||
max_memory=max_memory,
|
||||
print_code=print_code,
|
||||
print_mem=print_mem,
|
||||
print_est_mem=print_est_mem,
|
||||
print_progress=print_progress,
|
||||
)
|
||||
|
||||
|
@@ -55,9 +55,10 @@ def get_data(msa_len: int, pair_len: int) -> Tuple[List, List]:
|
||||
|
||||
def get_chunk_target() -> Dict:
|
||||
return {
|
||||
None: [(118, 123), (219, 237), (264, 289), (302, 309), (97, 104), (144, 152), (185, 193), (241, 242), (21, 46)],
|
||||
20: [(118, 123), (230, 237), (275, 282), (305, 306), (100, 101), (32, 39), (73, 79)],
|
||||
24: [(118, 123)],
|
||||
None: [(120, 123), (222, 237), (269, 289), (305, 311), (100, 105), (146, 152), (187, 193), (241, 242),
|
||||
(25, 50)],
|
||||
20: [(120, 123), (232, 237), (277, 282), (305, 306), (100, 101), (34, 39)],
|
||||
24: [(120, 123)],
|
||||
}
|
||||
|
||||
|
||||
@@ -75,9 +76,6 @@ def test_evoformer_block(data_args, max_memory):
|
||||
get_model=get_model,
|
||||
get_data=get_data,
|
||||
get_chunk_target=get_chunk_target,
|
||||
print_code=False,
|
||||
print_mem=False,
|
||||
print_progress=False,
|
||||
)
|
||||
mp.spawn(run_func, nprocs=1)
|
||||
|
||||
@@ -86,10 +84,12 @@ if __name__ == "__main__":
|
||||
run_test(
|
||||
rank=0,
|
||||
data_args=(32, 64),
|
||||
max_memory=20,
|
||||
max_memory=24,
|
||||
get_model=get_model,
|
||||
get_data=get_data,
|
||||
get_chunk_target=get_chunk_target,
|
||||
print_code=False,
|
||||
print_mem=False,
|
||||
print_est_mem=False,
|
||||
print_progress=False,
|
||||
)
|
||||
|
@@ -70,9 +70,6 @@ def test_evoformer_stack(data_args, max_memory):
|
||||
max_memory=max_memory,
|
||||
get_model=get_model,
|
||||
get_data=get_data,
|
||||
print_code=False,
|
||||
print_mem=False,
|
||||
print_progress=False,
|
||||
)
|
||||
mp.spawn(run_func, nprocs=1)
|
||||
|
||||
@@ -81,7 +78,7 @@ if __name__ == "__main__":
|
||||
run_test(
|
||||
rank=0,
|
||||
data_args=(32, 64),
|
||||
max_memory=20,
|
||||
max_memory=None,
|
||||
get_model=get_model,
|
||||
get_data=get_data,
|
||||
print_code=False,
|
||||
|
@@ -55,10 +55,10 @@ def get_data(msa_len: int, pair_len: int) -> Tuple[List, List]:
|
||||
|
||||
def get_chunk_target() -> Dict:
|
||||
return {
|
||||
None: [(126, 131), (227, 245), (272, 297), (310, 317), (105, 112), (152, 160), (193, 201), (249, 250),
|
||||
(33, 46)],
|
||||
20: [(126, 131), (238, 245), (283, 290), (313, 314), (108, 109), (35, 46)],
|
||||
24: [(126, 131)],
|
||||
None: [(128, 131), (230, 245), (277, 297), (313, 319), (108, 113), (154, 160), (195, 201), (249, 250),
|
||||
(36, 46)],
|
||||
20: [(128, 131), (240, 245), (285, 290), (313, 314), (108, 109), (41, 46)],
|
||||
24: [(128, 131)],
|
||||
}
|
||||
|
||||
|
||||
@@ -75,9 +75,7 @@ def test_extramsa_block(data_args, max_memory):
|
||||
max_memory=max_memory,
|
||||
get_model=get_model,
|
||||
get_data=get_data,
|
||||
print_code=False,
|
||||
print_mem=False,
|
||||
print_progress=False,
|
||||
get_chunk_target=get_chunk_target,
|
||||
)
|
||||
mp.spawn(run_func, nprocs=1)
|
||||
|
||||
@@ -86,7 +84,7 @@ if __name__ == "__main__":
|
||||
run_test(
|
||||
rank=0,
|
||||
data_args=(32, 64),
|
||||
max_memory=20,
|
||||
max_memory=None,
|
||||
get_model=get_model,
|
||||
get_data=get_data,
|
||||
get_chunk_target=get_chunk_target,
|
||||
|
@@ -17,8 +17,8 @@ from test_transformer_utils import run_test
|
||||
|
||||
from colossalai.autochunk.autochunk_codegen import AUTOCHUNK_AVAILABLE
|
||||
|
||||
BATCH_SIZE = 2
|
||||
SEQ_LENGTH = 256
|
||||
BATCH_SIZE = 1
|
||||
SEQ_LENGTH = 512
|
||||
|
||||
|
||||
def get_data(shape: tuple) -> Tuple[List, List]:
|
||||
@@ -37,17 +37,14 @@ def get_data(shape: tuple) -> Tuple[List, List]:
|
||||
)
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("shape", [(BATCH_SIZE, SEQ_LENGTH)])
|
||||
@pytest.mark.parametrize("max_memory", [None, 4.5, 5])
|
||||
def test_gpt(model, shape, max_memory):
|
||||
@pytest.mark.parametrize("max_memory", [None, 6, 8])
|
||||
def test_autochunk_gpt(model, shape, max_memory):
|
||||
run_func = partial(
|
||||
run_test,
|
||||
data=get_data(shape),
|
||||
max_memory=max_memory,
|
||||
model=model,
|
||||
config=GPT2Config(n_embd=96, n_position=shape[1], n_layer=2, n_head=4),
|
||||
print_code=False,
|
||||
print_mem=False,
|
||||
print_progress=False,
|
||||
)
|
||||
mp.spawn(run_func, nprocs=1)
|
||||
|
||||
@@ -59,7 +56,8 @@ if __name__ == "__main__":
|
||||
max_memory=None,
|
||||
model=GPT2Model,
|
||||
config=GPT2Config(n_embd=96, n_position=SEQ_LENGTH, n_layer=2, n_head=4),
|
||||
print_code=True,
|
||||
print_mem=True,
|
||||
print_code=False,
|
||||
print_est_mem=False,
|
||||
print_mem=False,
|
||||
print_progress=False,
|
||||
)
|
||||
|
@@ -20,6 +20,7 @@ def assert_codegen_run(
|
||||
model: Any,
|
||||
data: tuple,
|
||||
max_memory: int = None,
|
||||
print_est_mem: bool = False,
|
||||
print_mem: bool = False,
|
||||
print_progress: bool = False,
|
||||
print_code: bool = False,
|
||||
@@ -41,7 +42,7 @@ def assert_codegen_run(
|
||||
codegen = AutoChunkCodeGen(
|
||||
meta_graph,
|
||||
max_memory=max_memory,
|
||||
print_mem=print_mem,
|
||||
print_mem=print_est_mem,
|
||||
print_progress=print_progress,
|
||||
)
|
||||
chunks = codegen.chunk_infos
|
||||
@@ -61,7 +62,7 @@ def assert_codegen_run(
|
||||
code = graph.python_code("self").src
|
||||
if print_code:
|
||||
print(code)
|
||||
assert "chunk_result = None; chunk_size = None;" in code
|
||||
assert "chunk_size = None; " in code
|
||||
|
||||
# assert result
|
||||
inputs = [meta_args[i] if i in meta_args else concrete_args[i] for i in sequence]
|
||||
@@ -69,26 +70,44 @@ def assert_codegen_run(
|
||||
model.cuda().eval()
|
||||
gm.eval()
|
||||
with torch.no_grad():
|
||||
out_gm = gm(*inputs)
|
||||
if print_mem:
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
now_mem = torch.cuda.memory_allocated() / 1024**2
|
||||
out_gm = gm(*[i.clone() if isinstance(i, torch.Tensor) else i for i in inputs])
|
||||
if print_mem:
|
||||
new_max_mem = torch.cuda.max_memory_allocated() / 1024**2
|
||||
print("mem: %.2fMB" % (new_max_mem - now_mem))
|
||||
out_model = model(*inputs)
|
||||
for k in out_model.keys():
|
||||
if torch.is_tensor(out_gm[k]):
|
||||
assert torch.equal(
|
||||
out_model[k], out_gm[k]
|
||||
), f'{model.__class__.__name__} has incorrect output {k}, expect {out_model[k]}, but got {out_gm[k]}'
|
||||
|
||||
assert_allclose(out_model, out_gm)
|
||||
return chunks
|
||||
|
||||
|
||||
def assert_allclose(out_model: Any, out_gm: Any) -> None:
|
||||
"""
|
||||
assert allclose for out
|
||||
"""
|
||||
if isinstance(out_model, torch.Tensor):
|
||||
assert torch.allclose(out_model, out_gm,
|
||||
atol=1e-4), "fx_out doesn't comply with original output, diff is %.2e" % torch.mean(
|
||||
torch.abs(out_model - out_gm))
|
||||
elif isinstance(out_model, dict):
|
||||
for k in out_model.keys():
|
||||
assert_allclose(out_model[k], out_gm[k])
|
||||
elif isinstance(out_model, tuple) or isinstance(out_model, list) or isinstance(out_model, set):
|
||||
for i, j in zip(out_model, out_gm):
|
||||
assert_allclose(i, j)
|
||||
|
||||
|
||||
def run_test(
|
||||
rank: int,
|
||||
model: Any,
|
||||
config: Any,
|
||||
data: tuple,
|
||||
max_memory: int,
|
||||
print_code: bool,
|
||||
print_mem: bool,
|
||||
print_progress: bool,
|
||||
print_code: bool = False,
|
||||
print_est_mem: bool = False,
|
||||
print_mem: bool = False,
|
||||
print_progress: bool = False,
|
||||
get_chunk_target: Any = None,
|
||||
) -> None:
|
||||
model = model(config=config)
|
||||
@@ -108,6 +127,7 @@ def run_test(
|
||||
data=data,
|
||||
max_memory=max_memory,
|
||||
print_code=print_code,
|
||||
print_est_mem=print_est_mem,
|
||||
print_mem=print_mem,
|
||||
print_progress=print_progress,
|
||||
)
|
||||
@@ -119,5 +139,3 @@ def run_test(
|
||||
str(chunk_found),
|
||||
str(chunk_target),
|
||||
)
|
||||
|
||||
gpc.destroy()
|
||||
|
Reference in New Issue
Block a user