mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-04 18:40:28 +00:00
[legacy] move builder and registry to legacy (#4603)
This commit is contained in:
@@ -4,15 +4,15 @@
|
||||
|
||||
import math
|
||||
import random
|
||||
import numpy as np
|
||||
from typing import TypeVar, Iterator
|
||||
from typing import Iterator, TypeVar
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.data import Sampler, Dataset, DataLoader
|
||||
from torch.utils.data import DataLoader, Dataset, Sampler
|
||||
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.registry import DATA_SAMPLERS
|
||||
from colossalai.legacy.registry import DATA_SAMPLERS
|
||||
|
||||
T_co = TypeVar('T_co', covariant=True)
|
||||
|
||||
@@ -30,11 +30,7 @@ class DataParallelSampler(Sampler):
|
||||
the batch size, then the last batch will be smaller, defaults to False.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
dataset: Dataset,
|
||||
shuffle: bool = False,
|
||||
seed: int = 0,
|
||||
drop_last: bool = False) -> None:
|
||||
def __init__(self, dataset: Dataset, shuffle: bool = False, seed: int = 0, drop_last: bool = False) -> None:
|
||||
self.dataset = dataset
|
||||
self.num_replicas = gpc.get_world_size(ParallelMode.DATA)
|
||||
self.rank = gpc.get_local_rank(ParallelMode.DATA)
|
||||
@@ -54,8 +50,7 @@ class DataParallelSampler(Sampler):
|
||||
self.num_replicas # type: ignore[arg-type]
|
||||
)
|
||||
else:
|
||||
self.num_samples = math.ceil(
|
||||
len(self.dataset) / self.num_replicas) # type: ignore[arg-type]
|
||||
self.num_samples = math.ceil(len(self.dataset) / self.num_replicas) # type: ignore[arg-type]
|
||||
self.total_size = self.num_samples * self.num_replicas
|
||||
self.shuffle = shuffle
|
||||
self.seed = seed
|
||||
@@ -72,7 +67,7 @@ class DataParallelSampler(Sampler):
|
||||
# set_epoch manually
|
||||
self.epoch += 1
|
||||
else:
|
||||
indices = list(range(len(self.dataset))) # type: ignore[arg-type]
|
||||
indices = list(range(len(self.dataset))) # type: ignore[arg-type]
|
||||
|
||||
if not self.drop_last:
|
||||
# add extra samples to make it evenly divisible
|
||||
@@ -80,8 +75,7 @@ class DataParallelSampler(Sampler):
|
||||
if padding_size <= len(indices):
|
||||
indices += indices[:padding_size]
|
||||
else:
|
||||
indices += (indices * math.ceil(padding_size /
|
||||
len(indices)))[:padding_size]
|
||||
indices += (indices * math.ceil(padding_size / len(indices)))[:padding_size]
|
||||
else:
|
||||
# remove tail of data to make it evenly divisible.
|
||||
indices = indices[:self.total_size]
|
||||
@@ -109,8 +103,8 @@ class DataParallelSampler(Sampler):
|
||||
|
||||
def get_dataloader(dataset,
|
||||
shuffle=False,
|
||||
seed=1024,
|
||||
add_sampler=True,
|
||||
seed=1024,
|
||||
add_sampler=True,
|
||||
drop_last=False,
|
||||
pin_memory=False,
|
||||
num_workers=0,
|
||||
|
Reference in New Issue
Block a user