finish basic inference memory estimation

This commit is contained in:
oahzxl
2022-11-08 10:34:14 +08:00
parent d95cfe2622
commit 12301dd2e9
2 changed files with 23 additions and 2 deletions

View File

@@ -64,6 +64,8 @@ def _estimate_inference_mem(gm: torch.fx.GraphModule):
# if node is placeholder, just add the size of the node
if node.op == 'placeholder':
act_memory += _get_meta_node_size(node)
act_memory_peak_log.append(act_memory)
act_memory_after_node_log.append(act_memory)
# skip output
elif node.op == 'output':
continue
@@ -81,6 +83,15 @@ def _estimate_inference_mem(gm: torch.fx.GraphModule):
act_memory_after_node_log.append(act_memory)
act_memory_peak_log = [float(i) / (1024 ** 2) for i in act_memory_peak_log]
act_memory_after_node_log = [float(i) / (1024 ** 2) for i in act_memory_after_node_log]
# for i in act_memory_peak_log:
# print("%.2f " % i, end='')
# print("\n")
# for i in act_memory_after_node_log:
# print("%.2f " % i, end='')
# print("\n")
param_memory = parameter_size(gm)
return (act_memory + param_memory) / (1024 ** 2), param_memory / (1024 ** 2)