From 12301dd2e9a1889fe76c6ab719aff1404e92aea0 Mon Sep 17 00:00:00 2001 From: oahzxl Date: Tue, 8 Nov 2022 10:34:14 +0800 Subject: [PATCH] finish basic inference memory estimation --- chunk_codegen.py | 11 +++++++++++ chunk_codegen_run.py | 14 ++++++++++++-- 2 files changed, 23 insertions(+), 2 deletions(-) diff --git a/chunk_codegen.py b/chunk_codegen.py index 4ca33a4d5..01b29cb33 100644 --- a/chunk_codegen.py +++ b/chunk_codegen.py @@ -64,6 +64,8 @@ def _estimate_inference_mem(gm: torch.fx.GraphModule): # if node is placeholder, just add the size of the node if node.op == 'placeholder': act_memory += _get_meta_node_size(node) + act_memory_peak_log.append(act_memory) + act_memory_after_node_log.append(act_memory) # skip output elif node.op == 'output': continue @@ -81,6 +83,15 @@ def _estimate_inference_mem(gm: torch.fx.GraphModule): act_memory_after_node_log.append(act_memory) act_memory_peak_log = [float(i) / (1024 ** 2) for i in act_memory_peak_log] + act_memory_after_node_log = [float(i) / (1024 ** 2) for i in act_memory_after_node_log] + + # for i in act_memory_peak_log: + # print("%.2f " % i, end='') + # print("\n") + # for i in act_memory_after_node_log: + # print("%.2f " % i, end='') + # print("\n") + param_memory = parameter_size(gm) return (act_memory + param_memory) / (1024 ** 2), param_memory / (1024 ** 2) diff --git a/chunk_codegen_run.py b/chunk_codegen_run.py index 1ab7d958b..cc975f2ea 100644 --- a/chunk_codegen_run.py +++ b/chunk_codegen_run.py @@ -32,9 +32,19 @@ def _is_all_param_close(m: torch.nn.Module, gm: GraphModule) -> bool: def _test_fwd_and_bwd(model: torch.nn.Module, gm: ColoGraphModule, node, pair): + # now_mem = torch.cuda.memory_allocated() / 1024**2 + # max_mem = torch.cuda.max_memory_allocated() / 1024**2 + # print("now:%.2f max:%.2f" %(torch.cuda.memory_allocated() / 1024**2, torch.cuda.max_memory_allocated() / 1024**2)) + # with torch.no_grad(): + # fx_out = gm(node, pair) + # new_now_mem = torch.cuda.memory_allocated() / 1024**2 + # new_max_mem = torch.cuda.max_memory_allocated() / 1024**2 + # print("now:%.2f max:%.2f" %(new_now_mem - now_mem, new_max_mem - max_mem)) + # test forward - non_fx_out = model(node.clone(), pair.clone()) - fx_out = gm(node.clone(), pair.clone()) + with torch.no_grad(): + non_fx_out = model(node, pair) + fx_out = gm(node, pair) assert torch.equal(non_fx_out[0], fx_out[0]), "fx_out doesn't comply with original output" assert torch.equal(non_fx_out[1], fx_out[1]), "fx_out doesn't comply with original output"