[fx] refactor memory utils and extend shard utils. (#1754)

* [fx] change memory.py to memory_utils.py.

* [fx] add shard utils.

* [fx] fix import.

* [fx] check code style.

* [fx] add comment.

* [autoparallel] first move.

* [fx] add time computations.
This commit is contained in:
Super Daniel
2022-10-26 14:24:41 +08:00
committed by GitHub
parent 63f250bbd4
commit 0584654c79
14 changed files with 177 additions and 122 deletions

View File

@@ -3,12 +3,21 @@ from typing import Any, Dict, List, NamedTuple, Tuple
import torch
import torch.fx
from colossalai.fx._compatibility import compatibility
from colossalai.fx.profiler import (GraphInfo, activation_size, calculate_fwd_in, calculate_fwd_out, calculate_fwd_tmp,
profile_function, profile_method, profile_module)
from torch.fx.node import Argument, Node, Target
from torch.utils._pytree import tree_map
from colossalai.fx._compatibility import compatibility
from colossalai.fx.profiler import (
GraphInfo,
activation_size,
calculate_fwd_in,
calculate_fwd_out,
calculate_fwd_tmp,
profile_function,
profile_method,
profile_module,
)
@compatibility(is_backward_compatible=True)
class TensorMetadata(NamedTuple):
@@ -52,7 +61,7 @@ class MetaInfoProp(torch.fx.Interpreter):
DIM_HIDDEN = 16
DIM_OUT = 16
model = torch.nn.Sequential(
torch.nn.Linear(DIM_IN, DIM_HIDDEN),
torch.nn.Linear(DIM_IN, DIM_HIDDEN),
torch.nn.Linear(DIM_HIDDEN, DIM_OUT),
)
input_sample = torch.rand(BATCH_SIZE, DIM_IN)
@@ -60,9 +69,9 @@ class MetaInfoProp(torch.fx.Interpreter):
interp = MetaInfoProp(gm)
interp.run(input_sample)
print(interp.summary(format='kb')) # don't panic if some statistics are 0.00 MB
# output of above code is
# output of above code is
Op type Op Forward FLOPs Backward FLOPs FWD_OUT FWD_TMP BWD_OUT BWD_TMP
----------- ------- --------------- ---------------- --------- --------- --------- ---------
placeholder input_1 0 FLOPs 0 FLOPs 0.00 KB 0.00 KB 0.00 KB 0.00 KB
@@ -248,8 +257,8 @@ class MetaInfoProp(torch.fx.Interpreter):
def summary(self, unit: str = 'MB') -> str:
"""
Summarizes the memory and FLOPs statistics of the `GraphModule` in
tabular format. Note that this API requires the ``tabulate`` module
Summarizes the memory and FLOPs statistics of the `GraphModule` in
tabular format. Note that this API requires the ``tabulate`` module
to be installed.
"""
# https://github.com/pytorch/pytorch/blob/master/torch/fx/graph.py