mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-20 12:43:55 +00:00
[fx] metainfo_trace as an API. (#1873)
* [fx] metainfo_trace as an API. * [fx] add return.
This commit is contained in:
parent
6d559ea614
commit
448248b27c
@ -1,4 +1,4 @@
|
|||||||
from ._compatibility import compatibility, is_compatible_with_meta
|
from ._compatibility import compatibility, is_compatible_with_meta
|
||||||
from .graph_module import ColoGraphModule
|
from .graph_module import ColoGraphModule
|
||||||
from .passes import MetaInfoProp
|
from .passes import MetaInfoProp, metainfo_trace
|
||||||
from .tracer import ColoTracer, meta_trace, symbolic_trace
|
from .tracer import ColoTracer, meta_trace, symbolic_trace
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
from .adding_split_node_pass import balanced_split_pass, split_with_split_nodes_pass
|
from .adding_split_node_pass import balanced_split_pass, split_with_split_nodes_pass
|
||||||
from .shard_1d_pass import column_shard_linear_pass, row_shard_linear_pass
|
|
||||||
from .meta_info_prop import MetaInfoProp
|
|
||||||
from .concrete_info_prop import ConcreteInfoProp
|
from .concrete_info_prop import ConcreteInfoProp
|
||||||
|
from .meta_info_prop import MetaInfoProp, metainfo_trace
|
||||||
|
from .shard_1d_pass import column_shard_linear_pass, row_shard_linear_pass
|
||||||
|
@ -6,7 +6,7 @@ import torch.fx
|
|||||||
from torch.fx.node import Argument, Node, Target
|
from torch.fx.node import Argument, Node, Target
|
||||||
from torch.utils._pytree import tree_map
|
from torch.utils._pytree import tree_map
|
||||||
|
|
||||||
from colossalai.fx._compatibility import compatibility
|
from colossalai.fx._compatibility import compatibility, is_compatible_with_meta
|
||||||
from colossalai.fx.profiler import (
|
from colossalai.fx.profiler import (
|
||||||
GraphInfo,
|
GraphInfo,
|
||||||
activation_size,
|
activation_size,
|
||||||
@ -315,3 +315,38 @@ class MetaInfoProp(torch.fx.Interpreter):
|
|||||||
]
|
]
|
||||||
|
|
||||||
return tabulate(node_summaries, headers=headers, stralign='right')
|
return tabulate(node_summaries, headers=headers, stralign='right')
|
||||||
|
|
||||||
|
|
||||||
|
def metainfo_trace(gm: torch.fx.GraphModule, *args, verbose: bool = False, unit: str = "MB", **kwargs) -> None:
|
||||||
|
"""
|
||||||
|
MetaInfo tracing API
|
||||||
|
|
||||||
|
Given a ``GraphModule`` and a sample input, this API will trace the MetaInfo of a single training cycle,
|
||||||
|
and annotate them on ``gm.graph``.
|
||||||
|
|
||||||
|
Uses:
|
||||||
|
>>> model = ...
|
||||||
|
>>> gm = symbolic_trace(model)
|
||||||
|
>>> args = ... # sample input to the ``GraphModule``
|
||||||
|
>>> metainfo_trace(gm, *args)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
gm (torch.fx.GraphModule): The ``GraphModule`` to be annotated with MetaInfo.
|
||||||
|
verbose (bool, optional): Whether to show ``MetaInfoProp.summary()`. Defaults to False.
|
||||||
|
unit (str, optional): The unit of memory. Defaults to "MB".
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.fx.GraphModule: The ``GraphModule`` annotated with MetaInfo.
|
||||||
|
"""
|
||||||
|
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
|
||||||
|
interp = MetaInfoProp(gm.to(device))
|
||||||
|
if is_compatible_with_meta():
|
||||||
|
from colossalai.fx.profiler import MetaTensor
|
||||||
|
args = tree_map(lambda x: MetaTensor(x, fake_device=device), args)
|
||||||
|
kwargs = tree_map(lambda x: MetaTensor(x, fake_device=device), kwargs)
|
||||||
|
interp.propagate(*args, **kwargs)
|
||||||
|
if verbose:
|
||||||
|
interp.summary(unit)
|
||||||
|
gm.to('cpu')
|
||||||
|
del interp
|
||||||
|
return gm
|
||||||
|
Loading…
Reference in New Issue
Block a user