[fx/profiler] tuned the calculation of memory estimation (#1619)

* [fx] tuned the meta info and rotor solver.

* [fx] remove import.

* [fx] remove import.

* [fx] remove import.

* [fx] tune the meta calculations.

* [fx] polish comments.

* [fx] remove assertions.

* [fx] modify test cases.

* [fx] modify test cases.

* [fx] optimize import.

* [fx
This commit is contained in:
Super Daniel
2022-09-23 10:59:47 +08:00
committed by GitHub
parent f7f2248771
commit d967779a32
16 changed files with 413 additions and 207 deletions

View File

@@ -2,6 +2,7 @@ from typing import List, Tuple
from torch.fx import Node
from colossalai.fx.graph_module import ColoGraphModule
from colossalai.fx.profiler import activation_size, parameter_size
from colossalai.fx.profiler.tensor import MetaTensor
import math
from .linearize import linearize
from .operation import ForwardCheck, ForwardEnable, ForwardNograd, Backward, Loss, Chain, Sequence, Function
@@ -123,7 +124,9 @@ def _fwd_xbar(node: List[Node]) -> int:
xbar = 0
for n in node:
xbar += n.meta['fwd_mem_tmp'] + n.meta['fwd_mem_out']
xbar += n.meta['fwd_mem_tmp']
if any(map(lambda x: x.meta['save_fwd_in'], n.users)):
xbar += n.meta['fwd_mem_out']
return xbar
@@ -177,10 +180,13 @@ def _get_bwd_mem_tmp(node: List[Node]) -> int:
def _get_deps_size():
deps_size = 0
for k, v in deps.items():
k: Node
if v > 0:
deps_size += k.meta['bwd_mem_out']
if v == float('-inf'):
deps_size -= k.meta['fwd_mem_tmp'] + k.meta['fwd_mem_out']
deps_size -= k.meta['fwd_mem_tmp']
if any(map(lambda x: x.meta['save_fwd_in'], k.users)):
deps_size -= k.meta['fwd_mem_out']
return deps_size
@@ -333,8 +339,8 @@ def solver_rotor(gm: ColoGraphModule,
"""
node_list = linearize(gm, cnode)
mem_limit -= parameter_size(gm)
mem_unit = mem_limit * (1.0 - eps) // mem_slots
data = MetaTensor(data, fake_device=next(gm.parameters()).device)
MetaInfoProp(gm).run(data)
chain: Chain = _construct_chain(node_list, data)