ColossalAI/colossalai/legacy/zero/shard_utils/base_shard_strategy.py
Hongxin Liu 079bf3cb26
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit

* [misc] run pre-commit

* [misc] remove useless configuration files

* [misc] ignore cuda for clang-format
2023-09-19 14:20:26 +08:00

21 lines
635 B
Python

from abc import ABC, abstractmethod
from typing import List, Optional
import torch.distributed as dist
from colossalai.legacy.zero.sharded_param.sharded_tensor import ShardedTensor
class BaseShardStrategy(ABC):
def __init__(self) -> None:
"""Abstract Shard Strategy. Use to shard a tensors on multiple GPUs."""
super().__init__()
@abstractmethod
def shard(self, tensor_list: List[ShardedTensor], process_group: Optional[dist.ProcessGroup] = None):
pass
@abstractmethod
def gather(self, tensor_list: List[ShardedTensor], process_group: Optional[dist.ProcessGroup] = None):
pass