[misc] update pre-commit and run all files (#4752)

* [misc] update pre-commit

* [misc] run pre-commit

* [misc] remove useless configuration files

* [misc] ignore cuda for clang-format
This commit is contained in:
Hongxin Liu
2023-09-19 14:20:26 +08:00
committed by GitHub
parent 3c6b831c26
commit 079bf3cb26
1268 changed files with 50037 additions and 38444 deletions

View File

@@ -9,7 +9,7 @@ from ..memory_utils import activation_size
from .constants import INPLACE_METHOD, INPLACE_OPS, NON_INPLACE_METHOD
from .registry import meta_profiler_function, meta_profiler_module
__all__ = ['profile_function', 'profile_module', 'profile_method']
__all__ = ["profile_function", "profile_module", "profile_method"]
# this is for compatibility use
@@ -42,6 +42,7 @@ class GraphInfo:
bwd_mem_tmp (int): See the above illustration.
bwd_mem_out (int): See the above illustration.
"""
fwd_flop: int = 0
bwd_flop: int = 0
fwd_mem_in: int = 0
@@ -50,8 +51,7 @@ class GraphInfo:
bwd_mem_out: int = 0
CALL_FUNCTION_MSG = \
"""
CALL_FUNCTION_MSG = """
Colossal-AI hasn't supported profiling for {}, you might manually patch it with the following code.\n
from colossalai.fx.profiler.experimental import meta_profiler_function
@meta_profiler_function.register(YOUR_FUNCTION)
@@ -60,9 +60,8 @@ def profile_YOUR_FUNCTION(input: torch.Tensor, *args) -> Tuple[int, int]:
macs = ...
return flops, macs
"""
CALL_METHOD_MSG = 'Please check if {} is an inplace method. If so, add target to INPLACE_METHOD={}. Otherwise, add target to NON_INPLACE_METHOD={}'
CALL_MODULE_MSG = \
"""
CALL_METHOD_MSG = "Please check if {} is an inplace method. If so, add target to INPLACE_METHOD={}. Otherwise, add target to NON_INPLACE_METHOD={}"
CALL_MODULE_MSG = """
Colossal-AI hasn't supported profiling for {}, you might manually patch it with the following code.\n
from colossalai.fx.profiler.experimental import meta_profiler_module
@meta_profiler_module.register(YOUR_MODULE)
@@ -74,7 +73,7 @@ def profile_YOUR_MODULE(self: torch.nn.Module, input: torch.Tensor) -> Tuple[int
@compatibility(is_backward_compatible=True)
def profile_function(target: 'Target') -> Callable:
def profile_function(target: "Target") -> Callable:
"""
Wrap a `call_function` node or `torch.nn.functional` in order to
record the memory cost and FLOPs of the execution.
@@ -92,12 +91,13 @@ def profile_function(target: 'Target') -> Callable:
def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any:
assert meta_profiler_function.has(target) or meta_profiler_function.has(
target.__name__), CALL_FUNCTION_MSG.format(target)
target.__name__
), CALL_FUNCTION_MSG.format(target)
fwd_tmp = 0
fwd_out = 0
out = func(*args, **kwargs)
if target not in INPLACE_OPS and not kwargs.get('inplace', False):
if target not in INPLACE_OPS and not kwargs.get("inplace", False):
fwd_out = activation_size(out)
if meta_profiler_function.has(target):
profiler = meta_profiler_function.get(target)
@@ -112,7 +112,7 @@ def profile_function(target: 'Target') -> Callable:
@compatibility(is_backward_compatible=True)
def profile_method(target: 'Target') -> Callable:
def profile_method(target: "Target") -> Callable:
"""
Wrap a `call_method` node
record the memory cost and FLOPs of the execution.
@@ -126,11 +126,12 @@ def profile_method(target: 'Target') -> Callable:
self_obj, *args_tail = args
# execute the method and return the result
assert isinstance(target, str), f'{target} instance is not str.'
assert isinstance(target, str), f"{target} instance is not str."
out = getattr(self_obj, target)(*args_tail, **kwargs)
assert target in INPLACE_METHOD + NON_INPLACE_METHOD, CALL_METHOD_MSG.format(
target, INPLACE_METHOD, NON_INPLACE_METHOD)
target, INPLACE_METHOD, NON_INPLACE_METHOD
)
# call_method has no parameters and are MOSTLY(?) inplace, and has no FLOPs or MACs.
fwd_tmp = 0 if target in INPLACE_METHOD else activation_size(out)
fwd_out = 0 if target not in INPLACE_METHOD else activation_size(out)
@@ -161,7 +162,7 @@ def profile_module(module: torch.nn.Module) -> Callable:
fwd_tmp = 0
fwd_out = 0
out = func(*args, **kwargs)
if getattr(module, 'inplace', False):
if getattr(module, "inplace", False):
fwd_out = activation_size(out)
profiler = meta_profiler_module.get(type(module))
fwd_flop, _ = profiler(module, *args, **kwargs)