[fx/meta/rpc] move _meta_registration.py to fx folder / register fx functions with compatibility checks / remove color debug (#1710)

* [fx] move meta registration

* [fx] fix tests.

* [fx] fix test.

* [fx] fix.

* [meta] refactor meta registration.py.

* [fx] add compatibility descriptions.

* [fx] polish import.

* [fx] add a decorator.

* [fx] fix tests.

* [fx] remove print.

* [fx] edit raise error.

* [fx] edit raise error.

* [fx] add type hint.

* [fx] fix import in experimental.

* [rpc] remove color debug.

* [meta] fix naming.
This commit is contained in:
Super Daniel
2022-10-18 10:44:23 +08:00
committed by GitHub
parent e8d8eda5e7
commit 393f594051
32 changed files with 351 additions and 310 deletions

View File

@@ -1,16 +1,20 @@
from typing import List, Tuple
import copy
import torch
from torch.fx import GraphModule, Node
from colossalai.fx.graph_module import ColoGraphModule
from colossalai.fx.profiler import parameter_size
import math
from .linearize import linearize
from .operation import ForwardCheck, ForwardEnable, ForwardNograd, Backward, Loss, Chain, Sequence, Function, Offload, Prefetch
from typing import List, Tuple
import torch
from colossalai.fx import is_compatible_with_meta
from colossalai.fx.codegen.activation_checkpoint_codegen import \
_find_nested_ckpt_regions
from colossalai.fx.graph_module import ColoGraphModule
from colossalai.fx.passes.algorithms.ckpt_solver_rotor import (_compute_table, _construct_chain, _rec)
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
from colossalai.fx.codegen.activation_checkpoint_codegen import _find_nested_ckpt_regions
from colossalai.fx.passes.algorithms.ckpt_solver_rotor import _construct_chain, _compute_table, _rec
from colossalai import META_COMPATIBILITY
from colossalai.fx.profiler import parameter_size
from torch.fx import GraphModule, Node
from .linearize import linearize
from .operation import (Backward, Chain, ForwardCheck, ForwardEnable, ForwardNograd, Function, Loss, Offload, Prefetch,
Sequence)
INF = float("inf")
@@ -508,7 +512,7 @@ def solver_pofo(gm: ColoGraphModule,
mem_limit -= parameter_size(gm)
# prepare data
if META_COMPATIBILITY:
if is_compatible_with_meta():
from colossalai.fx.profiler import MetaTensor
data = MetaTensor(data, fake_device=next(gm.parameters()).device)
MetaInfoProp(gm).run(data)