mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-12 12:47:21 +00:00
[fx] refactor memory utils and extend shard utils. (#1754)
* [fx] change memory.py to memory_utils.py. * [fx] add shard utils. * [fx] fix import. * [fx] check code style. * [fx] add comment. * [autoparallel] first move. * [fx] add time computations.
This commit is contained in:
@@ -3,11 +3,12 @@ from typing import Any, Dict, List, NamedTuple, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.fx
|
||||
from colossalai.fx._compatibility import compatibility
|
||||
from colossalai.fx.profiler import (GraphInfo, profile_function, profile_method, profile_module)
|
||||
from torch.fx.node import Argument, Node, Target
|
||||
from torch.utils._pytree import tree_flatten
|
||||
|
||||
from colossalai.fx._compatibility import compatibility
|
||||
from colossalai.fx.profiler import GraphInfo, profile_function, profile_method, profile_module
|
||||
|
||||
|
||||
@compatibility(is_backward_compatible=True)
|
||||
class ConcreteInfoProp(torch.fx.Interpreter):
|
||||
@@ -22,17 +23,17 @@ class ConcreteInfoProp(torch.fx.Interpreter):
|
||||
DIM_HIDDEN = 16
|
||||
DIM_OUT = 16
|
||||
model = torch.nn.Sequential(
|
||||
torch.nn.Linear(DIM_IN, DIM_HIDDEN),
|
||||
torch.nn.Linear(DIM_IN, DIM_HIDDEN),
|
||||
torch.nn.Linear(DIM_HIDDEN, DIM_OUT),
|
||||
).cuda()
|
||||
input_sample = torch.rand(BATCH_SIZE, DIM_IN, device="cuda")
|
||||
gm = symbolic_trace(model)
|
||||
interp = ConcreteInfoProp(gm)
|
||||
interp.run(input_sample)
|
||||
print(interp.summary(unit='kb'))
|
||||
|
||||
|
||||
output of above code is
|
||||
print(interp.summary(unit='kb'))
|
||||
|
||||
|
||||
output of above code is
|
||||
Op type Op Forward time Backward time SAVE_FWD_IN FWD_OUT FWD_TMP BWD_OUT BWD_TMP
|
||||
----------- ------- ----------------------- ------------------------ ------------- --------- --------- --------- ---------
|
||||
placeholder input_1 0.0 s 0.0 s False 0.00 KB 0.00 KB 0.00 KB 0.00 KB
|
||||
@@ -229,8 +230,8 @@ class ConcreteInfoProp(torch.fx.Interpreter):
|
||||
|
||||
def summary(self, unit: str = 'MB') -> str:
|
||||
"""
|
||||
Summarizes the memory and FLOPs statistics of the `GraphModule` in
|
||||
tabular format. Note that this API requires the ``tabulate`` module
|
||||
Summarizes the memory and FLOPs statistics of the `GraphModule` in
|
||||
tabular format. Note that this API requires the ``tabulate`` module
|
||||
to be installed.
|
||||
"""
|
||||
# https://github.com/pytorch/pytorch/blob/master/torch/fx/graph.py
|
||||
|
Reference in New Issue
Block a user