[fx] refactor memory utils and extend shard utils. (#1754)

* [fx] change memory.py to memory_utils.py.

* [fx] add shard utils.

* [fx] fix import.

* [fx] check code style.

* [fx] add comment.

* [autoparallel] first move.

* [fx] add time computations.
This commit is contained in:
Super Daniel
2022-10-26 14:24:41 +08:00
committed by GitHub
parent 63f250bbd4
commit 0584654c79
14 changed files with 177 additions and 122 deletions

View File

@@ -11,7 +11,7 @@ from torch.utils._pytree import tree_map
from .._compatibility import compatibility
from .constants import ALIAS_ATEN, OUTPUT_SAVED_MOD, OUTPUT_SAVED_OPS
from .dataflow import GraphInfo, Phase, autograd_graph_analysis, is_phase
from .memory import activation_size, parameter_size
from .memory_utils import activation_size, parameter_size
from .opcount import flop_mapping
from .tensor import MetaTensor
@@ -286,13 +286,13 @@ def _profile_meta(target: Callable, *args, **kwargs) -> Tuple[Tuple[Any, ...], G
@compatibility(is_backward_compatible=True)
def profile_function(target: 'Target', device: str = 'meta') -> Callable:
"""
Wrap a `call_function` node or `torch.nn.functional` in order to
Wrap a `call_function` node or `torch.nn.functional` in order to
record the memory cost and FLOPs of the execution.
Warnings:
You may only use tensors with `device=meta` for this wrapped function.
Only original `torch.nn.functional` are available.
Examples:
>>> input = torch.rand(100, 100, 100, 100, device='meta')
>>> func = torch.nn.functional.relu
@@ -342,7 +342,7 @@ def profile_function(target: 'Target', device: str = 'meta') -> Callable:
def profile_method(target: 'Target', device: str = 'meta') -> Callable:
"""
Wrap a `call_method` node
record the memory cost and FLOPs of the execution.
record the memory cost and FLOPs of the execution.
"""
def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any:
@@ -360,13 +360,13 @@ def profile_method(target: 'Target', device: str = 'meta') -> Callable:
@compatibility(is_backward_compatible=True)
def profile_module(module: torch.nn.Module, device: str = 'meta') -> Callable:
"""
Wrap a `call_module` node or `torch.nn` in order to
Wrap a `call_module` node or `torch.nn` in order to
record the memory cost and FLOPs of the execution.
Warnings:
You may only use tensors with `device=meta` for this wrapped function.
Only original `torch.nn` are available.
Example:
>>> input = torch.rand(4, 3, 224, 224, device='meta')
>>> mod = torch.nn.Conv2d(3, 128, 3)