mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 03:20:52 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -10,25 +10,19 @@ from .plugin_base import Plugin
|
||||
|
||||
|
||||
class DPPluginBase(Plugin):
|
||||
"""This is a base class for all DP plugins. It sets up world size and rank, and provides data loader creation.
|
||||
"""
|
||||
"""This is a base class for all DP plugins. It sets up world size and rank, and provides data loader creation."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
assert dist.is_initialized(
|
||||
), 'torch.distributed is not initialized, please use colossalai.launch to create the distributed environment'
|
||||
assert (
|
||||
dist.is_initialized()
|
||||
), "torch.distributed is not initialized, please use colossalai.launch to create the distributed environment"
|
||||
self.rank = dist.get_rank()
|
||||
self.world_size = dist.get_world_size()
|
||||
|
||||
def prepare_dataloader(self,
|
||||
dataset,
|
||||
batch_size,
|
||||
shuffle=False,
|
||||
seed=1024,
|
||||
drop_last=False,
|
||||
pin_memory=False,
|
||||
num_workers=0,
|
||||
**kwargs):
|
||||
def prepare_dataloader(
|
||||
self, dataset, batch_size, shuffle=False, seed=1024, drop_last=False, pin_memory=False, num_workers=0, **kwargs
|
||||
):
|
||||
r"""
|
||||
Prepare a dataloader for distributed training. The dataloader will be wrapped by
|
||||
`torch.utils.data.DataLoader` and `torch.utils.data.DistributedSampler`.
|
||||
@@ -60,11 +54,13 @@ class DPPluginBase(Plugin):
|
||||
torch.manual_seed(worker_seed)
|
||||
random.seed(worker_seed)
|
||||
|
||||
return DataLoader(dataset,
|
||||
batch_size=batch_size,
|
||||
sampler=sampler,
|
||||
worker_init_fn=seed_worker,
|
||||
drop_last=drop_last,
|
||||
pin_memory=pin_memory,
|
||||
num_workers=num_workers,
|
||||
**_kwargs)
|
||||
return DataLoader(
|
||||
dataset,
|
||||
batch_size=batch_size,
|
||||
sampler=sampler,
|
||||
worker_init_fn=seed_worker,
|
||||
drop_last=drop_last,
|
||||
pin_memory=pin_memory,
|
||||
num_workers=num_workers,
|
||||
**_kwargs,
|
||||
)
|
||||
|
Reference in New Issue
Block a user