mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-10 21:40:02 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Callable, Iterator, List, Optional, Tuple, Union
|
||||
from typing import Callable, Iterator, List, Optional, Tuple
|
||||
|
||||
import torch.nn as nn
|
||||
from torch.optim import Optimizer
|
||||
@@ -9,11 +9,10 @@ from torch.utils.data import DataLoader, Dataset
|
||||
from colossalai.checkpoint_io import CheckpointIO
|
||||
from colossalai.interface import OptimizerWrapper
|
||||
|
||||
__all__ = ['Plugin']
|
||||
__all__ = ["Plugin"]
|
||||
|
||||
|
||||
class Plugin(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def supported_devices(self) -> List[str]:
|
||||
pass
|
||||
@@ -51,33 +50,31 @@ class Plugin(ABC):
|
||||
"""
|
||||
Whether the plugin controls the checkpoint io
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_checkpoint_io(self) -> CheckpointIO:
|
||||
"""
|
||||
Get checkpoint io object for this plugin, only invoked when control_checkpoint_io is True.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def no_sync(self, model: nn.Module, optimizer: OptimizerWrapper) -> Iterator[None]:
|
||||
"""
|
||||
Context manager to disable gradient synchronization.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def prepare_dataloader(self,
|
||||
dataset: Dataset,
|
||||
batch_size: int,
|
||||
shuffle: bool = False,
|
||||
seed: int = 1024,
|
||||
drop_last: bool = False,
|
||||
pin_memory: bool = False,
|
||||
num_workers: int = 0,
|
||||
**kwargs):
|
||||
def prepare_dataloader(
|
||||
self,
|
||||
dataset: Dataset,
|
||||
batch_size: int,
|
||||
shuffle: bool = False,
|
||||
seed: int = 1024,
|
||||
drop_last: bool = False,
|
||||
pin_memory: bool = False,
|
||||
num_workers: int = 0,
|
||||
**kwargs,
|
||||
):
|
||||
"""Prepare a dataloader for distributed training. The dataloader will be wrapped by
|
||||
`torch.utils.data.DataLoader`
|
||||
"""
|
||||
pass
|
||||
|
Reference in New Issue
Block a user