[misc] update pre-commit and run all files (#4752)

* [misc] update pre-commit

* [misc] run pre-commit

* [misc] remove useless configuration files

* [misc] ignore cuda for clang-format
This commit is contained in:
Hongxin Liu
2023-09-19 14:20:26 +08:00
committed by GitHub
parent 3c6b831c26
commit 079bf3cb26
1268 changed files with 50037 additions and 38444 deletions

View File

@@ -2,20 +2,13 @@ from typing import Callable, List
import torch
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
MemoryCost,
OperationData,
OperationDataType,
ShardingStrategy,
StrategiesVector,
TrainCycleItem,
)
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, ShardingStrategy, TrainCycleItem
from colossalai.tensor.sharding_spec import ShardingSpec
from .constants import INPLACE_MODULE, INPLACE_OPS, NO_SAVE_ACTIVATION
from .registry import meta_register
__all__ = ['ShardMetaInfo']
__all__ = ["ShardMetaInfo"]
class ShardMetaInfo:
@@ -76,10 +69,12 @@ class ShardMetaInfo:
"""
if isinstance(sharding_spec, ShardingSpec):
op_data = OperationData(name=operation_data.name,
data=torch.zeros(sharding_spec.get_sharded_shape_per_device(), device="meta"),
type=operation_data.type,
logical_shape=operation_data.logical_shape)
op_data = OperationData(
name=operation_data.name,
data=torch.zeros(sharding_spec.get_sharded_shape_per_device(), device="meta"),
type=operation_data.type,
logical_shape=operation_data.logical_shape,
)
elif isinstance(sharding_spec, (list, tuple)):
data = operation_data.data
assert isinstance(data, (list, tuple)), f"Data Should be list or tuple, but got {type(data)}."
@@ -97,8 +92,9 @@ class ShardMetaInfo:
"""
Compute meta info based on sharding strategy and the given target function.
"""
assert meta_register.has(self._target.__class__) or meta_register.has(self._target), \
f"Meta info for {self._target} is not registered."
assert meta_register.has(self._target.__class__) or meta_register.has(
self._target
), f"Meta info for {self._target} is not registered."
if meta_register.has(self._target.__class__):
# module
meta_func = meta_register.get(self._target.__class__)
@@ -117,11 +113,11 @@ class ShardMetaInfo:
# construct kwargs
if self.target in INPLACE_MODULE:
kwargs = {'inplace': self.target.inplace}
kwargs = {"inplace": self.target.inplace}
elif self.target in INPLACE_OPS:
kwargs = {'inplace': True}
kwargs = {"inplace": True}
else:
kwargs = {'inplace': False}
kwargs = {"inplace": False}
# compute metainfo with meta_func
self.compute_cost, self.memory_cost, self.fwd_in, self.fwd_buffer, self.fwd_out = meta_func(*args, **kwargs)