mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-25 06:52:46 +00:00
* [hotfix] meta tensor default device. * [siu] add experimental submodules to main branch. * [siu] * [siu] * [analyzer] init. * [analyzer] readme. * [analyzer] readme. * [analyzer] readme. * [analyzer] readme. * [test] add test. * Update symbolic_trace.py * mark skip tests. * try except. * try except. * try except. * s * init * init * fix * skip * skip --------- Co-authored-by: Daniel Shao <superdainiu@MININT-PVARVID.fareast.corp.microsoft.com> Co-authored-by: Daniel Shao <superdainiu@Daniels-Mac.local>
41 lines
983 B
Python
41 lines
983 B
Python
import torch
|
|
import torch.fx
|
|
from torch.fx import GraphModule
|
|
|
|
from .passes import ShapeProp, graph_profile_pass, shape_prop_pass
|
|
from .passes.graph_profile import FlopProfiler
|
|
|
|
|
|
def register_flop_count_impl(func):
|
|
|
|
def wrapper(impl):
|
|
FlopProfiler._custom_flop_count_impl[func] = impl
|
|
return impl
|
|
|
|
return wrapper
|
|
|
|
|
|
def register_shape_impl(func):
|
|
|
|
def wrapper(impl):
|
|
ShapeProp._custom_dispatch_func[func] = impl
|
|
return impl
|
|
|
|
return wrapper
|
|
|
|
|
|
def symbolic_profile(module: GraphModule, *args, verbose=False) -> GraphModule:
|
|
"""Symbolically profile a model with sample inputs.
|
|
|
|
Args:
|
|
module (GraphModule): The module to be profiled
|
|
args (Tuple): The sample inputs
|
|
verbose (bool): Whether to print the profiling result
|
|
|
|
Returns:
|
|
GraphModule: The profiled module
|
|
"""
|
|
module = shape_prop_pass(module, *args)
|
|
module = graph_profile_pass(module, *args, verbose=verbose)
|
|
return module
|