ColossalAI/colossalai/zero/legacy/shard_utils/base_shard_strategy.py
ver217 26b7aac0be
[zero] reorganize zero/gemini folder structure (#3424)
* [zero] refactor low-level zero folder structure

* [zero] fix legacy zero import path

* [zero] fix legacy zero import path

* [zero] remove useless import

* [zero] refactor gemini folder structure

* [zero] refactor gemini folder structure

* [zero] refactor legacy zero import path

* [zero] refactor gemini folder structure

* [zero] refactor gemini folder structure

* [zero] refactor gemini folder structure

* [zero] refactor legacy zero import path

* [zero] fix test import path

* [zero] fix test

* [zero] fix circular import

* [zero] update import
2023-04-04 13:48:16 +08:00

23 lines
645 B
Python

from abc import ABC, abstractmethod
from typing import List, Optional
import torch.distributed as dist
from colossalai.zero.legacy.sharded_param.sharded_tensor import ShardedTensor
class BaseShardStrategy(ABC):
def __init__(self) -> None:
"""Abstract Shard Strategy. Use to shard a tensors on multiple GPUs.
"""
super().__init__()
@abstractmethod
def shard(self, tensor_list: List[ShardedTensor], process_group: Optional[dist.ProcessGroup] = None):
pass
@abstractmethod
def gather(self, tensor_list: List[ShardedTensor], process_group: Optional[dist.ProcessGroup] = None):
pass