mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-04-29 12:15:39 +00:00
* benchmark gpt2 * fix fix fix fix * [doc] fix typo in Colossal-LLaMA-2/README.md (#5247) * [workflow] fixed build CI (#5240) * [workflow] fixed build CI * polish * polish * polish * polish * polish * [ci] fixed booster test (#5251) * [ci] fixed booster test * [ci] fixed booster test * [ci] fixed booster test * [ci] fixed ddp test (#5254) * [ci] fixed ddp test * polish * fix typo in applications/ColossalEval/README.md (#5250) * [ci] fix shardformer tests. (#5255) * fix ci fix * revert: revert p2p * feat: add enable_metadata_cache option * revert: enable t5 tests --------- Co-authored-by: Wenhao Chen <cwher@outlook.com> * [doc] fix doc typo (#5256) * [doc] fix annotation display * [doc] fix llama2 doc * [hotfix]: add pp sanity check and fix mbs arg (#5268) * fix: fix misleading mbs arg * feat: add pp sanity check * fix: fix 1f1b sanity check * [workflow] fixed incomplete bash command (#5272) * [workflow] fixed oom tests (#5275) * [workflow] fixed oom tests * polish * polish * polish * [ci] fix test_hybrid_parallel_plugin_checkpoint_io.py (#5276) * fix ci fix * fix test * revert: revert p2p * feat: add enable_metadata_cache option * revert: enable t5 tests * fix --------- Co-authored-by: Wenhao Chen <cwher@outlook.com> * [shardformer] hybridparallelplugin support gradients accumulation. (#5246) * support gradients acc fix fix fix fix fix fix fix fix fix fix fix fix fix * fix fix * fix fix fix * [hotfix] Fix ShardFormer test execution path when using sequence parallelism (#5230) * fix auto loading gpt2 tokenizer (#5279) * [doc] add llama2-13B disyplay (#5285) * Update README.md * fix 13b typo --------- Co-authored-by: binmakeswell <binmakeswell@gmail.com> * fix llama pretrain (#5287) * fix * fix * fix fix * fix fix fix * fix fix * benchmark gpt2 * fix fix fix fix * [workflow] fixed build CI (#5240) * [workflow] fixed build CI * polish * polish * polish * polish * polish * [ci] fixed booster test (#5251) * [ci] fixed booster test * [ci] fixed booster test * [ci] fixed booster test * fix fix * fix fix fix * fix * fix fix fix fix fix * fix * Update shardformer.py --------- Co-authored-by: digger yu <digger-yu@outlook.com> Co-authored-by: Frank Lee <somerlee.9@gmail.com> Co-authored-by: Wenhao Chen <cwher@outlook.com> Co-authored-by: binmakeswell <binmakeswell@gmail.com> Co-authored-by: Zhongkai Zhao <kanezz620@gmail.com> Co-authored-by: Michelle <97082656+MichelleMa8@users.noreply.github.com> Co-authored-by: Desperado-Jia <502205863@qq.com>
124 lines
4.1 KiB
Python
124 lines
4.1 KiB
Python
import json
|
|
import random
|
|
from typing import Iterator, Optional
|
|
|
|
import numpy as np
|
|
import torch
|
|
from torch.distributed import ProcessGroup
|
|
from torch.distributed.distributed_c10d import _get_default_group
|
|
from torch.utils.data import DataLoader, Dataset, DistributedSampler
|
|
|
|
from colossalai.accelerator import get_accelerator
|
|
|
|
|
|
class StatefulDistributedSampler(DistributedSampler):
|
|
def __init__(
|
|
self,
|
|
dataset: Dataset,
|
|
num_replicas: Optional[int] = None,
|
|
rank: Optional[int] = None,
|
|
shuffle: bool = True,
|
|
seed: int = 0,
|
|
drop_last: bool = False,
|
|
) -> None:
|
|
super().__init__(dataset, num_replicas, rank, shuffle, seed, drop_last)
|
|
self.start_index: int = 0
|
|
|
|
def __iter__(self) -> Iterator:
|
|
iterator = super().__iter__()
|
|
indices = list(iterator)
|
|
indices = indices[self.start_index :]
|
|
return iter(indices)
|
|
|
|
def __len__(self) -> int:
|
|
return self.num_samples - self.start_index
|
|
|
|
def set_start_index(self, start_index: int) -> None:
|
|
self.start_index = start_index
|
|
|
|
|
|
def prepare_dataloader(
|
|
dataset,
|
|
batch_size,
|
|
shuffle=False,
|
|
seed=1024,
|
|
drop_last=False,
|
|
pin_memory=False,
|
|
num_workers=0,
|
|
process_group: Optional[ProcessGroup] = None,
|
|
**kwargs,
|
|
):
|
|
r"""
|
|
Prepare a dataloader for distributed training. The dataloader will be wrapped by
|
|
`torch.utils.data.DataLoader` and `StatefulDistributedSampler`.
|
|
|
|
|
|
Args:
|
|
dataset (`torch.utils.data.Dataset`): The dataset to be loaded.
|
|
shuffle (bool, optional): Whether to shuffle the dataset. Defaults to False.
|
|
seed (int, optional): Random worker seed for sampling, defaults to 1024.
|
|
add_sampler: Whether to add ``DistributedDataParallelSampler`` to the dataset. Defaults to True.
|
|
drop_last (bool, optional): Set to True to drop the last incomplete batch, if the dataset size
|
|
is not divisible by the batch size. If False and the size of dataset is not divisible by
|
|
the batch size, then the last batch will be smaller, defaults to False.
|
|
pin_memory (bool, optional): Whether to pin memory address in CPU memory. Defaults to False.
|
|
num_workers (int, optional): Number of worker threads for this dataloader. Defaults to 0.
|
|
kwargs (dict): optional parameters for ``torch.utils.data.DataLoader``, more details could be found in
|
|
`DataLoader <https://pytorch.org/docs/stable/_modules/torch/utils/data/dataloader.html#DataLoader>`_.
|
|
|
|
Returns:
|
|
:class:`torch.utils.data.DataLoader`: A DataLoader used for training or testing.
|
|
"""
|
|
_kwargs = kwargs.copy()
|
|
process_group = process_group or _get_default_group()
|
|
sampler = StatefulDistributedSampler(
|
|
dataset, num_replicas=process_group.size(), rank=process_group.rank(), shuffle=shuffle
|
|
)
|
|
|
|
# Deterministic dataloader
|
|
def seed_worker(worker_id):
|
|
worker_seed = seed
|
|
np.random.seed(worker_seed)
|
|
torch.manual_seed(worker_seed)
|
|
random.seed(worker_seed)
|
|
|
|
return DataLoader(
|
|
dataset,
|
|
batch_size=batch_size,
|
|
sampler=sampler,
|
|
worker_init_fn=seed_worker,
|
|
drop_last=drop_last,
|
|
pin_memory=pin_memory,
|
|
num_workers=num_workers,
|
|
**_kwargs,
|
|
)
|
|
|
|
|
|
def load_json(file_path: str):
|
|
with open(file_path, "r") as f:
|
|
return json.load(f)
|
|
|
|
|
|
def save_json(data, file_path: str):
|
|
with open(file_path, "w") as f:
|
|
json.dump(data, f, indent=4)
|
|
|
|
|
|
class RandomDataset(Dataset):
|
|
def __init__(self, num_samples: int = 1000, max_length: int = 2048, vocab_size: int = 32000):
|
|
self.num_samples = num_samples
|
|
self.max_length = max_length
|
|
self.input_ids = torch.randint(
|
|
0, vocab_size, (num_samples, max_length), device=get_accelerator().get_current_device()
|
|
)
|
|
self.attention_mask = torch.ones_like(self.input_ids)
|
|
|
|
def __len__(self):
|
|
return self.num_samples
|
|
|
|
def __getitem__(self, idx):
|
|
return {
|
|
"input_ids": self.input_ids[idx],
|
|
"attention_mask": self.attention_mask[idx],
|
|
"labels": self.input_ids[idx],
|
|
} |