mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2026-04-25 01:03:35 +00:00
71 lines
2.1 KiB
Python
71 lines
2.1 KiB
Python
import glob
|
|
import os
|
|
import colossalai
|
|
from colossalai.context import ParallelMode
|
|
from colossalai.core import global_context as gpc
|
|
from colossalai.logging import get_global_dist_logger
|
|
from colossalai.trainer import Trainer
|
|
from colossalai.utils import set_global_multitimer_status
|
|
from dataloader.imagenet_dali_dataloader import DaliDataloader
|
|
|
|
|
|
def build_dali_train():
|
|
root = gpc.config.dali.root
|
|
train_pat = os.path.join(root, 'train/*')
|
|
train_idx_pat = os.path.join(root, 'idx_files/train/*')
|
|
return DaliDataloader(
|
|
sorted(glob.glob(train_pat)),
|
|
sorted(glob.glob(train_idx_pat)),
|
|
batch_size=gpc.config.BATCH_SIZE,
|
|
shard_id=gpc.get_local_rank(ParallelMode.DATA),
|
|
num_shards=gpc.get_world_size(ParallelMode.DATA),
|
|
training=True,
|
|
gpu_aug=gpc.config.dali.gpu_aug,
|
|
cuda=True,
|
|
mixup_alpha=gpc.config.dali.mixup_alpha
|
|
)
|
|
|
|
|
|
def build_dali_test():
|
|
root = gpc.config.dali.root
|
|
val_pat = os.path.join(root, 'validation/*')
|
|
val_idx_pat = os.path.join(root, 'idx_files/validation/*')
|
|
return DaliDataloader(
|
|
sorted(glob.glob(val_pat)),
|
|
sorted(glob.glob(val_idx_pat)),
|
|
batch_size=gpc.config.BATCH_SIZE,
|
|
shard_id=gpc.get_local_rank(ParallelMode.DATA),
|
|
num_shards=gpc.get_world_size(ParallelMode.DATA),
|
|
training=False,
|
|
# gpu_aug=gpc.config.dali.gpu_aug,
|
|
gpu_aug=False,
|
|
cuda=True,
|
|
mixup_alpha=gpc.config.dali.mixup_alpha
|
|
)
|
|
|
|
|
|
def main():
|
|
engine, train_dataloader, test_dataloader = colossalai.initialize(
|
|
train_dataloader=build_dali_train,
|
|
test_dataloader=build_dali_test
|
|
)
|
|
logger = get_global_dist_logger()
|
|
set_global_multitimer_status(True)
|
|
timer = colossalai.utils.get_global_multitimer()
|
|
trainer = Trainer(engine=engine,
|
|
verbose=True,
|
|
timer=timer)
|
|
|
|
trainer.fit(
|
|
train_dataloader=train_dataloader,
|
|
test_dataloader=test_dataloader,
|
|
epochs=gpc.config.NUM_EPOCHS,
|
|
hooks_cfg=gpc.config.hooks,
|
|
display_progress=True,
|
|
test_interval=1
|
|
)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|