mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-14 05:33:23 +00:00
[llama] fix dataloader for hybrid parallel (#5358)
* [plugin] refactor prepare dataloader * [plugin] update train script
This commit is contained in:
@@ -456,7 +456,16 @@ class GeminiPlugin(DPPluginBase):
|
||||
return ["cuda", "npu"]
|
||||
|
||||
def prepare_dataloader(
|
||||
self, dataset, batch_size, shuffle=False, seed=1024, drop_last=False, pin_memory=False, num_workers=0, **kwargs
|
||||
self,
|
||||
dataset,
|
||||
batch_size,
|
||||
shuffle=False,
|
||||
seed=1024,
|
||||
drop_last=False,
|
||||
pin_memory=False,
|
||||
num_workers=0,
|
||||
distributed_sampler_cls=None,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Prepare a dataloader for distributed training. The dataloader will be wrapped by
|
||||
@@ -484,7 +493,8 @@ class GeminiPlugin(DPPluginBase):
|
||||
extra_dp_world_size = self.pg_mesh.size(DP_AXIS)
|
||||
zero_rank = self.pg_mesh.coordinate(ZERO_AXIS)
|
||||
extra_dp_rank = self.pg_mesh.coordinate(DP_AXIS)
|
||||
sampler = DistributedSampler(
|
||||
distributed_sampler_cls = distributed_sampler_cls or DistributedSampler
|
||||
sampler = distributed_sampler_cls(
|
||||
dataset,
|
||||
num_replicas=zero_world_size * extra_dp_world_size,
|
||||
rank=zero_rank * extra_dp_world_size + extra_dp_rank,
|
||||
|
Reference in New Issue
Block a user