[misc] update pre-commit and run all files (#4752)

* [misc] update pre-commit

* [misc] run pre-commit

* [misc] remove useless configuration files

* [misc] ignore cuda for clang-format
This commit is contained in:
Hongxin Liu
2023-09-19 14:20:26 +08:00
committed by GitHub
parent 3c6b831c26
commit 079bf3cb26
1268 changed files with 50037 additions and 38444 deletions

View File

@@ -8,8 +8,11 @@ import torch
from colossalai.tensor.sharding_spec import ShardingSpec
__all__ = [
'transpose_partition_dim', 'update_partition_dim', 'enumerate_all_possible_1d_sharding',
'enumerate_all_possible_2d_sharding', 'generate_sharding_size'
"transpose_partition_dim",
"update_partition_dim",
"enumerate_all_possible_1d_sharding",
"enumerate_all_possible_2d_sharding",
"generate_sharding_size",
]
@@ -22,8 +25,7 @@ def transpose_partition_dim(sharding_spec: ShardingSpec, dim1: int, dim2: int) -
dim1 (int): the tensor dimension to switch
dim2 (int): the tensor dimension to switch
"""
assert len(sharding_spec.entire_shape) >= 2, \
'The entire_shape of the sharding spec must have at least 2 dimensions'
assert len(sharding_spec.entire_shape) >= 2, "The entire_shape of the sharding spec must have at least 2 dimensions"
dim_partition_dict = sharding_spec.dim_partition_dict
# transpose the dim partition
@@ -45,10 +47,9 @@ def transpose_partition_dim(sharding_spec: ShardingSpec, dim1: int, dim2: int) -
return sharding_spec
def update_partition_dim(sharding_spec: ShardingSpec,
dim_mapping: Dict[int, int],
physical_shape: torch.Size,
inplace: bool = False):
def update_partition_dim(
sharding_spec: ShardingSpec, dim_mapping: Dict[int, int], physical_shape: torch.Size, inplace: bool = False
):
"""
This method is used to update the partition dim dict from the logical one to the physical one.
@@ -78,9 +79,9 @@ def update_partition_dim(sharding_spec: ShardingSpec,
new_dim_partition_dict[tensor_dim] = mesh_dims
# update sharding spec
current_sharding_spec.__init__(device_mesh=sharding_spec.device_mesh,
entire_shape=physical_shape,
dim_partition_dict=new_dim_partition_dict)
current_sharding_spec.__init__(
device_mesh=sharding_spec.device_mesh, entire_shape=physical_shape, dim_partition_dict=new_dim_partition_dict
)
return current_sharding_spec