mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 19:40:28 +00:00
[fx] refactor memory utils and extend shard utils. (#1754)
* [fx] change memory.py to memory_utils.py. * [fx] add shard utils. * [fx] fix import. * [fx] check code style. * [fx] add comment. * [autoparallel] first move. * [fx] add time computations.
This commit is contained in:
@@ -11,7 +11,7 @@ from torch.utils._pytree import tree_map
|
||||
from .._compatibility import compatibility
|
||||
from .constants import ALIAS_ATEN, OUTPUT_SAVED_MOD, OUTPUT_SAVED_OPS
|
||||
from .dataflow import GraphInfo, Phase, autograd_graph_analysis, is_phase
|
||||
from .memory import activation_size, parameter_size
|
||||
from .memory_utils import activation_size, parameter_size
|
||||
from .opcount import flop_mapping
|
||||
from .tensor import MetaTensor
|
||||
|
||||
@@ -286,13 +286,13 @@ def _profile_meta(target: Callable, *args, **kwargs) -> Tuple[Tuple[Any, ...], G
|
||||
@compatibility(is_backward_compatible=True)
|
||||
def profile_function(target: 'Target', device: str = 'meta') -> Callable:
|
||||
"""
|
||||
Wrap a `call_function` node or `torch.nn.functional` in order to
|
||||
Wrap a `call_function` node or `torch.nn.functional` in order to
|
||||
record the memory cost and FLOPs of the execution.
|
||||
|
||||
|
||||
Warnings:
|
||||
You may only use tensors with `device=meta` for this wrapped function.
|
||||
Only original `torch.nn.functional` are available.
|
||||
|
||||
|
||||
Examples:
|
||||
>>> input = torch.rand(100, 100, 100, 100, device='meta')
|
||||
>>> func = torch.nn.functional.relu
|
||||
@@ -342,7 +342,7 @@ def profile_function(target: 'Target', device: str = 'meta') -> Callable:
|
||||
def profile_method(target: 'Target', device: str = 'meta') -> Callable:
|
||||
"""
|
||||
Wrap a `call_method` node
|
||||
record the memory cost and FLOPs of the execution.
|
||||
record the memory cost and FLOPs of the execution.
|
||||
"""
|
||||
|
||||
def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any:
|
||||
@@ -360,13 +360,13 @@ def profile_method(target: 'Target', device: str = 'meta') -> Callable:
|
||||
@compatibility(is_backward_compatible=True)
|
||||
def profile_module(module: torch.nn.Module, device: str = 'meta') -> Callable:
|
||||
"""
|
||||
Wrap a `call_module` node or `torch.nn` in order to
|
||||
Wrap a `call_module` node or `torch.nn` in order to
|
||||
record the memory cost and FLOPs of the execution.
|
||||
|
||||
|
||||
Warnings:
|
||||
You may only use tensors with `device=meta` for this wrapped function.
|
||||
Only original `torch.nn` are available.
|
||||
|
||||
|
||||
Example:
|
||||
>>> input = torch.rand(4, 3, 224, 224, device='meta')
|
||||
>>> mod = torch.nn.Conv2d(3, 128, 3)
|
||||
|
Reference in New Issue
Block a user