Files
ColossalAI/examples/vit-b16/train_dali.py

71 lines
2.1 KiB
Python

import glob
import os
import colossalai
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.logging import get_global_dist_logger
from colossalai.trainer import Trainer
from colossalai.utils import set_global_multitimer_status
from dataloader.imagenet_dali_dataloader import DaliDataloader
def build_dali_train():
root = gpc.config.dali.root
train_pat = os.path.join(root, 'train/*')
train_idx_pat = os.path.join(root, 'idx_files/train/*')
return DaliDataloader(
sorted(glob.glob(train_pat)),
sorted(glob.glob(train_idx_pat)),
batch_size=gpc.config.BATCH_SIZE,
shard_id=gpc.get_local_rank(ParallelMode.DATA),
num_shards=gpc.get_world_size(ParallelMode.DATA),
training=True,
gpu_aug=gpc.config.dali.gpu_aug,
cuda=True,
mixup_alpha=gpc.config.dali.mixup_alpha
)
def build_dali_test():
root = gpc.config.dali.root
val_pat = os.path.join(root, 'validation/*')
val_idx_pat = os.path.join(root, 'idx_files/validation/*')
return DaliDataloader(
sorted(glob.glob(val_pat)),
sorted(glob.glob(val_idx_pat)),
batch_size=gpc.config.BATCH_SIZE,
shard_id=gpc.get_local_rank(ParallelMode.DATA),
num_shards=gpc.get_world_size(ParallelMode.DATA),
training=False,
# gpu_aug=gpc.config.dali.gpu_aug,
gpu_aug=False,
cuda=True,
mixup_alpha=gpc.config.dali.mixup_alpha
)
def main():
engine, train_dataloader, test_dataloader = colossalai.initialize(
train_dataloader=build_dali_train,
test_dataloader=build_dali_test
)
logger = get_global_dist_logger()
set_global_multitimer_status(True)
timer = colossalai.utils.get_global_multitimer()
trainer = Trainer(engine=engine,
verbose=True,
timer=timer)
trainer.fit(
train_dataloader=train_dataloader,
test_dataloader=test_dataloader,
epochs=gpc.config.NUM_EPOCHS,
hooks_cfg=gpc.config.hooks,
display_progress=True,
test_interval=1
)
if __name__ == '__main__':
main()