mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-07 03:52:01 +00:00
[fx/profiler] tuned the calculation of memory estimation (#1619)
* [fx] tuned the meta info and rotor solver. * [fx] remove import. * [fx] remove import. * [fx] remove import. * [fx] tune the meta calculations. * [fx] polish comments. * [fx] remove assertions. * [fx] modify test cases. * [fx] modify test cases. * [fx] optimize import. * [fx
This commit is contained in:
@@ -2,6 +2,7 @@ from typing import List, Tuple
|
||||
from torch.fx import Node
|
||||
from colossalai.fx.graph_module import ColoGraphModule
|
||||
from colossalai.fx.profiler import activation_size, parameter_size
|
||||
from colossalai.fx.profiler.tensor import MetaTensor
|
||||
import math
|
||||
from .linearize import linearize
|
||||
from .operation import ForwardCheck, ForwardEnable, ForwardNograd, Backward, Loss, Chain, Sequence, Function
|
||||
@@ -123,7 +124,9 @@ def _fwd_xbar(node: List[Node]) -> int:
|
||||
|
||||
xbar = 0
|
||||
for n in node:
|
||||
xbar += n.meta['fwd_mem_tmp'] + n.meta['fwd_mem_out']
|
||||
xbar += n.meta['fwd_mem_tmp']
|
||||
if any(map(lambda x: x.meta['save_fwd_in'], n.users)):
|
||||
xbar += n.meta['fwd_mem_out']
|
||||
return xbar
|
||||
|
||||
|
||||
@@ -177,10 +180,13 @@ def _get_bwd_mem_tmp(node: List[Node]) -> int:
|
||||
def _get_deps_size():
|
||||
deps_size = 0
|
||||
for k, v in deps.items():
|
||||
k: Node
|
||||
if v > 0:
|
||||
deps_size += k.meta['bwd_mem_out']
|
||||
if v == float('-inf'):
|
||||
deps_size -= k.meta['fwd_mem_tmp'] + k.meta['fwd_mem_out']
|
||||
deps_size -= k.meta['fwd_mem_tmp']
|
||||
if any(map(lambda x: x.meta['save_fwd_in'], k.users)):
|
||||
deps_size -= k.meta['fwd_mem_out']
|
||||
|
||||
return deps_size
|
||||
|
||||
@@ -333,8 +339,8 @@ def solver_rotor(gm: ColoGraphModule,
|
||||
"""
|
||||
|
||||
node_list = linearize(gm, cnode)
|
||||
mem_limit -= parameter_size(gm)
|
||||
mem_unit = mem_limit * (1.0 - eps) // mem_slots
|
||||
data = MetaTensor(data, fake_device=next(gm.parameters()).device)
|
||||
MetaInfoProp(gm).run(data)
|
||||
|
||||
chain: Chain = _construct_chain(node_list, data)
|
||||
|
Reference in New Issue
Block a user