[fx/profiler] tuned the calculation of memory estimation (#1619)

* [fx] tuned the meta info and rotor solver.

* [fx] remove import.

* [fx] remove import.

* [fx] remove import.

* [fx] tune the meta calculations.

* [fx] polish comments.

* [fx] remove assertions.

* [fx] modify test cases.

* [fx] modify test cases.

* [fx] optimize import.

* [fx
This commit is contained in:
Super Daniel
2022-09-23 10:59:47 +08:00
committed by GitHub
parent f7f2248771
commit d967779a32
16 changed files with 413 additions and 207 deletions

View File

@@ -94,11 +94,9 @@ class MetaInfoProp(torch.fx.Interpreter):
tensor_meta = tree_map(extract_tensor_meta, result)
n.meta['tensor_meta'] = tensor_meta
n.meta = {**n.meta, **asdict(meta_info), 'fwd_mem_out': 0} # extend MetaInfo to `n.meta`
n.meta = {**n.meta, **asdict(meta_info)} # extend MetaInfo to `n.meta`
# TODO: the attribute node_size should be removed in the future
setattr(n, 'node_size', n.meta.get('fwd_mem_tmp', 0) + n.meta.get('fwd_mem_out', 0))
for par in n.all_input_nodes:
par.meta['fwd_mem_out'] = max(par.meta.get('fwd_mem_out', 0), n.meta.get('fwd_mem_in', 0))
n.meta['type'] = type(result)
# retain the autograd graph
@@ -224,7 +222,7 @@ class MetaInfoProp(torch.fx.Interpreter):
result (Any): The argument value that was retrieved
meta_info (MetaInfo): The memory cost and FLOPs estimated with `MetaTensor`.
"""
return args[0], GraphInfo(fwd_mem_in=activation_size(args[0]))
return args[0], GraphInfo(save_fwd_in=True)
def propagate(self, *args):
"""