mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 11:02:05 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -15,7 +15,7 @@ from .memory_utils import activation_size, parameter_size
|
||||
from .opcount import flop_mapping
|
||||
from .tensor import MetaTensor
|
||||
|
||||
__all__ = ['profile_function', 'profile_module', 'profile_method']
|
||||
__all__ = ["profile_function", "profile_module", "profile_method"]
|
||||
|
||||
# super-dainiu: this cache should be global, otherwise it cannot
|
||||
# track duplicated tensors between nodes
|
||||
@@ -174,7 +174,6 @@ def _profile_meta(target: Callable, *args, **kwargs) -> Tuple[Tuple[Any, ...], G
|
||||
# backward is executed.
|
||||
# Hopefully, this attempt will provide a better estimation of memory.
|
||||
class FlopTensor(MetaTensor):
|
||||
|
||||
_node: Node = None
|
||||
|
||||
def __repr__(self):
|
||||
@@ -186,24 +185,24 @@ def _profile_meta(target: Callable, *args, **kwargs) -> Tuple[Tuple[Any, ...], G
|
||||
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
|
||||
args_node = tree_map(lambda x: x._node if isinstance(x, FlopTensor) else None, args)
|
||||
kwargs_node = tree_map(lambda x: x._node if isinstance(x, FlopTensor) else None, kwargs)
|
||||
node = subgraph.create_node('call_function', func, args_node, kwargs_node)
|
||||
node = subgraph.create_node("call_function", func, args_node, kwargs_node)
|
||||
|
||||
out = super().__torch_dispatch__(func, types, args, kwargs)
|
||||
|
||||
flop_count[phase] += flop_mapping[func](args, normalize_tuple(out))
|
||||
node.meta['phase'] = phase
|
||||
node.meta["phase"] = phase
|
||||
|
||||
# super-dainiu: in `nn.MultiheadAttention` this weird thing occurs,
|
||||
# i.e. `Phase.PLACEHOLDER` tensors are aliased and saved during
|
||||
# `Phase.FORWARD`
|
||||
if phase == Phase.FORWARD:
|
||||
if all(map(partial(is_phase, phase=Phase.PLACEHOLDER), node.all_input_nodes)) and func in ALIAS_ATEN:
|
||||
node.meta['phase'] = Phase.PLACEHOLDER
|
||||
node.meta["phase"] = Phase.PLACEHOLDER
|
||||
|
||||
# TODO(yby): specify `saved_tensors` for backward memory estimation
|
||||
node.meta['saved_tensor'] = []
|
||||
node.meta["saved_tensor"] = []
|
||||
if phase == Phase.BACKWARD:
|
||||
node.meta['saved_tensor'] = normalize_tuple(out)
|
||||
node.meta["saved_tensor"] = normalize_tuple(out)
|
||||
|
||||
def wrap(x):
|
||||
if isinstance(x, MetaTensor):
|
||||
@@ -219,11 +218,14 @@ def _profile_meta(target: Callable, *args, **kwargs) -> Tuple[Tuple[Any, ...], G
|
||||
x = FlopTensor(x)
|
||||
if is_autogradable(x):
|
||||
x.requires_grad_(True)
|
||||
x._node = subgraph.create_node('placeholder',
|
||||
'placeholder', (subgraph._root,),
|
||||
name=subgraph._graph_namespace.create_name('input', x._tensor))
|
||||
x._node.meta['phase'] = Phase.PLACEHOLDER
|
||||
x._node.meta['saved_tensor'] = []
|
||||
x._node = subgraph.create_node(
|
||||
"placeholder",
|
||||
"placeholder",
|
||||
(subgraph._root,),
|
||||
name=subgraph._graph_namespace.create_name("input", x._tensor),
|
||||
)
|
||||
x._node.meta["phase"] = Phase.PLACEHOLDER
|
||||
x._node.meta["saved_tensor"] = []
|
||||
return x
|
||||
|
||||
# Basically, we need to detach the args and kwargs from the outer graph.
|
||||
@@ -235,7 +237,7 @@ def _profile_meta(target: Callable, *args, **kwargs) -> Tuple[Tuple[Any, ...], G
|
||||
if isinstance(x, FlopTensor) and not x._tensor.data_ptr() in cache:
|
||||
tensor = x._tensor.detach()
|
||||
tensor.data_ptr = x._tensor.data_ptr
|
||||
x._node.meta['saved_tensor'] += [tensor]
|
||||
x._node.meta["saved_tensor"] += [tensor]
|
||||
if not do_not_cache:
|
||||
cache.add(x._tensor.data_ptr())
|
||||
return x
|
||||
@@ -284,7 +286,7 @@ def _profile_meta(target: Callable, *args, **kwargs) -> Tuple[Tuple[Any, ...], G
|
||||
|
||||
|
||||
@compatibility(is_backward_compatible=True)
|
||||
def profile_function(target: 'Target', device: str = 'meta') -> Callable:
|
||||
def profile_function(target: "Target", device: str = "meta") -> Callable:
|
||||
"""
|
||||
Wrap a `call_function` node or `torch.nn.functional` in order to
|
||||
record the memory cost and FLOPs of the execution.
|
||||
@@ -300,7 +302,6 @@ def profile_function(target: 'Target', device: str = 'meta') -> Callable:
|
||||
"""
|
||||
|
||||
def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any:
|
||||
|
||||
# find the grad for parameter in args and kwargs
|
||||
param_size = 0
|
||||
|
||||
@@ -316,18 +317,18 @@ def profile_function(target: 'Target', device: str = 'meta') -> Callable:
|
||||
# still run the profiling but discard some results regarding `target`
|
||||
global do_not_cache
|
||||
|
||||
inplace = kwargs.get('inplace', False)
|
||||
inplace = kwargs.get("inplace", False)
|
||||
if target in OUTPUT_SAVED_OPS:
|
||||
do_not_cache = True
|
||||
if inplace:
|
||||
do_not_cache = True
|
||||
kwargs['inplace'] = False
|
||||
if device == 'meta':
|
||||
kwargs["inplace"] = False
|
||||
if device == "meta":
|
||||
out, meta = _profile_meta(func, *args, **kwargs)
|
||||
else:
|
||||
out, meta = _profile_concrete(func, *args, **kwargs)
|
||||
if inplace:
|
||||
kwargs['inplace'] = True
|
||||
kwargs["inplace"] = True
|
||||
meta.bwd_mem_tmp = 0
|
||||
meta.bwd_mem_out = 0
|
||||
do_not_cache = False
|
||||
@@ -341,7 +342,7 @@ def profile_function(target: 'Target', device: str = 'meta') -> Callable:
|
||||
|
||||
|
||||
@compatibility(is_backward_compatible=True)
|
||||
def profile_method(target: 'Target', device: str = 'meta') -> Callable:
|
||||
def profile_method(target: "Target", device: str = "meta") -> Callable:
|
||||
"""
|
||||
Wrap a `call_method` node
|
||||
record the memory cost and FLOPs of the execution.
|
||||
@@ -349,8 +350,8 @@ def profile_method(target: 'Target', device: str = 'meta') -> Callable:
|
||||
|
||||
def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any:
|
||||
# execute the method and return the result
|
||||
assert isinstance(target, str), f'{target} instance is not str.'
|
||||
if device == 'meta':
|
||||
assert isinstance(target, str), f"{target} instance is not str."
|
||||
if device == "meta":
|
||||
out, meta = _profile_meta(target, *args, **kwargs)
|
||||
else:
|
||||
out, meta = _profile_concrete(target, *args, **kwargs)
|
||||
@@ -360,7 +361,7 @@ def profile_method(target: 'Target', device: str = 'meta') -> Callable:
|
||||
|
||||
|
||||
@compatibility(is_backward_compatible=True)
|
||||
def profile_module(module: torch.nn.Module, device: str = 'meta') -> Callable:
|
||||
def profile_module(module: torch.nn.Module, device: str = "meta") -> Callable:
|
||||
"""
|
||||
Wrap a `call_module` node or `torch.nn` in order to
|
||||
record the memory cost and FLOPs of the execution.
|
||||
@@ -376,7 +377,6 @@ def profile_module(module: torch.nn.Module, device: str = 'meta') -> Callable:
|
||||
"""
|
||||
|
||||
def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any:
|
||||
|
||||
# calculate parameter size
|
||||
param_size = parameter_size(module)
|
||||
|
||||
@@ -384,13 +384,13 @@ def profile_module(module: torch.nn.Module, device: str = 'meta') -> Callable:
|
||||
# still run the profiling but discard some results regarding `module`.
|
||||
global do_not_cache
|
||||
|
||||
inplace = getattr(module, 'inplace', False)
|
||||
inplace = getattr(module, "inplace", False)
|
||||
if type(module) in OUTPUT_SAVED_MOD:
|
||||
do_not_cache = True
|
||||
if inplace:
|
||||
do_not_cache = True
|
||||
module.inplace = False
|
||||
if device == 'meta':
|
||||
if device == "meta":
|
||||
out, meta = _profile_meta(func, *args, **kwargs)
|
||||
else:
|
||||
out, meta = _profile_concrete(func, *args, **kwargs)
|
||||
|
Reference in New Issue
Block a user