ColossalAI/colossalai/_analyzer/fx/symbolic_profile.py
Super Daniel fff98f06ed
[analyzer] a minimal implementation of static graph analyzer (#2852)
* [hotfix] meta tensor default device.

* [siu] add experimental submodules to main branch.

* [siu]

* [siu]

* [analyzer] init.

* [analyzer] readme.

* [analyzer] readme.

* [analyzer] readme.

* [analyzer] readme.

* [test] add test.

* Update symbolic_trace.py

* mark skip tests.

* try except.

* try except.

* try except.

* s

* init

* init

* fix

* skip

* skip

---------

Co-authored-by: Daniel Shao <superdainiu@MININT-PVARVID.fareast.corp.microsoft.com>
Co-authored-by: Daniel Shao <superdainiu@Daniels-Mac.local>
2023-03-10 13:21:05 +08:00

41 lines
983 B
Python

import torch
import torch.fx
from torch.fx import GraphModule
from .passes import ShapeProp, graph_profile_pass, shape_prop_pass
from .passes.graph_profile import FlopProfiler
def register_flop_count_impl(func):
def wrapper(impl):
FlopProfiler._custom_flop_count_impl[func] = impl
return impl
return wrapper
def register_shape_impl(func):
def wrapper(impl):
ShapeProp._custom_dispatch_func[func] = impl
return impl
return wrapper
def symbolic_profile(module: GraphModule, *args, verbose=False) -> GraphModule:
"""Symbolically profile a model with sample inputs.
Args:
module (GraphModule): The module to be profiled
args (Tuple): The sample inputs
verbose (bool): Whether to print the profiling result
Returns:
GraphModule: The profiled module
"""
module = shape_prop_pass(module, *args)
module = graph_profile_pass(module, *args, verbose=verbose)
return module