[booster] update prepare dataloader method for plugin (#3706)

* [booster] add prepare dataloader method for plug

* [booster] update examples and docstr
This commit is contained in:
Hongxin Liu
2023-05-08 15:44:03 +08:00
committed by GitHub
parent f83ea813f5
commit 3bf09efe74
9 changed files with 41 additions and 40 deletions

View File

@@ -20,21 +20,19 @@ class DPPluginBase(Plugin):
self.rank = dist.get_rank()
self.world_size = dist.get_world_size()
def prepare_train_dataloader(self,
dataset,
batch_size,
shuffle=False,
seed=1024,
drop_last=False,
pin_memory=False,
num_workers=0,
**kwargs):
def prepare_dataloader(self,
dataset,
batch_size,
shuffle=False,
seed=1024,
drop_last=False,
pin_memory=False,
num_workers=0,
**kwargs):
r"""
Prepare a dataloader for distributed training. The dataloader will be wrapped by
`torch.utils.data.DataLoader` and `torch.utils.data.DistributedSampler`.
Note:
1. Evaluation datasets should not be passed to this function.
Args:
dataset (`torch.utils.data.Dataset`): The dataset to be loaded.

View File

@@ -156,7 +156,7 @@ class GeminiPlugin(DPPluginBase):
>>> model, train_dataset, optimizer, criterion = ...
>>> plugin = GeminiPlugin()
>>> train_dataloader = plugin.prepare_train_dataloader(train_dataset, batch_size=8)
>>> train_dataloader = plugin.prepare_dataloader(train_dataset, batch_size=8)
>>> booster = Booster(plugin=plugin)
>>> model, optimizer, train_dataloader, criterion = booster.boost(model, optimizer, train_dataloader, criterion)

View File

@@ -95,7 +95,7 @@ class LowLevelZeroPlugin(DPPluginBase):
>>> model, train_dataset, optimizer, criterion = ...
>>> plugin = LowLevelZeroPlugin()
>>> train_dataloader = plugin.prepare_train_dataloader(train_dataset, batch_size=8)
>>> train_dataloader = plugin.prepare_dataloader(train_dataset, batch_size=8)
>>> booster = Booster(plugin=plugin)
>>> model, optimizer, train_dataloader, criterion = booster.boost(model, optimizer, train_dataloader, criterion)

View File

@@ -4,7 +4,7 @@ from typing import Callable, List, Tuple, Union
import torch.nn as nn
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from torch.utils.data import DataLoader
from torch.utils.data import DataLoader, Dataset
from colossalai.checkpoint_io import CheckpointIO
from colossalai.interface import OptimizerWrapper
@@ -59,3 +59,18 @@ class Plugin(ABC):
Get checkpoint io object for this plugin, only invoked when control_checkpoint_io is True.
"""
pass
@abstractmethod
def prepare_dataloader(self,
dataset: Dataset,
batch_size: int,
shuffle: bool = False,
seed: int = 1024,
drop_last: bool = False,
pin_memory: bool = False,
num_workers: int = 0,
**kwargs):
"""Prepare a dataloader for distributed training. The dataloader will be wrapped by
`torch.utils.data.DataLoader`
"""
pass

View File

@@ -72,7 +72,7 @@ class TorchDDPPlugin(DPPluginBase):
>>> model, train_dataset, optimizer, criterion = ...
>>> plugin = TorchDDPPlugin()
>>> train_dataloader = plugin.prepare_train_dataloader(train_dataset, batch_size=8)
>>> train_dataloader = plugin.prepare_dataloader(train_dataset, batch_size=8)
>>> booster = Booster(plugin=plugin)
>>> model, optimizer, train_dataloader, criterion = booster.boost(model, optimizer, train_dataloader, criterion)