mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 19:13:01 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -1,9 +1,9 @@
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Callable, ClassVar, Dict, List, Optional, Tuple, Union
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch.autograd.profiler_util import _format_memory, _format_time
|
||||
from torch.fx import Graph, GraphModule, Node
|
||||
from torch.autograd.profiler_util import _format_memory
|
||||
from torch.fx import Node
|
||||
|
||||
from colossalai._analyzer.envs import MeshConfig
|
||||
|
||||
@@ -85,12 +85,12 @@ class MetaInfo:
|
||||
node: Node
|
||||
|
||||
# directory
|
||||
mod_dir: str = ''
|
||||
mod_dir: str = ""
|
||||
|
||||
# ctx[data_ptr] = Tensor
|
||||
# mark the storage for ctx.save_for_backward
|
||||
global_ctx: Dict[str, torch.Tensor] = field(default_factory=lambda: {}) # globally shared
|
||||
curr_ctx: Dict[str, torch.Tensor] = field(default_factory=lambda: {}) # global_ctx till this node
|
||||
global_ctx: Dict[str, torch.Tensor] = field(default_factory=lambda: {}) # globally shared
|
||||
curr_ctx: Dict[str, torch.Tensor] = field(default_factory=lambda: {}) # global_ctx till this node
|
||||
|
||||
# should be updated after each graph manipulation
|
||||
# ============================== Update ====================================
|
||||
@@ -100,7 +100,7 @@ class MetaInfo:
|
||||
|
||||
inputs: Tuple[torch.Tensor] = ()
|
||||
outputs: Tuple[torch.Tensor] = ()
|
||||
is_alias: Tuple[bool] = () # whether the output is an alias of input
|
||||
is_alias: Tuple[bool] = () # whether the output is an alias of input
|
||||
|
||||
# compute cost
|
||||
fwd_flop: Optional[int] = 0
|
||||
@@ -112,29 +112,29 @@ class MetaInfo:
|
||||
|
||||
# should keep the same whenever manipulated
|
||||
# ============================= Invariant ==================================
|
||||
activation_checkpoint: Tuple[torch.Tensor] = () # (region_0, region_1, ...) support nested codegen
|
||||
activation_checkpoint: Tuple[torch.Tensor] = () # (region_0, region_1, ...) support nested codegen
|
||||
to_offload: Optional[bool] = False
|
||||
sharding_spec: str = 'RR'
|
||||
sharding_spec: str = "RR"
|
||||
|
||||
def __new__(cls, node: Node, **kwargs):
|
||||
orig_init = cls.__init__
|
||||
|
||||
# if initialized, return the existing one
|
||||
# should disable the __init__ function
|
||||
if node.meta.get('info', None) is not None:
|
||||
if node.meta.get("info", None) is not None:
|
||||
|
||||
def _dummy(self, *args, **kwargs):
|
||||
if getattr(self, '_is_init', False):
|
||||
if getattr(self, "_is_init", False):
|
||||
self._is_init = True
|
||||
orig_init(self, *args, **kwargs)
|
||||
cls.__init__ = orig_init
|
||||
|
||||
cls.__init__ = _dummy
|
||||
return node.meta['info']
|
||||
return node.meta["info"]
|
||||
return super().__new__(cls)
|
||||
|
||||
def __post_init__(self):
|
||||
self.node.meta['info'] = self
|
||||
self.node.meta["info"] = self
|
||||
|
||||
@property
|
||||
def fwd_time(self, tflops: float = MeshConfig.TFLOPS, bandwidth: float = MeshConfig.BANDWIDTH):
|
||||
@@ -188,24 +188,26 @@ class MetaInfo:
|
||||
return compute_size_in_bytes(self.inputs)
|
||||
|
||||
def __repr__(self):
|
||||
s = f'Node {self.node.name}'
|
||||
s = f"Node {self.node.name}"
|
||||
if self.parameters:
|
||||
s += f'\n\thas parameter of size {_format_memory(self.param_size)}'
|
||||
s += f"\n\thas parameter of size {_format_memory(self.param_size)}"
|
||||
if self.buffers:
|
||||
s += f'\n\thas buffer of size {_format_memory(self.buffer_size)}'
|
||||
s += f"\n\thas buffer of size {_format_memory(self.buffer_size)}"
|
||||
if self.output_size:
|
||||
s += f'\n\thas output activation of size {_format_memory(self.output_size)}'
|
||||
s += f"\n\thas output activation of size {_format_memory(self.output_size)}"
|
||||
# if self.total_size:
|
||||
# s += f'\n\thas total activation of size {_format_memory(self.total_size)}'
|
||||
if self.temp_size:
|
||||
s += f'\n\thas temp activation of size {_format_memory(self.temp_size)}'
|
||||
s += f"\n\thas temp activation of size {_format_memory(self.temp_size)}"
|
||||
if self.backward_size:
|
||||
s += f'\n\thas backward activation of size {_format_memory(self.backward_size)}'
|
||||
s += f'\n\tfwd_flop = {self.fwd_flop}'\
|
||||
f'\n\tbwd_flop = {self.bwd_flop}'\
|
||||
f'\n\tfwd_comm = {self.fwd_comm}'\
|
||||
f'\n\tbwd_comm = {self.bwd_comm}'\
|
||||
f'\n\tto_recompute = {self.to_recompute}'\
|
||||
f'\n\tto_offload = {self.to_offload}'\
|
||||
f'\n\tsharding_spec = {self.sharding_spec}'
|
||||
s += f"\n\thas backward activation of size {_format_memory(self.backward_size)}"
|
||||
s += (
|
||||
f"\n\tfwd_flop = {self.fwd_flop}"
|
||||
f"\n\tbwd_flop = {self.bwd_flop}"
|
||||
f"\n\tfwd_comm = {self.fwd_comm}"
|
||||
f"\n\tbwd_comm = {self.bwd_comm}"
|
||||
f"\n\tto_recompute = {self.to_recompute}"
|
||||
f"\n\tto_offload = {self.to_offload}"
|
||||
f"\n\tsharding_spec = {self.sharding_spec}"
|
||||
)
|
||||
return s
|
||||
|
Reference in New Issue
Block a user