mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-10-01 15:18:51 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -14,7 +14,6 @@
|
||||
# limitations under the License.
|
||||
"""Dataloaders."""
|
||||
|
||||
import random
|
||||
|
||||
import torch
|
||||
|
||||
@@ -22,61 +21,60 @@ from colossalai.legacy.context import ParallelMode
|
||||
from colossalai.legacy.core import global_context as gpc
|
||||
|
||||
|
||||
def build_pretraining_data_loader(dataset, consumed_samples, micro_batch_size, dataloader_type='single', num_workers=0):
|
||||
def build_pretraining_data_loader(dataset, consumed_samples, micro_batch_size, dataloader_type="single", num_workers=0):
|
||||
"""Build dataloader given an input dataset."""
|
||||
|
||||
if dataset is None:
|
||||
return None
|
||||
|
||||
# Megatron sampler
|
||||
if dataloader_type == 'single':
|
||||
batch_sampler = MegatronPretrainingSampler(total_samples=len(dataset),
|
||||
consumed_samples=consumed_samples,
|
||||
micro_batch_size=micro_batch_size,
|
||||
data_parallel_rank=gpc.get_local_rank(ParallelMode.DATA),
|
||||
data_parallel_size=gpc.get_world_size(ParallelMode.DATA))
|
||||
elif dataloader_type == 'cyclic':
|
||||
batch_sampler = MegatronPretrainingRandomSampler(total_samples=len(dataset),
|
||||
consumed_samples=consumed_samples,
|
||||
micro_batch_size=micro_batch_size,
|
||||
data_parallel_rank=gpc.get_local_rank(ParallelMode.DATA),
|
||||
data_parallel_size=gpc.get_world_size(ParallelMode.DATA))
|
||||
if dataloader_type == "single":
|
||||
batch_sampler = MegatronPretrainingSampler(
|
||||
total_samples=len(dataset),
|
||||
consumed_samples=consumed_samples,
|
||||
micro_batch_size=micro_batch_size,
|
||||
data_parallel_rank=gpc.get_local_rank(ParallelMode.DATA),
|
||||
data_parallel_size=gpc.get_world_size(ParallelMode.DATA),
|
||||
)
|
||||
elif dataloader_type == "cyclic":
|
||||
batch_sampler = MegatronPretrainingRandomSampler(
|
||||
total_samples=len(dataset),
|
||||
consumed_samples=consumed_samples,
|
||||
micro_batch_size=micro_batch_size,
|
||||
data_parallel_rank=gpc.get_local_rank(ParallelMode.DATA),
|
||||
data_parallel_size=gpc.get_world_size(ParallelMode.DATA),
|
||||
)
|
||||
else:
|
||||
raise Exception('{} dataloader type is not supported.'.format(dataloader_type))
|
||||
raise Exception("{} dataloader type is not supported.".format(dataloader_type))
|
||||
|
||||
# Torch dataloader.
|
||||
return torch.utils.data.DataLoader(dataset, batch_sampler=batch_sampler, num_workers=num_workers, pin_memory=True)
|
||||
|
||||
|
||||
class MegatronPretrainingSampler:
|
||||
|
||||
def __init__(self,
|
||||
total_samples,
|
||||
consumed_samples,
|
||||
micro_batch_size,
|
||||
data_parallel_rank,
|
||||
data_parallel_size,
|
||||
drop_last=True):
|
||||
def __init__(
|
||||
self, total_samples, consumed_samples, micro_batch_size, data_parallel_rank, data_parallel_size, drop_last=True
|
||||
):
|
||||
# Keep a copy of input params for later use.
|
||||
self.total_samples = total_samples
|
||||
self.consumed_samples = consumed_samples
|
||||
self.micro_batch_size = micro_batch_size
|
||||
self.data_parallel_rank = data_parallel_rank
|
||||
self.micro_batch_times_data_parallel_size = \
|
||||
self.micro_batch_size * data_parallel_size
|
||||
self.micro_batch_times_data_parallel_size = self.micro_batch_size * data_parallel_size
|
||||
self.drop_last = drop_last
|
||||
|
||||
# Sanity checks.
|
||||
assert self.total_samples > 0, \
|
||||
'no sample to consume: {}'.format(self.total_samples)
|
||||
assert self.consumed_samples < self.total_samples, \
|
||||
'no samples left to consume: {}, {}'.format(self.consumed_samples,
|
||||
self.total_samples)
|
||||
assert self.total_samples > 0, "no sample to consume: {}".format(self.total_samples)
|
||||
assert self.consumed_samples < self.total_samples, "no samples left to consume: {}, {}".format(
|
||||
self.consumed_samples, self.total_samples
|
||||
)
|
||||
assert self.micro_batch_size > 0
|
||||
assert data_parallel_size > 0
|
||||
assert self.data_parallel_rank < data_parallel_size, \
|
||||
'data_parallel_rank should be smaller than data size: {}, ' \
|
||||
'{}'.format(self.data_parallel_rank, data_parallel_size)
|
||||
assert (
|
||||
self.data_parallel_rank < data_parallel_size
|
||||
), "data_parallel_rank should be smaller than data size: {}, " "{}".format(
|
||||
self.data_parallel_rank, data_parallel_size
|
||||
)
|
||||
|
||||
def __len__(self):
|
||||
return self.total_samples
|
||||
@@ -103,7 +101,6 @@ class MegatronPretrainingSampler:
|
||||
|
||||
|
||||
class MegatronPretrainingRandomSampler:
|
||||
|
||||
def __init__(self, total_samples, consumed_samples, micro_batch_size, data_parallel_rank, data_parallel_size):
|
||||
# Keep a copy of input params for later use.
|
||||
self.total_samples = total_samples
|
||||
@@ -111,19 +108,18 @@ class MegatronPretrainingRandomSampler:
|
||||
self.micro_batch_size = micro_batch_size
|
||||
self.data_parallel_rank = data_parallel_rank
|
||||
self.data_parallel_size = data_parallel_size
|
||||
self.micro_batch_times_data_parallel_size = \
|
||||
self.micro_batch_size * data_parallel_size
|
||||
self.last_batch_size = \
|
||||
self.total_samples % self.micro_batch_times_data_parallel_size
|
||||
self.micro_batch_times_data_parallel_size = self.micro_batch_size * data_parallel_size
|
||||
self.last_batch_size = self.total_samples % self.micro_batch_times_data_parallel_size
|
||||
|
||||
# Sanity checks.
|
||||
assert self.total_samples > 0, \
|
||||
'no sample to consume: {}'.format(self.total_samples)
|
||||
assert self.total_samples > 0, "no sample to consume: {}".format(self.total_samples)
|
||||
assert self.micro_batch_size > 0
|
||||
assert data_parallel_size > 0
|
||||
assert self.data_parallel_rank < data_parallel_size, \
|
||||
'data_parallel_rank should be smaller than data size: {}, ' \
|
||||
'{}'.format(self.data_parallel_rank, data_parallel_size)
|
||||
assert (
|
||||
self.data_parallel_rank < data_parallel_size
|
||||
), "data_parallel_rank should be smaller than data size: {}, " "{}".format(
|
||||
self.data_parallel_rank, data_parallel_size
|
||||
)
|
||||
|
||||
def __len__(self):
|
||||
return self.total_samples
|
||||
@@ -135,8 +131,7 @@ class MegatronPretrainingRandomSampler:
|
||||
assert current_epoch_samples % self.micro_batch_times_data_parallel_size == 0
|
||||
|
||||
# data sharding and random sampling
|
||||
bucket_size = (self.total_samples // self.micro_batch_times_data_parallel_size) \
|
||||
* self.micro_batch_size
|
||||
bucket_size = (self.total_samples // self.micro_batch_times_data_parallel_size) * self.micro_batch_size
|
||||
bucket_offset = current_epoch_samples // self.data_parallel_size
|
||||
start_idx = self.data_parallel_rank * bucket_size
|
||||
|
||||
|
Reference in New Issue
Block a user