mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-07 03:52:01 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -2,20 +2,13 @@ from typing import Callable, List
|
||||
|
||||
import torch
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
|
||||
MemoryCost,
|
||||
OperationData,
|
||||
OperationDataType,
|
||||
ShardingStrategy,
|
||||
StrategiesVector,
|
||||
TrainCycleItem,
|
||||
)
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, ShardingStrategy, TrainCycleItem
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||
|
||||
from .constants import INPLACE_MODULE, INPLACE_OPS, NO_SAVE_ACTIVATION
|
||||
from .registry import meta_register
|
||||
|
||||
__all__ = ['ShardMetaInfo']
|
||||
__all__ = ["ShardMetaInfo"]
|
||||
|
||||
|
||||
class ShardMetaInfo:
|
||||
@@ -76,10 +69,12 @@ class ShardMetaInfo:
|
||||
"""
|
||||
|
||||
if isinstance(sharding_spec, ShardingSpec):
|
||||
op_data = OperationData(name=operation_data.name,
|
||||
data=torch.zeros(sharding_spec.get_sharded_shape_per_device(), device="meta"),
|
||||
type=operation_data.type,
|
||||
logical_shape=operation_data.logical_shape)
|
||||
op_data = OperationData(
|
||||
name=operation_data.name,
|
||||
data=torch.zeros(sharding_spec.get_sharded_shape_per_device(), device="meta"),
|
||||
type=operation_data.type,
|
||||
logical_shape=operation_data.logical_shape,
|
||||
)
|
||||
elif isinstance(sharding_spec, (list, tuple)):
|
||||
data = operation_data.data
|
||||
assert isinstance(data, (list, tuple)), f"Data Should be list or tuple, but got {type(data)}."
|
||||
@@ -97,8 +92,9 @@ class ShardMetaInfo:
|
||||
"""
|
||||
Compute meta info based on sharding strategy and the given target function.
|
||||
"""
|
||||
assert meta_register.has(self._target.__class__) or meta_register.has(self._target), \
|
||||
f"Meta info for {self._target} is not registered."
|
||||
assert meta_register.has(self._target.__class__) or meta_register.has(
|
||||
self._target
|
||||
), f"Meta info for {self._target} is not registered."
|
||||
if meta_register.has(self._target.__class__):
|
||||
# module
|
||||
meta_func = meta_register.get(self._target.__class__)
|
||||
@@ -117,11 +113,11 @@ class ShardMetaInfo:
|
||||
|
||||
# construct kwargs
|
||||
if self.target in INPLACE_MODULE:
|
||||
kwargs = {'inplace': self.target.inplace}
|
||||
kwargs = {"inplace": self.target.inplace}
|
||||
elif self.target in INPLACE_OPS:
|
||||
kwargs = {'inplace': True}
|
||||
kwargs = {"inplace": True}
|
||||
else:
|
||||
kwargs = {'inplace': False}
|
||||
kwargs = {"inplace": False}
|
||||
|
||||
# compute metainfo with meta_func
|
||||
self.compute_cost, self.memory_cost, self.fwd_in, self.fwd_buffer, self.fwd_out = meta_func(*args, **kwargs)
|
||||
|
Reference in New Issue
Block a user