[misc] update pre-commit and run all files (#4752)

* [misc] update pre-commit

* [misc] run pre-commit

* [misc] remove useless configuration files

* [misc] ignore cuda for clang-format
This commit is contained in:
Hongxin Liu
2023-09-19 14:20:26 +08:00
committed by GitHub
parent 3c6b831c26
commit 079bf3cb26
1268 changed files with 50037 additions and 38444 deletions

View File

@@ -14,8 +14,11 @@ from colossalai.tensor.comm_spec import CollectiveCommPattern, CommSpec
from colossalai.tensor.sharding_spec import ShardingSpec
__all__ = [
'BroadcastType', 'is_broadcastable', 'get_broadcast_shape', 'recover_sharding_spec_for_broadcast_shape',
'comm_actions_for_oprands'
"BroadcastType",
"is_broadcastable",
"get_broadcast_shape",
"recover_sharding_spec_for_broadcast_shape",
"comm_actions_for_oprands",
]
@@ -41,7 +44,7 @@ def get_broadcast_shape(shape1: torch.Size, shape2: torch.Size) -> List[int]:
"""
Compute the broadcast shape given two shapes.
"""
assert is_broadcastable(shape1, shape2), f'{shape1} and {shape2} are not broadcastable'
assert is_broadcastable(shape1, shape2), f"{shape1} and {shape2} are not broadcastable"
shape1_reverse = shape1[::-1]
shape2_reverse = shape2[::-1]
min_common_dim = min(len(shape1), len(shape2))
@@ -60,8 +63,9 @@ def get_broadcast_dim_info(logical_shape, physical_shape):
logical_num_dims = len(logical_shape)
physical_num_dims = len(physical_shape)
assert logical_num_dims >= physical_num_dims, \
'The number of dimensions in the logical shape is smaller than that of the physical shape, this tensor is not broadcast!'
assert (
logical_num_dims >= physical_num_dims
), "The number of dimensions in the logical shape is smaller than that of the physical shape, this tensor is not broadcast!"
# track the dim and its broadcasting type
logical_dim_broadcast_info = {}
@@ -85,8 +89,9 @@ def get_broadcast_dim_info(logical_shape, physical_shape):
return logical_dim_broadcast_info
def recover_sharding_spec_for_broadcast_shape(logical_sharding_spec: ShardingSpec, logical_shape: torch.Size,
physical_shape: torch.Size) -> ShardingSpec:
def recover_sharding_spec_for_broadcast_shape(
logical_sharding_spec: ShardingSpec, logical_shape: torch.Size, physical_shape: torch.Size
) -> ShardingSpec:
"""
This function computes the sharding spec for the physical shape of a broadcast tensor.
@@ -124,15 +129,18 @@ def recover_sharding_spec_for_broadcast_shape(logical_sharding_spec: ShardingSpe
physical_dim = physical_num_dims - (logical_num_dims - shape_dim)
physical_dim_partition[physical_dim] = mesh_dim
physical_sharding_spec = ShardingSpec(device_mesh=logical_sharding_spec.device_mesh,
entire_shape=physical_shape,
dim_partition_dict=physical_dim_partition)
physical_sharding_spec = ShardingSpec(
device_mesh=logical_sharding_spec.device_mesh,
entire_shape=physical_shape,
dim_partition_dict=physical_dim_partition,
)
return physical_sharding_spec, removed_dims
def comm_actions_for_oprands(node: Node, removed_dims: List[int], op_data: OperationData,
sharding_spec: ShardingSpec) -> CommAction:
def comm_actions_for_oprands(
node: Node, removed_dims: List[int], op_data: OperationData, sharding_spec: ShardingSpec
) -> CommAction:
"""
This method is used to generate communication actions for oprands which lose information
during convert logical shape to physical shape.
@@ -140,9 +148,11 @@ def comm_actions_for_oprands(node: Node, removed_dims: List[int], op_data: Opera
if len(removed_dims) == 1:
# if list length is 1, extract element from list to avoid using flatten device mesh
removed_dims = removed_dims[0]
comm_spec = CommSpec(comm_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
sharding_spec=sharding_spec,
logical_process_axis=removed_dims)
comm_spec = CommSpec(
comm_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
sharding_spec=sharding_spec,
logical_process_axis=removed_dims,
)
if op_data.type == OperationDataType.PARAM:
comm_type = CommType.HOOK
else:
@@ -151,7 +161,7 @@ def comm_actions_for_oprands(node: Node, removed_dims: List[int], op_data: Opera
for index, arg in enumerate(node.args):
if op_data.name == str(arg):
arg_index = index
assert arg_index >= 0, f'op_data should be an argument of node.'
assert arg_index >= 0, f"op_data should be an argument of node."
comm_action = CommAction(
comm_spec=comm_spec,
comm_type=comm_type,