[misc] update pre-commit and run all files (#4752)

* [misc] update pre-commit

* [misc] run pre-commit

* [misc] remove useless configuration files

* [misc] ignore cuda for clang-format
This commit is contained in:
Hongxin Liu
2023-09-19 14:20:26 +08:00
committed by GitHub
parent 3c6b831c26
commit 079bf3cb26
1268 changed files with 50037 additions and 38444 deletions

View File

@@ -1,12 +1,12 @@
import functools
from typing import Any, Callable, Dict, List, Tuple, Type, Union
from typing import Any, Callable, Tuple, Type, Union
import torch
from colossalai.logging import get_dist_logger
from colossalai.tensor.sharding_spec import ShardingSpec, ShardingSpecException
__all__ = ['ignore_sharding_exception', 'pytree_map']
__all__ = ["ignore_sharding_exception", "pytree_map"]
def ignore_sharding_exception(func):
@@ -48,29 +48,32 @@ def check_sharding_spec_validity(sharding_spec: ShardingSpec, tensor: torch.Tens
tensor_num_dim = tensor.dim()
num_devices_in_col = sharding_spec.device_mesh.shape[0]
num_devices_in_row = sharding_spec.device_mesh.shape[1]
assert sharding_len == tensor_num_dim, \
f'The ShardingSpec ({sharding_spec.sharding_sequence}) is created for {sharding_len}-dimension tensor, but the given tensor is {tensor_num_dim}-dimension ({tensor.shape}).'
assert (
sharding_len == tensor_num_dim
), f"The ShardingSpec ({sharding_spec.sharding_sequence}) is created for {sharding_len}-dimension tensor, but the given tensor is {tensor_num_dim}-dimension ({tensor.shape})."
# make sure the sharding is valid for each dim
for i in range(tensor_num_dim):
dim_size = tensor.shape[i]
dim_spec = sharding_spec.sharding_sequence[i]
if str(dim_spec).startswith('S'):
devices_str = str(dim_spec).lstrip('S')
if str(dim_spec).startswith("S"):
devices_str = str(dim_spec).lstrip("S")
num_devices = 1
if '0' in devices_str:
if "0" in devices_str:
num_devices *= num_devices_in_col
if '1' in devices_str:
if "1" in devices_str:
num_devices *= num_devices_in_row
assert dim_size >= num_devices and dim_size % num_devices == 0, \
f'The dimension at index {i} has value {dim_size}, but it is sharded over {num_devices} devices.'
assert (
dim_size >= num_devices and dim_size % num_devices == 0
), f"The dimension at index {i} has value {dim_size}, but it is sharded over {num_devices} devices."
# make sure the entire shape matches the physical tensor shape
assert sharding_spec.entire_shape == tensor.shape, \
f'The entire_shape of the sharding spec {sharding_spec.entire_shape} does not match the tensor shape {tensor.shape}'
assert (
sharding_spec.entire_shape == tensor.shape
), f"The entire_shape of the sharding spec {sharding_spec.entire_shape} does not match the tensor shape {tensor.shape}"
def pytree_map(obj: Any, fn: Callable, process_types: Union[Type, Tuple[Type]] = (), map_all: bool = False) -> Any: