[llama] fix dataloader for hybrid parallel (#5358)

* [plugin] refactor prepare dataloader

* [plugin] update train script
This commit is contained in:
Hongxin Liu
2024-02-05 15:14:56 +08:00
committed by GitHub
parent 2dd01e3a14
commit 6c0fa7b9a8
6 changed files with 45 additions and 65 deletions

View File

@@ -1205,7 +1205,16 @@ class HybridParallelPlugin(PipelinePluginBase):
return outputs
def prepare_dataloader(
self, dataset, batch_size, shuffle=False, seed=1024, drop_last=False, pin_memory=False, num_workers=0, **kwargs
self,
dataset,
batch_size,
shuffle=False,
seed=1024,
drop_last=False,
pin_memory=False,
num_workers=0,
distributed_sampler_cls=None,
**kwargs,
):
r"""
Prepare a dataloader for distributed training. The dataloader will be wrapped by
@@ -1229,7 +1238,8 @@ class HybridParallelPlugin(PipelinePluginBase):
:class:`torch.utils.data.DataLoader`: A DataLoader used for training or testing.
"""
_kwargs = kwargs.copy()
sampler = DistributedSampler(
distributed_sampler_cls = distributed_sampler_cls or DistributedSampler
sampler = distributed_sampler_cls(
dataset, num_replicas=self.pg_mesh.size(DP_AXIS), rank=self.pg_mesh.coordinate(DP_AXIS), shuffle=shuffle
)