mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-01 09:07:51 +00:00
[fx] refactor memory utils and extend shard utils. (#1754)
* [fx] change memory.py to memory_utils.py. * [fx] add shard utils. * [fx] fix import. * [fx] check code style. * [fx] add comment. * [autoparallel] first move. * [fx] add time computations.
This commit is contained in:
@@ -3,12 +3,21 @@ from typing import Any, Dict, List, NamedTuple, Tuple
|
||||
|
||||
import torch
|
||||
import torch.fx
|
||||
from colossalai.fx._compatibility import compatibility
|
||||
from colossalai.fx.profiler import (GraphInfo, activation_size, calculate_fwd_in, calculate_fwd_out, calculate_fwd_tmp,
|
||||
profile_function, profile_method, profile_module)
|
||||
from torch.fx.node import Argument, Node, Target
|
||||
from torch.utils._pytree import tree_map
|
||||
|
||||
from colossalai.fx._compatibility import compatibility
|
||||
from colossalai.fx.profiler import (
|
||||
GraphInfo,
|
||||
activation_size,
|
||||
calculate_fwd_in,
|
||||
calculate_fwd_out,
|
||||
calculate_fwd_tmp,
|
||||
profile_function,
|
||||
profile_method,
|
||||
profile_module,
|
||||
)
|
||||
|
||||
|
||||
@compatibility(is_backward_compatible=True)
|
||||
class TensorMetadata(NamedTuple):
|
||||
@@ -52,7 +61,7 @@ class MetaInfoProp(torch.fx.Interpreter):
|
||||
DIM_HIDDEN = 16
|
||||
DIM_OUT = 16
|
||||
model = torch.nn.Sequential(
|
||||
torch.nn.Linear(DIM_IN, DIM_HIDDEN),
|
||||
torch.nn.Linear(DIM_IN, DIM_HIDDEN),
|
||||
torch.nn.Linear(DIM_HIDDEN, DIM_OUT),
|
||||
)
|
||||
input_sample = torch.rand(BATCH_SIZE, DIM_IN)
|
||||
@@ -60,9 +69,9 @@ class MetaInfoProp(torch.fx.Interpreter):
|
||||
interp = MetaInfoProp(gm)
|
||||
interp.run(input_sample)
|
||||
print(interp.summary(format='kb')) # don't panic if some statistics are 0.00 MB
|
||||
|
||||
|
||||
# output of above code is
|
||||
|
||||
|
||||
# output of above code is
|
||||
Op type Op Forward FLOPs Backward FLOPs FWD_OUT FWD_TMP BWD_OUT BWD_TMP
|
||||
----------- ------- --------------- ---------------- --------- --------- --------- ---------
|
||||
placeholder input_1 0 FLOPs 0 FLOPs 0.00 KB 0.00 KB 0.00 KB 0.00 KB
|
||||
@@ -248,8 +257,8 @@ class MetaInfoProp(torch.fx.Interpreter):
|
||||
|
||||
def summary(self, unit: str = 'MB') -> str:
|
||||
"""
|
||||
Summarizes the memory and FLOPs statistics of the `GraphModule` in
|
||||
tabular format. Note that this API requires the ``tabulate`` module
|
||||
Summarizes the memory and FLOPs statistics of the `GraphModule` in
|
||||
tabular format. Note that this API requires the ``tabulate`` module
|
||||
to be installed.
|
||||
"""
|
||||
# https://github.com/pytorch/pytorch/blob/master/torch/fx/graph.py
|
||||
|
Reference in New Issue
Block a user