mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 11:02:05 +00:00
replace the customized dataloader setup with the build-in one
This commit is contained in:
@@ -4,22 +4,16 @@
|
||||
Dataloader for sft, dpo, ppo
|
||||
"""
|
||||
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable, Dict, Iterator, List, Optional, Sequence, Union
|
||||
from typing import Dict, Iterator, List, Optional, Sequence, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn.functional as F
|
||||
from coati.dataset.utils import chuncate_sequence, pad_to_max_len
|
||||
from datasets import Dataset as HFDataset
|
||||
from datasets import dataset_dict, load_from_disk
|
||||
from torch.distributed import ProcessGroup
|
||||
from torch.distributed.distributed_c10d import _get_default_group
|
||||
from torch.utils.data import ConcatDataset, DataLoader, Dataset, DistributedSampler
|
||||
from torch.utils.data import ConcatDataset, Dataset, DistributedSampler
|
||||
from transformers.tokenization_utils import PreTrainedTokenizer
|
||||
|
||||
DatasetType = Union[Dataset, ConcatDataset, dataset_dict.Dataset]
|
||||
@@ -236,159 +230,26 @@ class DataCollatorForPreferenceDataset(object):
|
||||
|
||||
|
||||
class StatefulDistributedSampler(DistributedSampler):
|
||||
"""
|
||||
Stateful distributed sampler for multi-stage training.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dataset: DatasetType,
|
||||
dataset: Dataset,
|
||||
num_replicas: Optional[int] = None,
|
||||
rank: Optional[int] = None,
|
||||
shuffle: bool = True,
|
||||
seed: int = 0,
|
||||
drop_last: bool = False,
|
||||
tp_size: int = 1,
|
||||
sp_size: int = 1,
|
||||
pp_size: int = 1,
|
||||
) -> None:
|
||||
if not tp_size > 1:
|
||||
super().__init__(
|
||||
dataset=dataset,
|
||||
num_replicas=num_replicas,
|
||||
rank=rank,
|
||||
shuffle=shuffle,
|
||||
seed=seed,
|
||||
drop_last=drop_last,
|
||||
)
|
||||
else:
|
||||
# adapted from https://github.com/pytorch/pytorch/blob/4979f9c0d72490970e2019bb1d2284f83d93f76b/torch/utils/data/distributed.py#L62
|
||||
if rank is None:
|
||||
rank = dist.get_rank()
|
||||
dist.get_world_size()
|
||||
# dp_size = world_size // (tp_size * sp_size * pp_size)
|
||||
dp_rank = int(rank / (tp_size * sp_size * pp_size)) # data parallel rank:
|
||||
if rank < 0:
|
||||
raise ValueError(f"Invalid rank {rank}, rank should be in the interval [0, 0]")
|
||||
self.dataset = dataset
|
||||
self.num_replicas = num_replicas
|
||||
self.dp_rank = dp_rank
|
||||
self.rank = rank
|
||||
self.epoch = 0
|
||||
self.drop_last = drop_last
|
||||
# If the dataset length is evenly divisible by # of replicas, then there
|
||||
# is no need to drop any data, since the dataset will be split equally.
|
||||
if self.drop_last and len(self.dataset) % self.num_replicas != 0: # type: ignore[arg-type]
|
||||
# Split to nearest available length that is evenly divisible.
|
||||
# This is to ensure each rank receives the same amount of data when
|
||||
# using this Sampler.
|
||||
self.num_samples = math.ceil(
|
||||
(len(self.dataset) - self.num_replicas) / self.num_replicas # type: ignore[arg-type]
|
||||
)
|
||||
else:
|
||||
self.num_samples = math.ceil(len(self.dataset) / self.num_replicas) # type: ignore[arg-type]
|
||||
self.total_size = self.num_samples * self.num_replicas
|
||||
self.shuffle = shuffle
|
||||
self.seed = seed
|
||||
self.start_index = 0
|
||||
self.tp_size = tp_size
|
||||
super().__init__(dataset, num_replicas, rank, shuffle, seed, drop_last)
|
||||
self.start_index: int = 0
|
||||
|
||||
def __iter__(self) -> Iterator:
|
||||
if self.tp_size > 1:
|
||||
# TODO Add support for tp_group not equal to 1
|
||||
pass
|
||||
# adpated from https://github.com/pytorch/pytorch/blob/4979f9c0d72490970e2019bb1d2284f83d93f76b/torch/utils/data/distributed.py#L96
|
||||
if self.shuffle:
|
||||
# deterministically shuffle based on epoch and seed
|
||||
g = torch.Generator()
|
||||
g.manual_seed(self.seed + self.epoch)
|
||||
indices = torch.randperm(len(self.dataset), generator=g).tolist() # type: ignore[arg-type]
|
||||
else:
|
||||
indices = list(range(len(self.dataset))) # type: ignore[arg-type]
|
||||
|
||||
if not self.drop_last:
|
||||
# add extra samples to make it evenly divisible
|
||||
padding_size = self.total_size - len(indices)
|
||||
if padding_size <= len(indices):
|
||||
indices += indices[:padding_size]
|
||||
else:
|
||||
indices += (indices * math.ceil(padding_size / len(indices)))[:padding_size]
|
||||
else:
|
||||
# remove tail of data to make it evenly divisible.
|
||||
indices = indices[: self.total_size]
|
||||
assert len(indices) == self.total_size
|
||||
|
||||
# subsample
|
||||
indices = indices[
|
||||
self.dp_rank : self.dp_rank + self.total_size : self.num_replicas
|
||||
] # num_replicas=tp_group=1, we only support tp_group==1 for now
|
||||
assert len(indices) == self.num_samples
|
||||
return iter(indices)
|
||||
|
||||
else:
|
||||
iterator = super().__iter__()
|
||||
indices = list(iterator)
|
||||
indices = indices[self.start_index :]
|
||||
return iter(indices)
|
||||
iterator = super().__iter__()
|
||||
indices = list(iterator)
|
||||
indices = indices[self.start_index :]
|
||||
return iter(indices)
|
||||
|
||||
def __len__(self) -> int:
|
||||
return self.num_samples - self.start_index
|
||||
|
||||
def set_start_index(self, start_index: int) -> None:
|
||||
self.start_index = start_index
|
||||
|
||||
|
||||
def setup_distributed_dataloader(
|
||||
dataset: DatasetType,
|
||||
batch_size: int = 1,
|
||||
shuffle: bool = False,
|
||||
seed: int = 1024,
|
||||
drop_last: bool = False,
|
||||
pin_memory: bool = False,
|
||||
num_workers: int = 0,
|
||||
collate_fn: Callable[[Sequence[Dict[str, Union[str, List[int]]]]], Dict[str, torch.Tensor]] = None,
|
||||
process_group: Optional[ProcessGroup] = None,
|
||||
tp_size: Optional[int] = 1,
|
||||
sp_size: Optional[int] = 1,
|
||||
pp_size: Optional[int] = 1,
|
||||
**kwargs,
|
||||
) -> DataLoader:
|
||||
"""
|
||||
Setup dataloader for distributed training.
|
||||
"""
|
||||
_kwargs = kwargs.copy()
|
||||
process_group = process_group or _get_default_group()
|
||||
# world_size = tp_size * pp_size
|
||||
assert (
|
||||
process_group.size() % tp_size == 0
|
||||
), f"process_group.size()={process_group.size()} must be divisible by tp_size={tp_size}"
|
||||
sampler = StatefulDistributedSampler(
|
||||
dataset=dataset,
|
||||
num_replicas=int(process_group.size() / tp_size),
|
||||
rank=process_group.rank(),
|
||||
shuffle=shuffle,
|
||||
seed=seed,
|
||||
drop_last=drop_last,
|
||||
tp_size=tp_size,
|
||||
sp_size=sp_size,
|
||||
pp_size=pp_size,
|
||||
)
|
||||
|
||||
# Deterministic dataloader
|
||||
def seed_worker(worker_id: int) -> None:
|
||||
worker_seed = seed
|
||||
np.random.seed(worker_seed)
|
||||
torch.manual_seed(worker_seed)
|
||||
random.seed(worker_seed)
|
||||
|
||||
return DataLoader(
|
||||
dataset=dataset,
|
||||
batch_size=batch_size,
|
||||
sampler=sampler,
|
||||
num_workers=num_workers,
|
||||
collate_fn=collate_fn,
|
||||
pin_memory=pin_memory,
|
||||
drop_last=drop_last,
|
||||
worker_init_fn=seed_worker,
|
||||
**_kwargs,
|
||||
)
|
||||
|
Reference in New Issue
Block a user