mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-07 03:52:01 +00:00
[fx/meta/rpc] move _meta_registration.py to fx folder / register fx functions with compatibility checks / remove color debug (#1710)
* [fx] move meta registration * [fx] fix tests. * [fx] fix test. * [fx] fix. * [meta] refactor meta registration.py. * [fx] add compatibility descriptions. * [fx] polish import. * [fx] add a decorator. * [fx] fix tests. * [fx] remove print. * [fx] edit raise error. * [fx] edit raise error. * [fx] add type hint. * [fx] fix import in experimental. * [rpc] remove color debug. * [meta] fix naming.
This commit is contained in:
@@ -1,16 +1,20 @@
|
||||
from typing import List, Tuple
|
||||
import copy
|
||||
import torch
|
||||
from torch.fx import GraphModule, Node
|
||||
from colossalai.fx.graph_module import ColoGraphModule
|
||||
from colossalai.fx.profiler import parameter_size
|
||||
import math
|
||||
from .linearize import linearize
|
||||
from .operation import ForwardCheck, ForwardEnable, ForwardNograd, Backward, Loss, Chain, Sequence, Function, Offload, Prefetch
|
||||
from typing import List, Tuple
|
||||
|
||||
import torch
|
||||
from colossalai.fx import is_compatible_with_meta
|
||||
from colossalai.fx.codegen.activation_checkpoint_codegen import \
|
||||
_find_nested_ckpt_regions
|
||||
from colossalai.fx.graph_module import ColoGraphModule
|
||||
from colossalai.fx.passes.algorithms.ckpt_solver_rotor import (_compute_table, _construct_chain, _rec)
|
||||
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
|
||||
from colossalai.fx.codegen.activation_checkpoint_codegen import _find_nested_ckpt_regions
|
||||
from colossalai.fx.passes.algorithms.ckpt_solver_rotor import _construct_chain, _compute_table, _rec
|
||||
from colossalai import META_COMPATIBILITY
|
||||
from colossalai.fx.profiler import parameter_size
|
||||
from torch.fx import GraphModule, Node
|
||||
|
||||
from .linearize import linearize
|
||||
from .operation import (Backward, Chain, ForwardCheck, ForwardEnable, ForwardNograd, Function, Loss, Offload, Prefetch,
|
||||
Sequence)
|
||||
|
||||
INF = float("inf")
|
||||
|
||||
@@ -508,7 +512,7 @@ def solver_pofo(gm: ColoGraphModule,
|
||||
mem_limit -= parameter_size(gm)
|
||||
|
||||
# prepare data
|
||||
if META_COMPATIBILITY:
|
||||
if is_compatible_with_meta():
|
||||
from colossalai.fx.profiler import MetaTensor
|
||||
data = MetaTensor(data, fake_device=next(gm.parameters()).device)
|
||||
MetaInfoProp(gm).run(data)
|
||||
|
Reference in New Issue
Block a user