mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-27 15:57:16 +00:00
[example] update vit ci script (#2469)
* [example] update vit ci script * [example] update requirements * [example] update requirements
This commit is contained in:
parent
867c8c2d3a
commit
8e85d2440a
32
examples/images/vit/configs/vit_1d_tp2_ci.py
Normal file
32
examples/images/vit/configs/vit_1d_tp2_ci.py
Normal file
@ -0,0 +1,32 @@
|
|||||||
|
from colossalai.amp import AMP_TYPE
|
||||||
|
|
||||||
|
# hyperparameters
|
||||||
|
# BATCH_SIZE is as per GPU
|
||||||
|
# global batch size = BATCH_SIZE x data parallel size
|
||||||
|
BATCH_SIZE = 8
|
||||||
|
LEARNING_RATE = 3e-3
|
||||||
|
WEIGHT_DECAY = 0.3
|
||||||
|
NUM_EPOCHS = 3
|
||||||
|
WARMUP_EPOCHS = 1
|
||||||
|
|
||||||
|
# model config
|
||||||
|
IMG_SIZE = 224
|
||||||
|
PATCH_SIZE = 16
|
||||||
|
HIDDEN_SIZE = 32
|
||||||
|
DEPTH = 2
|
||||||
|
NUM_HEADS = 4
|
||||||
|
MLP_RATIO = 4
|
||||||
|
NUM_CLASSES = 10
|
||||||
|
CHECKPOINT = False
|
||||||
|
SEQ_LENGTH = (IMG_SIZE // PATCH_SIZE)**2 + 1 # add 1 for cls token
|
||||||
|
|
||||||
|
USE_DDP = True
|
||||||
|
TP_WORLD_SIZE = 2
|
||||||
|
TP_TYPE = 'row'
|
||||||
|
parallel = dict(tensor=dict(mode="1d", size=TP_WORLD_SIZE),)
|
||||||
|
|
||||||
|
fp16 = dict(mode=AMP_TYPE.NAIVE)
|
||||||
|
clip_grad_norm = 1.0
|
||||||
|
gradient_accumulation = 2
|
||||||
|
|
||||||
|
LOG_PATH = "./log_ci"
|
@ -1,2 +1,8 @@
|
|||||||
colossalai >= 0.1.12
|
colossalai >= 0.1.12
|
||||||
torch >= 1.8.1
|
torch >= 1.8.1
|
||||||
|
numpy>=1.24.1
|
||||||
|
timm>=0.6.12
|
||||||
|
titans>=0.0.7
|
||||||
|
tqdm>=4.61.2
|
||||||
|
transformers>=4.25.1
|
||||||
|
nvidia-dali-cuda110>=1.8.0 --extra-index-url https://developer.download.nvidia.com/compute/redist
|
||||||
|
9
examples/images/vit/test_ci.sh
Normal file
9
examples/images/vit/test_ci.sh
Normal file
@ -0,0 +1,9 @@
|
|||||||
|
export OMP_NUM_THREADS=4
|
||||||
|
|
||||||
|
pip install -r requirements.txt
|
||||||
|
|
||||||
|
# train
|
||||||
|
colossalai run \
|
||||||
|
--nproc_per_node 4 train.py \
|
||||||
|
--config configs/vit_1d_tp2_ci.py \
|
||||||
|
--dummy_data
|
@ -7,6 +7,7 @@ import torch.nn.functional as F
|
|||||||
from timm.models.vision_transformer import _create_vision_transformer
|
from timm.models.vision_transformer import _create_vision_transformer
|
||||||
from titans.dataloader.imagenet import build_dali_imagenet
|
from titans.dataloader.imagenet import build_dali_imagenet
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
from vit import DummyDataLoader
|
||||||
|
|
||||||
import colossalai
|
import colossalai
|
||||||
from colossalai.core import global_context as gpc
|
from colossalai.core import global_context as gpc
|
||||||
@ -56,8 +57,8 @@ def init_spec_func(model, tp_type):
|
|||||||
def train_imagenet():
|
def train_imagenet():
|
||||||
|
|
||||||
parser = colossalai.get_default_parser()
|
parser = colossalai.get_default_parser()
|
||||||
parser.add_argument('--from_torch', default=True, action='store_true')
|
parser.add_argument('--resume_from', default=False, action='store_true')
|
||||||
parser.add_argument('--resume_from', default=False)
|
parser.add_argument('--dummy_data', default=False, action='store_true')
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
colossalai.launch_from_torch(config=args.config)
|
colossalai.launch_from_torch(config=args.config)
|
||||||
@ -74,10 +75,22 @@ def train_imagenet():
|
|||||||
logger.log_to_file(log_path)
|
logger.log_to_file(log_path)
|
||||||
|
|
||||||
logger.info('Build data loader', ranks=[0])
|
logger.info('Build data loader', ranks=[0])
|
||||||
root = os.environ['DATA']
|
if not args.dummy_data:
|
||||||
train_dataloader, test_dataloader = build_dali_imagenet(root,
|
root = os.environ['DATA']
|
||||||
train_batch_size=gpc.config.BATCH_SIZE,
|
train_dataloader, test_dataloader = build_dali_imagenet(root,
|
||||||
test_batch_size=gpc.config.BATCH_SIZE)
|
train_batch_size=gpc.config.BATCH_SIZE,
|
||||||
|
test_batch_size=gpc.config.BATCH_SIZE)
|
||||||
|
else:
|
||||||
|
train_dataloader = DummyDataLoader(length=10,
|
||||||
|
batch_size=gpc.config.BATCH_SIZE,
|
||||||
|
category=gpc.config.NUM_CLASSES,
|
||||||
|
image_size=gpc.config.IMG_SIZE,
|
||||||
|
return_dict=False)
|
||||||
|
test_dataloader = DummyDataLoader(length=5,
|
||||||
|
batch_size=gpc.config.BATCH_SIZE,
|
||||||
|
category=gpc.config.NUM_CLASSES,
|
||||||
|
image_size=gpc.config.IMG_SIZE,
|
||||||
|
return_dict=False)
|
||||||
|
|
||||||
logger.info('Build model', ranks=[0])
|
logger.info('Build model', ranks=[0])
|
||||||
|
|
||||||
|
@ -32,21 +32,24 @@ class DummyDataGenerator(ABC):
|
|||||||
|
|
||||||
|
|
||||||
class DummyDataLoader(DummyDataGenerator):
|
class DummyDataLoader(DummyDataGenerator):
|
||||||
batch_size = 4
|
|
||||||
channel = 3
|
def __init__(self, length=10, batch_size=4, channel=3, category=8, image_size=224, return_dict=True):
|
||||||
category = 8
|
super().__init__(length)
|
||||||
image_size = 224
|
self.batch_size = batch_size
|
||||||
|
self.channel = channel
|
||||||
|
self.category = category
|
||||||
|
self.image_size = image_size
|
||||||
|
self.return_dict = return_dict
|
||||||
|
|
||||||
def generate(self):
|
def generate(self):
|
||||||
image_dict = {}
|
image_dict = {}
|
||||||
image_dict['pixel_values'] = torch.rand(DummyDataLoader.batch_size,
|
image_dict['pixel_values'] = torch.rand(
|
||||||
DummyDataLoader.channel,
|
self.batch_size, self.channel, self.image_size, self.image_size, device=get_current_device()) * 2 - 1
|
||||||
DummyDataLoader.image_size,
|
image_dict['label'] = torch.randint(self.category, (self.batch_size,),
|
||||||
DummyDataLoader.image_size,
|
|
||||||
device=get_current_device()) * 2 - 1
|
|
||||||
image_dict['label'] = torch.randint(DummyDataLoader.category, (DummyDataLoader.batch_size,),
|
|
||||||
dtype=torch.int64,
|
dtype=torch.int64,
|
||||||
device=get_current_device())
|
device=get_current_device())
|
||||||
|
if not self.return_dict:
|
||||||
|
return image_dict['pixel_values'], image_dict['label']
|
||||||
return image_dict
|
return image_dict
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user