mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-10-05 00:56:17 +00:00
Refactored docstring to google style
This commit is contained in:
@@ -19,18 +19,15 @@ T_co = TypeVar('T_co', covariant=True)
|
||||
|
||||
@DATA_SAMPLERS.register_module
|
||||
class DataParallelSampler(Sampler):
|
||||
"""A data sampler for distributed data parallelism
|
||||
"""A data sampler for distributed data parallelism.
|
||||
|
||||
:param dataset: A Dataset instance
|
||||
:type dataset: torch.utils.data.Dataset
|
||||
:param shuffle: Whether to shuffle data, defaults to False
|
||||
:type shuffle: bool, optional
|
||||
:param seed: The random seed, defaults to 0
|
||||
:type seed: int, optional
|
||||
:param drop_last: Set to True to drop the last incomplete batch, if the dataset size is not divisible by the batch
|
||||
size. If False and the size of dataset is not divisible by the batch size, then the last batch will be smaller,
|
||||
defaults to False
|
||||
:type drop_last: bool, optional
|
||||
Args:
|
||||
dataset (:class:`torch.utils.data.Dataset`): The Dataset for sampling.
|
||||
shuffle (bool, optional): Whether to shuffle data, defaults to False.
|
||||
seed (int, optional): The random seed used for sampling, defaults to 0.
|
||||
drop_last (bool, optional): Set to True to drop the last incomplete batch, if the dataset size
|
||||
is not divisible by the batch size. If False and the size of dataset is not divisible by
|
||||
the batch size, then the last batch will be smaller, defaults to False.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
@@ -104,8 +101,8 @@ class DataParallelSampler(Sampler):
|
||||
use a different random ordering for each epoch. Otherwise, the next iteration of this
|
||||
sampler will yield the same ordering.
|
||||
|
||||
:param epoch: Epoch number.
|
||||
:type epoch: int
|
||||
Args:
|
||||
epoch (int): Epoch number.
|
||||
"""
|
||||
self.epoch = epoch
|
||||
|
||||
@@ -118,29 +115,27 @@ def get_dataloader(dataset,
|
||||
pin_memory=False,
|
||||
num_workers=0,
|
||||
**kwargs):
|
||||
"""Set up a deterministic dataloader (also configure seed workers, samplers and whether shuffle or not)
|
||||
r"""Set up a deterministic dataloader (also configure seed workers, samplers and whether shuffle or not)
|
||||
|
||||
.. note:: When pipeline parallel is enabled, shuffle cannot be True as it will result in mismatch between input data
|
||||
on the 1st stage and label on the last stage
|
||||
Note:
|
||||
When pipeline parallel is enabled, shuffle cannot be True as it will result in mismatch between input data
|
||||
on the 1st stage and label on the last stage.
|
||||
|
||||
:param dataset: A :class:`torch.utils.data.Dataset` object
|
||||
:param shuffle: Whether to shuffle the dataset
|
||||
:param seed: Random worker seed, defaults to 1024
|
||||
:param add_sampler: Add DistributedDataParallelSampelr to the dataset
|
||||
:param drop_last: Drop the last incomplete batch of data
|
||||
:param pin_memory: Whether to pin memory address in CPU memory
|
||||
:param num_workers: Number of worker threads for this dataloader
|
||||
Args:
|
||||
dataset (:class:`torch.utils.data.Dataset`): The dataset to be loaded.
|
||||
shuffle (bool, optional): Whether to shuffle the dataset. Defaults to False.
|
||||
seed (int, optional): Random worker seed for sampling, defaults to 1024.
|
||||
add_sampler: Whether to add ``DistributedDataParallelSampler`` to the dataset. Defaults to True.
|
||||
drop_last (bool, optional): Set to True to drop the last incomplete batch, if the dataset size
|
||||
is not divisible by the batch size. If False and the size of dataset is not divisible by
|
||||
the batch size, then the last batch will be smaller, defaults to False.
|
||||
pin_memory (bool, optional): Whether to pin memory address in CPU memory. Defaults to False.
|
||||
num_workers (int, optional): Number of worker threads for this dataloader. Defaults to 0.
|
||||
kwargs (dict): optional parameters for ``torch.utils.data.DataLoader``, more details could be found in
|
||||
`DataLoader <https://pytorch.org/docs/stable/_modules/torch/utils/data/dataloader.html#DataLoader>`_.
|
||||
|
||||
:type dataset: :class:`torch.utils.data.Dataset`
|
||||
:type shuffle: bool, optional. Default is False
|
||||
:type seed: int, optional. Default is 1024
|
||||
:type add_sampler: bool, optional. Default is True
|
||||
:type drop_last: bool, optional. Default is False
|
||||
:type pin_memory: bool, optional. Default is False
|
||||
:type num_workers: int, optional. Default is 0
|
||||
|
||||
:return: A object of :class:`torch.utils.data.DataLoader`
|
||||
:rtype: :class:`torch.utils.data.DataLoader`
|
||||
Returns:
|
||||
:class:`torch.utils.data.DataLoader`: A DataLoader used for training or testing.
|
||||
"""
|
||||
_kwargs = kwargs.copy()
|
||||
|
||||
|
Reference in New Issue
Block a user