mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-04 10:34:41 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -10,7 +10,6 @@ from colossalai.tensor.comm_spec import CommSpec
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||
|
||||
from .constants import (
|
||||
BCAST_FUNC_OP,
|
||||
ELEMENTWISE_FUNC_OP,
|
||||
ELEMENTWISE_METHOD_OP,
|
||||
ELEMENTWISE_MODULE_OP,
|
||||
@@ -18,13 +17,14 @@ from .constants import (
|
||||
RESHAPE_METHOD_OP,
|
||||
)
|
||||
|
||||
__all__ = ['OperationDataType', 'OperationData', 'TrainCycleItem', 'MemoryCost', 'ShardingStrategy', 'StrategiesVector']
|
||||
__all__ = ["OperationDataType", "OperationData", "TrainCycleItem", "MemoryCost", "ShardingStrategy", "StrategiesVector"]
|
||||
|
||||
|
||||
class OperationDataType(Enum):
|
||||
"""
|
||||
An operation can come from the argument list of an operator or the parameter list of a module.
|
||||
"""
|
||||
|
||||
INPUT = 0
|
||||
ARG = 1
|
||||
PARAM = 2
|
||||
@@ -43,6 +43,7 @@ class OperationData:
|
||||
data (Any): the value for this data, usually it is a meta tensor.
|
||||
logical_shape (Tuple[int]): the logical shape of the data, it can be different from the its actual shape in memory.
|
||||
"""
|
||||
|
||||
name: str
|
||||
type: OperationDataType
|
||||
data: Any
|
||||
@@ -69,13 +70,13 @@ class OperationData:
|
||||
self.logical_shape = _infer_logical_shape(self.data)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f'OperationData(name={self.name}, type={self.type})'
|
||||
return f"OperationData(name={self.name}, type={self.type})"
|
||||
|
||||
def __eq__(self, other) -> bool:
|
||||
return other.name == self.name
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash(f'{self.name}')
|
||||
return hash(f"{self.name}")
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -88,6 +89,7 @@ class TrainCycleItem:
|
||||
fwd (float): the item for the forward pass
|
||||
bwd (float): the item for the backward pass
|
||||
"""
|
||||
|
||||
fwd: Any
|
||||
bwd: Any
|
||||
total: Any
|
||||
@@ -104,6 +106,7 @@ class MemoryCost:
|
||||
temp (int): the memory cost incurred by the temporary tensors in bytes.
|
||||
buffer (int): the memory cost incurred by the module buffer in bytes.
|
||||
"""
|
||||
|
||||
activation: int = 0
|
||||
parameter: int = 0
|
||||
temp: int = 0
|
||||
@@ -120,6 +123,7 @@ class CommType(Enum):
|
||||
HOOK: the communication action is used to do the grad all reduce.
|
||||
IMPLICIT: the communication action happens during the kernel execution, such as SyncBatchNorm
|
||||
"""
|
||||
|
||||
BEFORE = 0
|
||||
AFTER = 1
|
||||
HOOK = 2
|
||||
@@ -137,6 +141,7 @@ class CommAction:
|
||||
arg_index: record the location of tensor which join the communication, we cannot use name of node or op_data at runtime,
|
||||
because the args of node may be changed by graph transform passes.
|
||||
"""
|
||||
|
||||
comm_spec: CommSpec = None
|
||||
comm_type: CommType = None
|
||||
arg_index: int = -1
|
||||
@@ -156,6 +161,7 @@ class ShardingStrategy:
|
||||
memory_cost (TrainCycleItem): Memory cost of the output node using this strategy. (default to None)
|
||||
input_sharding_specs (List(ShardingSpec)): The ShardingSpecs of the input nodes.
|
||||
"""
|
||||
|
||||
name: str
|
||||
sharding_specs: Dict[OperationData, Union[ShardingSpec, Tuple[ShardingSpec]]] = None
|
||||
compute_cost: TrainCycleItem = None
|
||||
@@ -200,7 +206,6 @@ class ShardingStrategy:
|
||||
raise KeyError(f"Could not find the ShardingSpec for OperationData with name {name}")
|
||||
|
||||
def clone(self):
|
||||
|
||||
def _deepcopy_dict_vals(data: Dict):
|
||||
return {k: deepcopy(v) for k, v in data.items()}
|
||||
|
||||
@@ -209,31 +214,34 @@ class ShardingStrategy:
|
||||
# Consider the examples below:
|
||||
# If self.communication_actions is an empty dictionary {}, then self.communication_actions is not None, but its __bool__ value is False.
|
||||
# In this case, if we set None to the new object, program will crash when we try to access the communication_actions.items.
|
||||
communication_actions = _deepcopy_dict_vals(
|
||||
self.communication_actions) if self.communication_actions is not None else None
|
||||
communication_actions = (
|
||||
_deepcopy_dict_vals(self.communication_actions) if self.communication_actions is not None else None
|
||||
)
|
||||
# same reason as communication_actions
|
||||
resharding_costs = _deepcopy_dict_vals(self.resharding_costs) if self.resharding_costs is not None else None
|
||||
compute_cost = deepcopy(self.compute_cost)
|
||||
communication_cost = deepcopy(self.communication_cost)
|
||||
memory_cost = deepcopy(self.memory_cost)
|
||||
|
||||
return ShardingStrategy(name=self.name,
|
||||
sharding_specs=sharding_specs,
|
||||
compute_cost=compute_cost,
|
||||
communication_cost=communication_cost,
|
||||
memory_cost=memory_cost,
|
||||
communication_actions=communication_actions,
|
||||
resharding_costs=resharding_costs)
|
||||
return ShardingStrategy(
|
||||
name=self.name,
|
||||
sharding_specs=sharding_specs,
|
||||
compute_cost=compute_cost,
|
||||
communication_cost=communication_cost,
|
||||
memory_cost=memory_cost,
|
||||
communication_actions=communication_actions,
|
||||
resharding_costs=resharding_costs,
|
||||
)
|
||||
|
||||
|
||||
class StrategiesVector(list):
|
||||
'''
|
||||
"""
|
||||
Each node in fx graph will have a corresponding StrategiesVector, to store all the possible
|
||||
strategies of the node.
|
||||
|
||||
Argument:
|
||||
node (Node): node for which the list of sharding strategies are generated.
|
||||
'''
|
||||
"""
|
||||
|
||||
def __init__(self, node: Node):
|
||||
super().__init__()
|
||||
@@ -245,7 +253,7 @@ class StrategiesVector(list):
|
||||
|
||||
def check_merge(self):
|
||||
merge_label = False
|
||||
if self.node.op == 'call_module':
|
||||
if self.node.op == "call_module":
|
||||
target = self.node.target
|
||||
root_module = self.node.graph.owning_module
|
||||
submod = root_module.get_submodule(target)
|
||||
@@ -255,7 +263,7 @@ class StrategiesVector(list):
|
||||
if submod_type in ELEMENTWISE_MODULE_OP:
|
||||
merge_label = True
|
||||
|
||||
if self.node.op == 'call_function':
|
||||
if self.node.op == "call_function":
|
||||
# we could merge element-wise op, because the output sharding spec is always same as the input sharding spec.
|
||||
if self.node.target in ELEMENTWISE_FUNC_OP:
|
||||
merge_label = True
|
||||
@@ -267,7 +275,7 @@ class StrategiesVector(list):
|
||||
if self.node.target in RESHAPE_FUNC_OP:
|
||||
merge_label = True
|
||||
|
||||
if self.node.op == 'call_method':
|
||||
if self.node.op == "call_method":
|
||||
# we could merge reshape op, because their computation costs are negligible.
|
||||
method = getattr(self.node.args[0]._meta_data.__class__, self.node.target)
|
||||
if method in RESHAPE_METHOD_OP:
|
||||
|
Reference in New Issue
Block a user