[llama] fix dataloader for hybrid parallel (#5358)

* [plugin] refactor prepare dataloader

* [plugin] update train script
This commit is contained in:
Hongxin Liu
2024-02-05 15:14:56 +08:00
committed by GitHub
parent 2dd01e3a14
commit 6c0fa7b9a8
6 changed files with 45 additions and 65 deletions

View File

@@ -456,7 +456,16 @@ class GeminiPlugin(DPPluginBase):
return ["cuda", "npu"]
def prepare_dataloader(
self, dataset, batch_size, shuffle=False, seed=1024, drop_last=False, pin_memory=False, num_workers=0, **kwargs
self,
dataset,
batch_size,
shuffle=False,
seed=1024,
drop_last=False,
pin_memory=False,
num_workers=0,
distributed_sampler_cls=None,
**kwargs,
):
r"""
Prepare a dataloader for distributed training. The dataloader will be wrapped by
@@ -484,7 +493,8 @@ class GeminiPlugin(DPPluginBase):
extra_dp_world_size = self.pg_mesh.size(DP_AXIS)
zero_rank = self.pg_mesh.coordinate(ZERO_AXIS)
extra_dp_rank = self.pg_mesh.coordinate(DP_AXIS)
sampler = DistributedSampler(
distributed_sampler_cls = distributed_sampler_cls or DistributedSampler
sampler = distributed_sampler_cls(
dataset,
num_replicas=zero_world_size * extra_dp_world_size,
rank=zero_rank * extra_dp_world_size + extra_dp_rank,