[autochunk] refactor chunk memory estimation (#2762)

* refact memory code

* dont log free var memory

* add memory align

* update chunk target

* update setting for new memory

* finish test

* update tracer

* update typo

* update test
This commit is contained in:
Xuanlei Zhao
2023-03-08 16:22:30 +08:00
committed by GitHub
parent b51bfec357
commit 2ca9728cbb
12 changed files with 294 additions and 422 deletions

View File

@@ -95,7 +95,7 @@ def _benchmark_memory(model, inputs):
with torch.no_grad():
torch.cuda.reset_peak_memory_stats()
now_mem = float(torch.cuda.memory_allocated()) / 1024**2
model(*[i.clone() if isinstance(i, torch.Tensor) else i for i in inputs])
model(*inputs)
new_max_mem = float(torch.cuda.max_memory_allocated()) / 1024**2
return new_max_mem - now_mem
@@ -116,8 +116,7 @@ def _benchmark_speed(model, inputs, loop=5):
def benchmark_autochunk_gpt(batch=1, seq=512, n_embd=768, n_head=12):
from test_autochunk_gpt import GPT2Config, GPT2Model, get_data
model = GPT2Model
config = GPT2Config(n_embd=n_embd, n_position=seq, n_layer=2, n_head=n_head)
config.max_position_embeddings = seq
config = GPT2Config(n_embd=n_embd, n_positions=seq, n_layer=2, n_head=n_head)
model = model(config=config)
shape = [batch, seq]
print("\nbatch: %d, seq: %d, n_embd: %d, n_head: %d" % (batch, seq, n_embd, n_head))

View File

@@ -44,20 +44,19 @@ def test_autochunk_gpt(model, shape, max_memory):
data=get_data(shape),
max_memory=max_memory,
model=model,
config=GPT2Config(n_embd=96, n_position=shape[1], n_layer=2, n_head=4),
config=GPT2Config(n_embd=96, n_positions=shape[1], n_layer=2, n_head=4),
)
mp.spawn(run_func, nprocs=1)
if __name__ == "__main__":
run_test(
rank=0,
data=get_data((BATCH_SIZE, SEQ_LENGTH)),
max_memory=None,
model=GPT2Model,
config=GPT2Config(n_embd=96, n_position=SEQ_LENGTH, n_layer=2, n_head=4),
print_code=False,
print_est_mem=False,
print_mem=False,
print_progress=False,
)
run_test(rank=0,
data=get_data((BATCH_SIZE, SEQ_LENGTH)),
max_memory=None,
model=GPT2Model,
config=GPT2Config(n_embd=96, n_position=SEQ_LENGTH, n_layer=2, n_head=4),
print_code=False,
print_est_mem=False,
print_mem=False,
print_progress=False,
eval_mem=False)

View File

@@ -24,6 +24,7 @@ def assert_codegen_run(
print_mem: bool = False,
print_progress: bool = False,
print_code: bool = False,
eval_mem: bool = False,
) -> List[Dict]:
meta_args, concrete_args, sequence = data
if concrete_args is None:
@@ -39,12 +40,11 @@ def assert_codegen_run(
meta_tensors = [meta_args[i] if i in meta_args else concrete_args[i] for i in sequence]
meta_tensors = [MetaTensor(i, fake_device="cuda:0") if isinstance(i, torch.Tensor) else i for i in meta_tensors]
interp.propagate(*meta_tensors)
codegen = AutoChunkCodeGen(
meta_graph,
max_memory=max_memory,
print_mem=print_est_mem,
print_progress=print_progress,
)
codegen = AutoChunkCodeGen(meta_graph,
max_memory=max_memory,
print_mem=print_est_mem,
print_progress=print_progress,
eval_mem=eval_mem)
chunks = codegen.chunk_infos
# trace and recompile
@@ -108,6 +108,7 @@ def run_test(
print_est_mem: bool = False,
print_mem: bool = False,
print_progress: bool = False,
eval_mem: bool = False,
get_chunk_target: Any = None,
) -> None:
model = model(config=config)
@@ -122,15 +123,14 @@ def run_test(
)
# build model and input
chunks = assert_codegen_run(
model,
data=data,
max_memory=max_memory,
print_code=print_code,
print_est_mem=print_est_mem,
print_mem=print_mem,
print_progress=print_progress,
)
chunks = assert_codegen_run(model,
data=data,
max_memory=max_memory,
print_code=print_code,
print_est_mem=print_est_mem,
print_mem=print_mem,
print_progress=print_progress,
eval_mem=eval_mem)
if get_chunk_target is not None:
chunk_found = [i["region"] for i in chunks]