mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-16 14:41:53 +00:00
[zero] a shard strategy in granularity of tensor (#307)
This commit is contained in:
4
colossalai/zero/shard_utils/__init__.py
Normal file
4
colossalai/zero/shard_utils/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from colossalai.zero.shard_utils.base_shard_strategy import BaseShardStrategy
|
||||
from colossalai.zero.shard_utils.tensor_shard_strategy import TensorShardStrategy
|
||||
|
||||
__all__ = ['BaseShardStrategy', 'TensorShardStrategy']
|
27
colossalai/zero/shard_utils/base_shard_strategy.py
Normal file
27
colossalai/zero/shard_utils/base_shard_strategy.py
Normal file
@@ -0,0 +1,27 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor
|
||||
import torch.distributed as dist
|
||||
from typing import List, Optional
|
||||
|
||||
|
||||
class BaseShardStrategy(ABC):
|
||||
|
||||
def __init__(self, process_group: Optional[dist.ProcessGroup] = None) -> None:
|
||||
self.process_group = process_group
|
||||
self.world_size = dist.get_world_size(self.process_group)
|
||||
self.local_rank = dist.get_rank(self.process_group)
|
||||
super().__init__()
|
||||
|
||||
@abstractmethod
|
||||
def shard(self, tensor_list: List[ShardedTensor]):
|
||||
r"""
|
||||
sharded the memory of tensor on multiple processes.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def gather(self, tensor_list: List[ShardedTensor]):
|
||||
r"""
|
||||
duplicate tensor payload on each processes.
|
||||
"""
|
||||
pass
|
18
colossalai/zero/shard_utils/tensor_shard_strategy.py
Normal file
18
colossalai/zero/shard_utils/tensor_shard_strategy.py
Normal file
@@ -0,0 +1,18 @@
|
||||
from colossalai.zero.shard_utils import BaseShardStrategy
|
||||
import torch.distributed as dist
|
||||
from typing import List, Optional
|
||||
from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor
|
||||
|
||||
|
||||
class TensorShardStrategy(BaseShardStrategy):
|
||||
|
||||
def __init__(self, process_group: Optional[dist.ProcessGroup] = None) -> None:
|
||||
super().__init__(process_group)
|
||||
|
||||
def shard(self, tensor_list: List[ShardedTensor]):
|
||||
for t in tensor_list:
|
||||
t.shard()
|
||||
|
||||
def gather(self, tensor_list: List[ShardedTensor]):
|
||||
for t in tensor_list:
|
||||
t.gather()
|
Reference in New Issue
Block a user