[booster] added the plugin base and torch ddp plugin (#3180)

* [booster] added the plugin base and torch ddp plugin

* polish code

* polish code

* polish code
This commit is contained in:
Frank Lee
2023-03-21 17:39:30 +08:00
committed by GitHub
parent e5f668f280
commit e7f3bed2d3
8 changed files with 378 additions and 86 deletions

View File

@@ -0,0 +1,51 @@
from abc import ABC, abstractmethod
from typing import Callable, List, Tuple, Union
import torch.nn as nn
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from torch.utils.data import DataLoader
from colossalai.booster.interface import OptimizerWrapper
__all__ = ['Plugin']
class Plugin(ABC):
@property
@abstractmethod
def supported_devices(self) -> List[str]:
pass
@property
@abstractmethod
def supported_precisions(self) -> List[str]:
pass
@property
@abstractmethod
def control_precision(self) -> bool:
pass
@property
@abstractmethod
def control_device(self) -> bool:
pass
@property
@abstractmethod
def support_no_sync(self) -> bool:
pass
@abstractmethod
def configure(
self,
model: nn.Module,
optimizer: Optimizer,
criterion: Callable = None,
dataloader: DataLoader = None,
lr_scheduler: LRScheduler = None,
) -> Tuple[Union[nn.Module, OptimizerWrapper, LRScheduler, DataLoader]]:
# implement this method
pass