mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-03 10:06:44 +00:00
[fx] refactor code for profiler / enable fake tensor movement. (#1646)
* [fx/profiling] provide summary for MetaInfoProp. * [fx/profiler] provide a table of summary. * [fx/profiler] provide a table of summary. * [fx/profiler] provide a table of summary. * [fx/profiler] provide a table of summary. * [fx] optimize table repr. * [fx] optimize table repr. * [fx] refactor code for profiler. * [fx] add docstring. * [fx] add docstring. * [fx] skip test. * [fx] redo. * [fx] redo. * [fx] fix import error for torch11. * [fx] fix import error for torch11.
This commit is contained in:
@@ -10,6 +10,7 @@ from .operation import ForwardCheck, ForwardEnable, ForwardNograd, Backward, Los
|
||||
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
|
||||
from colossalai.fx.codegen.activation_checkpoint_codegen import _find_nested_ckpt_regions
|
||||
from colossalai.fx.passes.algorithms.ckpt_solver_rotor import _construct_chain, _compute_table, _rec
|
||||
from colossalai import META_COMPATIBILITY
|
||||
|
||||
INF = float("inf")
|
||||
|
||||
@@ -507,6 +508,9 @@ def solver_pofo(gm: ColoGraphModule,
|
||||
mem_limit -= parameter_size(gm)
|
||||
|
||||
# prepare data
|
||||
if META_COMPATIBILITY:
|
||||
from colossalai.fx.profiler import MetaTensor
|
||||
data = MetaTensor(data, fake_device=next(gm.parameters()).device)
|
||||
MetaInfoProp(gm).run(data)
|
||||
chain: Chain = _construct_chain(node_list, data)
|
||||
chain = _normalize_flops(chain, flops)
|
||||
|
@@ -2,12 +2,12 @@ from typing import List, Tuple
|
||||
from torch.fx import Node
|
||||
from colossalai.fx.graph_module import ColoGraphModule
|
||||
from colossalai.fx.profiler import activation_size, parameter_size
|
||||
from colossalai.fx.profiler.tensor import MetaTensor
|
||||
import math
|
||||
from .linearize import linearize
|
||||
from .operation import ForwardCheck, ForwardEnable, ForwardNograd, Backward, Loss, Chain, Sequence, Function
|
||||
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
|
||||
from colossalai.fx.codegen.activation_checkpoint_codegen import _find_nested_ckpt_regions
|
||||
from colossalai import META_COMPATIBILITY
|
||||
|
||||
|
||||
# this is the python compute table code from rotor
|
||||
@@ -340,7 +340,9 @@ def solver_rotor(gm: ColoGraphModule,
|
||||
|
||||
node_list = linearize(gm, cnode)
|
||||
mem_unit = mem_limit * (1.0 - eps) // mem_slots
|
||||
data = MetaTensor(data, fake_device=next(gm.parameters()).device)
|
||||
if META_COMPATIBILITY:
|
||||
from colossalai.fx.profiler import MetaTensor
|
||||
data = MetaTensor(data, fake_device=next(gm.parameters()).device)
|
||||
MetaInfoProp(gm).run(data)
|
||||
|
||||
chain: Chain = _construct_chain(node_list, data)
|
||||
|
Reference in New Issue
Block a user