mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-11 05:49:55 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -9,7 +9,7 @@ from ..memory_utils import activation_size
|
||||
from .constants import INPLACE_METHOD, INPLACE_OPS, NON_INPLACE_METHOD
|
||||
from .registry import meta_profiler_function, meta_profiler_module
|
||||
|
||||
__all__ = ['profile_function', 'profile_module', 'profile_method']
|
||||
__all__ = ["profile_function", "profile_module", "profile_method"]
|
||||
|
||||
|
||||
# this is for compatibility use
|
||||
@@ -42,6 +42,7 @@ class GraphInfo:
|
||||
bwd_mem_tmp (int): See the above illustration.
|
||||
bwd_mem_out (int): See the above illustration.
|
||||
"""
|
||||
|
||||
fwd_flop: int = 0
|
||||
bwd_flop: int = 0
|
||||
fwd_mem_in: int = 0
|
||||
@@ -50,8 +51,7 @@ class GraphInfo:
|
||||
bwd_mem_out: int = 0
|
||||
|
||||
|
||||
CALL_FUNCTION_MSG = \
|
||||
"""
|
||||
CALL_FUNCTION_MSG = """
|
||||
Colossal-AI hasn't supported profiling for {}, you might manually patch it with the following code.\n
|
||||
from colossalai.fx.profiler.experimental import meta_profiler_function
|
||||
@meta_profiler_function.register(YOUR_FUNCTION)
|
||||
@@ -60,9 +60,8 @@ def profile_YOUR_FUNCTION(input: torch.Tensor, *args) -> Tuple[int, int]:
|
||||
macs = ...
|
||||
return flops, macs
|
||||
"""
|
||||
CALL_METHOD_MSG = 'Please check if {} is an inplace method. If so, add target to INPLACE_METHOD={}. Otherwise, add target to NON_INPLACE_METHOD={}'
|
||||
CALL_MODULE_MSG = \
|
||||
"""
|
||||
CALL_METHOD_MSG = "Please check if {} is an inplace method. If so, add target to INPLACE_METHOD={}. Otherwise, add target to NON_INPLACE_METHOD={}"
|
||||
CALL_MODULE_MSG = """
|
||||
Colossal-AI hasn't supported profiling for {}, you might manually patch it with the following code.\n
|
||||
from colossalai.fx.profiler.experimental import meta_profiler_module
|
||||
@meta_profiler_module.register(YOUR_MODULE)
|
||||
@@ -74,7 +73,7 @@ def profile_YOUR_MODULE(self: torch.nn.Module, input: torch.Tensor) -> Tuple[int
|
||||
|
||||
|
||||
@compatibility(is_backward_compatible=True)
|
||||
def profile_function(target: 'Target') -> Callable:
|
||||
def profile_function(target: "Target") -> Callable:
|
||||
"""
|
||||
Wrap a `call_function` node or `torch.nn.functional` in order to
|
||||
record the memory cost and FLOPs of the execution.
|
||||
@@ -92,12 +91,13 @@ def profile_function(target: 'Target') -> Callable:
|
||||
|
||||
def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any:
|
||||
assert meta_profiler_function.has(target) or meta_profiler_function.has(
|
||||
target.__name__), CALL_FUNCTION_MSG.format(target)
|
||||
target.__name__
|
||||
), CALL_FUNCTION_MSG.format(target)
|
||||
|
||||
fwd_tmp = 0
|
||||
fwd_out = 0
|
||||
out = func(*args, **kwargs)
|
||||
if target not in INPLACE_OPS and not kwargs.get('inplace', False):
|
||||
if target not in INPLACE_OPS and not kwargs.get("inplace", False):
|
||||
fwd_out = activation_size(out)
|
||||
if meta_profiler_function.has(target):
|
||||
profiler = meta_profiler_function.get(target)
|
||||
@@ -112,7 +112,7 @@ def profile_function(target: 'Target') -> Callable:
|
||||
|
||||
|
||||
@compatibility(is_backward_compatible=True)
|
||||
def profile_method(target: 'Target') -> Callable:
|
||||
def profile_method(target: "Target") -> Callable:
|
||||
"""
|
||||
Wrap a `call_method` node
|
||||
record the memory cost and FLOPs of the execution.
|
||||
@@ -126,11 +126,12 @@ def profile_method(target: 'Target') -> Callable:
|
||||
self_obj, *args_tail = args
|
||||
|
||||
# execute the method and return the result
|
||||
assert isinstance(target, str), f'{target} instance is not str.'
|
||||
assert isinstance(target, str), f"{target} instance is not str."
|
||||
|
||||
out = getattr(self_obj, target)(*args_tail, **kwargs)
|
||||
assert target in INPLACE_METHOD + NON_INPLACE_METHOD, CALL_METHOD_MSG.format(
|
||||
target, INPLACE_METHOD, NON_INPLACE_METHOD)
|
||||
target, INPLACE_METHOD, NON_INPLACE_METHOD
|
||||
)
|
||||
# call_method has no parameters and are MOSTLY(?) inplace, and has no FLOPs or MACs.
|
||||
fwd_tmp = 0 if target in INPLACE_METHOD else activation_size(out)
|
||||
fwd_out = 0 if target not in INPLACE_METHOD else activation_size(out)
|
||||
@@ -161,7 +162,7 @@ def profile_module(module: torch.nn.Module) -> Callable:
|
||||
fwd_tmp = 0
|
||||
fwd_out = 0
|
||||
out = func(*args, **kwargs)
|
||||
if getattr(module, 'inplace', False):
|
||||
if getattr(module, "inplace", False):
|
||||
fwd_out = activation_size(out)
|
||||
profiler = meta_profiler_module.get(type(module))
|
||||
fwd_flop, _ = profiler(module, *args, **kwargs)
|
||||
|
Reference in New Issue
Block a user