mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-08 20:40:34 +00:00
[autochunk] refactor chunk memory estimation (#2762)
* refact memory code * dont log free var memory * add memory align * update chunk target * update setting for new memory * finish test * update tracer * update typo * update test
This commit is contained in:
@@ -61,7 +61,7 @@ def _benchmark_evoformer_stack_gm(
|
||||
# bench
|
||||
mem = _benchmark_memory(gm, inputs)
|
||||
speed = _benchmark_speed(gm, inputs)
|
||||
print("evoformer stack gm, mem: %.2fMB, time: %.4fs, data_args: %s" % (mem, speed, str(data_args)))
|
||||
print("evoformer stack gm, mem: %.2fMB, time: %.4fs" % (mem, speed))
|
||||
|
||||
|
||||
def _benchmark_evoformer_stack_origin(
|
||||
@@ -83,14 +83,15 @@ def _benchmark_evoformer_stack_origin(
|
||||
# bench
|
||||
mem = _benchmark_memory(model, inputs)
|
||||
speed = _benchmark_speed(model, inputs)
|
||||
print("evoformer stack origin, mem: %.2fMB, time: %.4fs, data_args: %s" % (mem, speed, str(data_args)))
|
||||
print("evoformer stack origin, mem: %.2fMB, time: %.4fs" % (mem, speed))
|
||||
return mem
|
||||
|
||||
|
||||
def _benchmark_memory(model, inputs):
|
||||
with torch.no_grad():
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
now_mem = torch.cuda.memory_allocated() / 1024**2
|
||||
model(*[i.clone() if isinstance(i, torch.Tensor) else i for i in inputs])
|
||||
model(*inputs)
|
||||
new_max_mem = torch.cuda.max_memory_allocated() / 1024**2
|
||||
return new_max_mem - now_mem
|
||||
|
||||
@@ -108,13 +109,18 @@ def _benchmark_speed(model, inputs, loop=5):
|
||||
return (time2 - time1) / loop
|
||||
|
||||
|
||||
def benchmark_evoformer_stack():
|
||||
def benchmark_evoformer_stack(data_args):
|
||||
from test_autochunk_evoformer_stack import get_data, get_model
|
||||
data_args = [128, 256]
|
||||
print("")
|
||||
_benchmark_evoformer_stack_origin(data_args, get_model, get_data)
|
||||
_benchmark_evoformer_stack_gm(data_args, 600, get_model, get_data)
|
||||
_benchmark_evoformer_stack_gm(data_args, 400, get_model, get_data)
|
||||
print("\nmsa len: %d, pair len: %d" % (data_args[0], data_args[1]))
|
||||
max_mem = _benchmark_evoformer_stack_origin(data_args, get_model, get_data)
|
||||
for ratio in [0.5, 0.4, 0.3, 0.2, 0.1]:
|
||||
try:
|
||||
_benchmark_evoformer_stack_gm(data_args, max_mem * ratio, get_model, get_data)
|
||||
except RuntimeError as e:
|
||||
if e.args[0] == 'Search failed. Try a larger memory threshold.':
|
||||
break
|
||||
except Exception as e:
|
||||
raise e
|
||||
_benchmark_evoformer_stack_gm(data_args, None, get_model, get_data)
|
||||
|
||||
|
||||
@@ -128,4 +134,7 @@ if __name__ == "__main__":
|
||||
port=free_port(),
|
||||
backend="nccl",
|
||||
)
|
||||
benchmark_evoformer_stack()
|
||||
benchmark_evoformer_stack((256, 256))
|
||||
benchmark_evoformer_stack((256, 512))
|
||||
benchmark_evoformer_stack((256, 1024))
|
||||
benchmark_evoformer_stack((256, 1280))
|
||||
|
@@ -55,10 +55,10 @@ def get_data(msa_len: int, pair_len: int) -> Tuple[List, List]:
|
||||
|
||||
def get_chunk_target() -> Dict:
|
||||
return {
|
||||
None: [(120, 123), (222, 237), (269, 289), (305, 311), (100, 105), (146, 152), (187, 193), (241, 242),
|
||||
(25, 50)],
|
||||
20: [(120, 123), (232, 237), (277, 282), (305, 306), (100, 101), (34, 39)],
|
||||
24: [(120, 123)],
|
||||
None: [(120, 126), (225, 244), (270, 289), (306, 311), (70, 106), (23, 46), (146, 152), (187, 193), (181, 184),
|
||||
(140, 145), (162, 163), (203, 204)],
|
||||
20: [(120, 123), (232, 237), (277, 282), (305, 306)],
|
||||
24: [(122, 123)],
|
||||
}
|
||||
|
||||
|
||||
|
@@ -53,15 +53,6 @@ def get_data(msa_len: int, pair_len: int) -> Tuple[List, List]:
|
||||
return meta_args, concrete_args
|
||||
|
||||
|
||||
def get_chunk_target() -> Dict:
|
||||
return {
|
||||
None: [(128, 131), (230, 245), (277, 297), (313, 319), (108, 113), (154, 160), (195, 201), (249, 250),
|
||||
(36, 46)],
|
||||
20: [(128, 131), (240, 245), (285, 290), (313, 314), (108, 109), (41, 46)],
|
||||
24: [(128, 131)],
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not (AUTOCHUNK_AVAILABLE and HAS_REPO),
|
||||
reason="torch version is lower than 1.12.0",
|
||||
@@ -75,7 +66,6 @@ def test_extramsa_block(data_args, max_memory):
|
||||
max_memory=max_memory,
|
||||
get_model=get_model,
|
||||
get_data=get_data,
|
||||
get_chunk_target=get_chunk_target,
|
||||
)
|
||||
mp.spawn(run_func, nprocs=1)
|
||||
|
||||
@@ -87,7 +77,6 @@ if __name__ == "__main__":
|
||||
max_memory=None,
|
||||
get_model=get_model,
|
||||
get_data=get_data,
|
||||
get_chunk_target=get_chunk_target,
|
||||
print_code=False,
|
||||
print_mem=False,
|
||||
print_progress=False,
|
||||
|
@@ -95,7 +95,7 @@ def _benchmark_memory(model, inputs):
|
||||
with torch.no_grad():
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
now_mem = float(torch.cuda.memory_allocated()) / 1024**2
|
||||
model(*[i.clone() if isinstance(i, torch.Tensor) else i for i in inputs])
|
||||
model(*inputs)
|
||||
new_max_mem = float(torch.cuda.max_memory_allocated()) / 1024**2
|
||||
return new_max_mem - now_mem
|
||||
|
||||
@@ -116,8 +116,7 @@ def _benchmark_speed(model, inputs, loop=5):
|
||||
def benchmark_autochunk_gpt(batch=1, seq=512, n_embd=768, n_head=12):
|
||||
from test_autochunk_gpt import GPT2Config, GPT2Model, get_data
|
||||
model = GPT2Model
|
||||
config = GPT2Config(n_embd=n_embd, n_position=seq, n_layer=2, n_head=n_head)
|
||||
config.max_position_embeddings = seq
|
||||
config = GPT2Config(n_embd=n_embd, n_positions=seq, n_layer=2, n_head=n_head)
|
||||
model = model(config=config)
|
||||
shape = [batch, seq]
|
||||
print("\nbatch: %d, seq: %d, n_embd: %d, n_head: %d" % (batch, seq, n_embd, n_head))
|
||||
|
@@ -44,20 +44,19 @@ def test_autochunk_gpt(model, shape, max_memory):
|
||||
data=get_data(shape),
|
||||
max_memory=max_memory,
|
||||
model=model,
|
||||
config=GPT2Config(n_embd=96, n_position=shape[1], n_layer=2, n_head=4),
|
||||
config=GPT2Config(n_embd=96, n_positions=shape[1], n_layer=2, n_head=4),
|
||||
)
|
||||
mp.spawn(run_func, nprocs=1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_test(
|
||||
rank=0,
|
||||
data=get_data((BATCH_SIZE, SEQ_LENGTH)),
|
||||
max_memory=None,
|
||||
model=GPT2Model,
|
||||
config=GPT2Config(n_embd=96, n_position=SEQ_LENGTH, n_layer=2, n_head=4),
|
||||
print_code=False,
|
||||
print_est_mem=False,
|
||||
print_mem=False,
|
||||
print_progress=False,
|
||||
)
|
||||
run_test(rank=0,
|
||||
data=get_data((BATCH_SIZE, SEQ_LENGTH)),
|
||||
max_memory=None,
|
||||
model=GPT2Model,
|
||||
config=GPT2Config(n_embd=96, n_position=SEQ_LENGTH, n_layer=2, n_head=4),
|
||||
print_code=False,
|
||||
print_est_mem=False,
|
||||
print_mem=False,
|
||||
print_progress=False,
|
||||
eval_mem=False)
|
||||
|
@@ -24,6 +24,7 @@ def assert_codegen_run(
|
||||
print_mem: bool = False,
|
||||
print_progress: bool = False,
|
||||
print_code: bool = False,
|
||||
eval_mem: bool = False,
|
||||
) -> List[Dict]:
|
||||
meta_args, concrete_args, sequence = data
|
||||
if concrete_args is None:
|
||||
@@ -39,12 +40,11 @@ def assert_codegen_run(
|
||||
meta_tensors = [meta_args[i] if i in meta_args else concrete_args[i] for i in sequence]
|
||||
meta_tensors = [MetaTensor(i, fake_device="cuda:0") if isinstance(i, torch.Tensor) else i for i in meta_tensors]
|
||||
interp.propagate(*meta_tensors)
|
||||
codegen = AutoChunkCodeGen(
|
||||
meta_graph,
|
||||
max_memory=max_memory,
|
||||
print_mem=print_est_mem,
|
||||
print_progress=print_progress,
|
||||
)
|
||||
codegen = AutoChunkCodeGen(meta_graph,
|
||||
max_memory=max_memory,
|
||||
print_mem=print_est_mem,
|
||||
print_progress=print_progress,
|
||||
eval_mem=eval_mem)
|
||||
chunks = codegen.chunk_infos
|
||||
|
||||
# trace and recompile
|
||||
@@ -108,6 +108,7 @@ def run_test(
|
||||
print_est_mem: bool = False,
|
||||
print_mem: bool = False,
|
||||
print_progress: bool = False,
|
||||
eval_mem: bool = False,
|
||||
get_chunk_target: Any = None,
|
||||
) -> None:
|
||||
model = model(config=config)
|
||||
@@ -122,15 +123,14 @@ def run_test(
|
||||
)
|
||||
|
||||
# build model and input
|
||||
chunks = assert_codegen_run(
|
||||
model,
|
||||
data=data,
|
||||
max_memory=max_memory,
|
||||
print_code=print_code,
|
||||
print_est_mem=print_est_mem,
|
||||
print_mem=print_mem,
|
||||
print_progress=print_progress,
|
||||
)
|
||||
chunks = assert_codegen_run(model,
|
||||
data=data,
|
||||
max_memory=max_memory,
|
||||
print_code=print_code,
|
||||
print_est_mem=print_est_mem,
|
||||
print_mem=print_mem,
|
||||
print_progress=print_progress,
|
||||
eval_mem=eval_mem)
|
||||
|
||||
if get_chunk_target is not None:
|
||||
chunk_found = [i["region"] for i in chunks]
|
||||
|
Reference in New Issue
Block a user