finish memory estimation

This commit is contained in:
oahzxl
2022-11-11 15:43:03 +08:00
parent 22f9c60b6b
commit d7634af5c0
2 changed files with 80 additions and 47 deletions

View File

@@ -32,14 +32,14 @@ def _is_all_param_close(m: torch.nn.Module, gm: GraphModule) -> bool:
def _test_fwd_and_bwd(model: torch.nn.Module, gm: ColoGraphModule, node, pair):
# now_mem = torch.cuda.memory_allocated() / 1024**2
# max_mem = torch.cuda.max_memory_allocated() / 1024**2
# print("now:%.2f max:%.2f" %(torch.cuda.memory_allocated() / 1024**2, torch.cuda.max_memory_allocated() / 1024**2))
# with torch.no_grad():
# fx_out = gm(node, pair)
# new_now_mem = torch.cuda.memory_allocated() / 1024**2
# new_max_mem = torch.cuda.max_memory_allocated() / 1024**2
# print("now:%.2f max:%.2f" %(new_now_mem - now_mem, new_max_mem - max_mem))
now_mem = torch.cuda.memory_allocated() / 1024**2
with torch.no_grad():
node0 = node.clone()
pair0 = pair.clone()
node1, pair1 = gm(node0, pair0)
new_now_mem = torch.cuda.memory_allocated() / 1024**2
new_max_mem = torch.cuda.max_memory_allocated() / 1024**2
print("now:%.2f max:%.2f" %(new_now_mem - now_mem, new_max_mem - now_mem))
# test forward
with torch.no_grad():
@@ -63,8 +63,8 @@ def _run_offload_codegen(rank):
# build model and input
model = evoformer_base().cuda()
node = torch.randn(1, 16, 32, 256).cuda()
pair = torch.randn(1, 32, 32, 128).cuda()
node = torch.randn(1, 100, 300, 256).cuda()
pair = torch.randn(1, 300, 300, 128).cuda()
# trace the module and replace codegen
graph = ColoTracer().trace(model, meta_args={'node': node.to(torch.device('meta')), 'pair': pair.to(torch.device('meta'))})