ColossalAI/colossalai/zero/shard_utils/base_shard_strategy.py
ver217 a241f61b34
[zero] Update initialize for ZeRO (#458)
* polish code

* shard strategy receive pg in shard() / gather()

* update zero engine

* polish code
2022-03-18 16:18:31 +08:00

22 lines
637 B
Python

from abc import ABC, abstractmethod
from typing import List, Optional
import torch.distributed as dist
from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor
class BaseShardStrategy(ABC):
def __init__(self) -> None:
"""Abstract Shard Strategy. Use to shard a tensors on multiple GPUs.
"""
super().__init__()
@abstractmethod
def shard(self, tensor_list: List[ShardedTensor], process_group: Optional[dist.ProcessGroup] = None):
pass
@abstractmethod
def gather(self, tensor_list: List[ShardedTensor], process_group: Optional[dist.ProcessGroup] = None):
pass