mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-17 15:11:20 +00:00
[llama] fix dataloader for hybrid parallel (#5358)
* [plugin] refactor prepare dataloader * [plugin] update train script
This commit is contained in:
@@ -21,7 +21,16 @@ class DPPluginBase(Plugin):
|
||||
self.world_size = dist.get_world_size()
|
||||
|
||||
def prepare_dataloader(
|
||||
self, dataset, batch_size, shuffle=False, seed=1024, drop_last=False, pin_memory=False, num_workers=0, **kwargs
|
||||
self,
|
||||
dataset,
|
||||
batch_size,
|
||||
shuffle=False,
|
||||
seed=1024,
|
||||
drop_last=False,
|
||||
pin_memory=False,
|
||||
num_workers=0,
|
||||
distributed_sampler_cls=None,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Prepare a dataloader for distributed training. The dataloader will be wrapped by
|
||||
@@ -45,7 +54,8 @@ class DPPluginBase(Plugin):
|
||||
:class:`torch.utils.data.DataLoader`: A DataLoader used for training or testing.
|
||||
"""
|
||||
_kwargs = kwargs.copy()
|
||||
sampler = DistributedSampler(dataset, num_replicas=self.world_size, rank=self.rank, shuffle=shuffle)
|
||||
distributed_sampler_cls = distributed_sampler_cls or DistributedSampler
|
||||
sampler = distributed_sampler_cls(dataset, num_replicas=self.world_size, rank=self.rank, shuffle=shuffle)
|
||||
|
||||
# Deterministic dataloader
|
||||
def seed_worker(worker_id):
|
||||
|
Reference in New Issue
Block a user