mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-03 18:19:58 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -2,4 +2,4 @@ from .base_shard_strategy import BaseShardStrategy
|
||||
from .bucket_tensor_shard_strategy import BucketTensorShardStrategy
|
||||
from .tensor_shard_strategy import TensorShardStrategy
|
||||
|
||||
__all__ = ['BaseShardStrategy', 'TensorShardStrategy', 'BucketTensorShardStrategy']
|
||||
__all__ = ["BaseShardStrategy", "TensorShardStrategy", "BucketTensorShardStrategy"]
|
||||
|
@@ -7,10 +7,8 @@ from colossalai.legacy.zero.sharded_param.sharded_tensor import ShardedTensor
|
||||
|
||||
|
||||
class BaseShardStrategy(ABC):
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Abstract Shard Strategy. Use to shard a tensors on multiple GPUs.
|
||||
"""
|
||||
"""Abstract Shard Strategy. Use to shard a tensors on multiple GPUs."""
|
||||
super().__init__()
|
||||
|
||||
@abstractmethod
|
||||
|
@@ -18,7 +18,6 @@ class BucketTensorShardStrategy(TensorShardStrategy):
|
||||
"""
|
||||
|
||||
def gather(self, tensor_list: List[ShardedTensor], process_group: Optional[dist.ProcessGroup] = None):
|
||||
|
||||
tensor_list: List[ShardedTensor] = [t for t in tensor_list if t.is_sharded]
|
||||
if len(tensor_list) == 0:
|
||||
return
|
||||
@@ -40,8 +39,8 @@ class BucketTensorShardStrategy(TensorShardStrategy):
|
||||
buffer_list = [buffer.to(target_device) for buffer in buffer_list]
|
||||
offset = 0
|
||||
for i, t in enumerate(tensor_list):
|
||||
gathered_payload = [buffer[offset:offset + tensor_numels[i]] for buffer in buffer_list]
|
||||
gathered_payload = torch.cat(gathered_payload)[:t.origin_numel].view(t.origin_shape)
|
||||
gathered_payload = [buffer[offset : offset + tensor_numels[i]] for buffer in buffer_list]
|
||||
gathered_payload = torch.cat(gathered_payload)[: t.origin_numel].view(t.origin_shape)
|
||||
t.payload_reset(gathered_payload)
|
||||
t.is_sharded = False
|
||||
offset += tensor_numels[i]
|
||||
|
@@ -24,7 +24,7 @@ class TensorShardStrategy(BaseShardStrategy):
|
||||
self._gather_tensor(t, process_group)
|
||||
|
||||
def _shard_tensor(self, t: ShardedTensor, process_group: Optional[dist.ProcessGroup] = None):
|
||||
""" Shard tensor among processes.
|
||||
"""Shard tensor among processes.
|
||||
|
||||
Args:
|
||||
t (ShardedTensor): a tensor to be sharded.
|
||||
@@ -33,9 +33,11 @@ class TensorShardStrategy(BaseShardStrategy):
|
||||
"""
|
||||
if t.is_sharded:
|
||||
return
|
||||
if t.payload.device.type == 'cuda':
|
||||
assert t.payload.device == get_current_device(), f"shard tensor on cuda device index {t.payload.device.index},"\
|
||||
if t.payload.device.type == "cuda":
|
||||
assert t.payload.device == get_current_device(), (
|
||||
f"shard tensor on cuda device index {t.payload.device.index},"
|
||||
f" but current cuda device is {get_current_device()}"
|
||||
)
|
||||
sharded_payload, _ = get_shard(t.payload, dist.get_rank(process_group), dist.get_world_size(process_group))
|
||||
t.payload_reset(sharded_payload)
|
||||
t.is_sharded = True
|
||||
|
Reference in New Issue
Block a user