[fx] refactor code for profiler / enable fake tensor movement. (#1646)

* [fx/profiling] provide summary for MetaInfoProp.

* [fx/profiler] provide a table of summary.

* [fx/profiler] provide a table of summary.

* [fx/profiler] provide a table of summary.

* [fx/profiler] provide a table of summary.

* [fx] optimize table repr.

* [fx] optimize table repr.

* [fx] refactor code for profiler.

* [fx] add docstring.

* [fx] add docstring.

* [fx] skip test.

* [fx] redo.

* [fx] redo.

* [fx] fix import error for torch11.

* [fx] fix import error for torch11.
This commit is contained in:
Super Daniel
2022-09-27 10:26:52 +08:00
committed by GitHub
parent 5d0fdb9cb4
commit 6135e178b3
5 changed files with 103 additions and 67 deletions

View File

@@ -10,6 +10,7 @@ from .operation import ForwardCheck, ForwardEnable, ForwardNograd, Backward, Los
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
from colossalai.fx.codegen.activation_checkpoint_codegen import _find_nested_ckpt_regions
from colossalai.fx.passes.algorithms.ckpt_solver_rotor import _construct_chain, _compute_table, _rec
from colossalai import META_COMPATIBILITY
INF = float("inf")
@@ -507,6 +508,9 @@ def solver_pofo(gm: ColoGraphModule,
mem_limit -= parameter_size(gm)
# prepare data
if META_COMPATIBILITY:
from colossalai.fx.profiler import MetaTensor
data = MetaTensor(data, fake_device=next(gm.parameters()).device)
MetaInfoProp(gm).run(data)
chain: Chain = _construct_chain(node_list, data)
chain = _normalize_flops(chain, flops)

View File

@@ -2,12 +2,12 @@ from typing import List, Tuple
from torch.fx import Node
from colossalai.fx.graph_module import ColoGraphModule
from colossalai.fx.profiler import activation_size, parameter_size
from colossalai.fx.profiler.tensor import MetaTensor
import math
from .linearize import linearize
from .operation import ForwardCheck, ForwardEnable, ForwardNograd, Backward, Loss, Chain, Sequence, Function
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
from colossalai.fx.codegen.activation_checkpoint_codegen import _find_nested_ckpt_regions
from colossalai import META_COMPATIBILITY
# this is the python compute table code from rotor
@@ -340,7 +340,9 @@ def solver_rotor(gm: ColoGraphModule,
node_list = linearize(gm, cnode)
mem_unit = mem_limit * (1.0 - eps) // mem_slots
data = MetaTensor(data, fake_device=next(gm.parameters()).device)
if META_COMPATIBILITY:
from colossalai.fx.profiler import MetaTensor
data = MetaTensor(data, fake_device=next(gm.parameters()).device)
MetaInfoProp(gm).run(data)
chain: Chain = _construct_chain(node_list, data)