mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-17 23:18:36 +00:00
[fx/meta/rpc] move _meta_registration.py to fx folder / register fx functions with compatibility checks / remove color debug (#1710)
* [fx] move meta registration * [fx] fix tests. * [fx] fix test. * [fx] fix. * [meta] refactor meta registration.py. * [fx] add compatibility descriptions. * [fx] polish import. * [fx] add a decorator. * [fx] fix tests. * [fx] remove print. * [fx] edit raise error. * [fx] edit raise error. * [fx] add type hint. * [fx] fix import in experimental. * [rpc] remove color debug. * [meta] fix naming.
This commit is contained in:
@@ -1,15 +1,19 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable, Any, Dict, Tuple
|
||||
from typing import Any, Callable, Dict, Tuple
|
||||
|
||||
import torch
|
||||
from torch.fx.node import Argument, Target
|
||||
from . import meta_profiler_function, meta_profiler_module
|
||||
|
||||
from ..._compatibility import compatibility
|
||||
from ..memory import activation_size
|
||||
from ..constant import INPLACE_METHOD, NON_INPLACE_METHOD, INPLACE_OPS
|
||||
from .constants import INPLACE_METHOD, INPLACE_OPS, NON_INPLACE_METHOD
|
||||
from .registry import meta_profiler_function, meta_profiler_module
|
||||
|
||||
__all__ = ['profile_function', 'profile_module', 'profile_method']
|
||||
|
||||
|
||||
# this is for compatibility use
|
||||
@compatibility(is_backward_compatible=True)
|
||||
@dataclass
|
||||
class GraphInfo:
|
||||
"""
|
||||
@@ -69,6 +73,7 @@ def profile_YOUR_MODULE(self: torch.nn.Module, input: torch.Tensor) -> Tuple[int
|
||||
"""
|
||||
|
||||
|
||||
@compatibility(is_backward_compatible=True)
|
||||
def profile_function(target: 'Target') -> Callable:
|
||||
"""
|
||||
Wrap a `call_function` node or `torch.nn.functional` in order to
|
||||
@@ -106,6 +111,7 @@ def profile_function(target: 'Target') -> Callable:
|
||||
return f
|
||||
|
||||
|
||||
@compatibility(is_backward_compatible=True)
|
||||
def profile_method(target: 'Target') -> Callable:
|
||||
"""
|
||||
Wrap a `call_method` node
|
||||
@@ -133,6 +139,7 @@ def profile_method(target: 'Target') -> Callable:
|
||||
return f
|
||||
|
||||
|
||||
@compatibility(is_backward_compatible=True)
|
||||
def profile_module(module: torch.nn.Module) -> Callable:
|
||||
"""
|
||||
Wrap a `call_module` node or `torch.nn` in order to
|
||||
|
Reference in New Issue
Block a user