[fx/meta/rpc] move _meta_registration.py to fx folder / register fx functions with compatibility checks / remove color debug (#1710)

* [fx] move meta registration

* [fx] fix tests.

* [fx] fix test.

* [fx] fix.

* [meta] refactor meta registration.py.

* [fx] add compatibility descriptions.

* [fx] polish import.

* [fx] add a decorator.

* [fx] fix tests.

* [fx] remove print.

* [fx] edit raise error.

* [fx] edit raise error.

* [fx] add type hint.

* [fx] fix import in experimental.

* [rpc] remove color debug.

* [meta] fix naming.
This commit is contained in:
Super Daniel
2022-10-18 10:44:23 +08:00
committed by GitHub
parent e8d8eda5e7
commit 393f594051
32 changed files with 351 additions and 310 deletions

View File

@@ -1,15 +1,19 @@
from dataclasses import dataclass
from typing import Callable, Any, Dict, Tuple
from typing import Any, Callable, Dict, Tuple
import torch
from torch.fx.node import Argument, Target
from . import meta_profiler_function, meta_profiler_module
from ..._compatibility import compatibility
from ..memory import activation_size
from ..constant import INPLACE_METHOD, NON_INPLACE_METHOD, INPLACE_OPS
from .constants import INPLACE_METHOD, INPLACE_OPS, NON_INPLACE_METHOD
from .registry import meta_profiler_function, meta_profiler_module
__all__ = ['profile_function', 'profile_module', 'profile_method']
# this is for compatibility use
@compatibility(is_backward_compatible=True)
@dataclass
class GraphInfo:
"""
@@ -69,6 +73,7 @@ def profile_YOUR_MODULE(self: torch.nn.Module, input: torch.Tensor) -> Tuple[int
"""
@compatibility(is_backward_compatible=True)
def profile_function(target: 'Target') -> Callable:
"""
Wrap a `call_function` node or `torch.nn.functional` in order to
@@ -106,6 +111,7 @@ def profile_function(target: 'Target') -> Callable:
return f
@compatibility(is_backward_compatible=True)
def profile_method(target: 'Target') -> Callable:
"""
Wrap a `call_method` node
@@ -133,6 +139,7 @@ def profile_method(target: 'Target') -> Callable:
return f
@compatibility(is_backward_compatible=True)
def profile_module(module: torch.nn.Module) -> Callable:
"""
Wrap a `call_module` node or `torch.nn` in order to