finish basic inference memory estimation

This commit is contained in:
oahzxl
2022-11-08 10:34:14 +08:00
parent d95cfe2622
commit 12301dd2e9
2 changed files with 23 additions and 2 deletions

View File

@@ -32,9 +32,19 @@ def _is_all_param_close(m: torch.nn.Module, gm: GraphModule) -> bool:
def _test_fwd_and_bwd(model: torch.nn.Module, gm: ColoGraphModule, node, pair):
# now_mem = torch.cuda.memory_allocated() / 1024**2
# max_mem = torch.cuda.max_memory_allocated() / 1024**2
# print("now:%.2f max:%.2f" %(torch.cuda.memory_allocated() / 1024**2, torch.cuda.max_memory_allocated() / 1024**2))
# with torch.no_grad():
# fx_out = gm(node, pair)
# new_now_mem = torch.cuda.memory_allocated() / 1024**2
# new_max_mem = torch.cuda.max_memory_allocated() / 1024**2
# print("now:%.2f max:%.2f" %(new_now_mem - now_mem, new_max_mem - max_mem))
# test forward
non_fx_out = model(node.clone(), pair.clone())
fx_out = gm(node.clone(), pair.clone())
with torch.no_grad():
non_fx_out = model(node, pair)
fx_out = gm(node, pair)
assert torch.equal(non_fx_out[0], fx_out[0]), "fx_out doesn't comply with original output"
assert torch.equal(non_fx_out[1], fx_out[1]), "fx_out doesn't comply with original output"