Hotfix/Colossalai layers (#92)

* optimized 1d layer apis; reorganized nn.layer modules; fixed tests

* fixed 2.5d runtime issue

* reworked split batch, now called in trainer.schedule.load_batch

Co-authored-by: BoxiangW <45734921+BoxiangW@users.noreply.github.com>
This commit is contained in:
アマデウス
2021-12-29 23:32:10 +08:00
committed by GitHub
parent 0fedef4f3c
commit 01a80cd86d
71 changed files with 1033 additions and 773 deletions

View File

@@ -2,7 +2,7 @@ BATCH_SIZE = 512
LEARNING_RATE = 2e-3
WEIGHT_DECAY = 3e-2
TENSOR_PARALLEL_SIZE = 4
TENSOR_PARALLEL_SIZE = 2
TENSOR_PARALLEL_MODE = '1d'
NUM_EPOCHS = 200

View File

@@ -72,13 +72,11 @@ def train_cifar():
os.mkdir(log_path)
logger.log_to_file(log_path)
tp = gpc.config.parallel.tensor.mode
model = vit_lite_depth7_patch4_32(tensor_parallel=tp)
model = vit_lite_depth7_patch4_32()
train_dataloader, test_dataloader = build_cifar(gpc.config.BATCH_SIZE // gpc.data_parallel_size)
criterion = CrossEntropyLoss(label_smoothing=0.1, tensor_parallel=tp)
criterion = CrossEntropyLoss(label_smoothing=0.1)
optimizer = torch.optim.AdamW(model.parameters(), lr=gpc.config.LEARNING_RATE, weight_decay=gpc.config.WEIGHT_DECAY)
@@ -107,7 +105,7 @@ def train_cifar():
LogMetricByStepHook(),
# LogTimingByEpochHook(timer=timer, logger=logger),
# LogMemoryByEpochHook(logger=logger),
AccuracyHook(accuracy_func=Accuracy(tensor_parallel=tp)),
AccuracyHook(accuracy_func=Accuracy()),
LossHook(),
ThroughputHook(),
LRSchedulerHook(lr_scheduler=lr_scheduler, by_epoch=False)

View File

@@ -4,7 +4,7 @@ TOTAL_BATCH_SIZE = 4096
LEARNING_RATE = 3e-3
WEIGHT_DECAY = 0.3
TENSOR_PARALLEL_SIZE = 4
TENSOR_PARALLEL_SIZE = 2
TENSOR_PARALLEL_MODE = '1d'
NUM_EPOCHS = 300

View File

@@ -159,14 +159,12 @@ def train_imagenet():
os.mkdir(log_path)
logger.log_to_file(log_path)
tp = gpc.config.parallel.tensor.mode
model = vit_small_patch16_224(tensor_parallel=tp, num_classes=100, init_method='jax')
model = vit_small_patch16_224(num_classes=100, init_method='jax')
train_dataloader = build_dali_train(gpc.config.BATCH_SIZE // gpc.data_parallel_size)
test_dataloader = build_dali_test(gpc.config.BATCH_SIZE // gpc.data_parallel_size)
criterion = CrossEntropyLoss(label_smoothing=0.1, tensor_parallel=tp)
criterion = CrossEntropyLoss(label_smoothing=0.1)
optimizer = torch.optim.AdamW(model.parameters(), lr=gpc.config.LEARNING_RATE, weight_decay=gpc.config.WEIGHT_DECAY)
@@ -192,7 +190,7 @@ def train_imagenet():
LogMetricByStepHook(),
# LogTimingByEpochHook(timer=timer, logger=logger),
# LogMemoryByEpochHook(logger=logger),
AccuracyHook(accuracy_func=Accuracy(tensor_parallel=tp)),
AccuracyHook(accuracy_func=Accuracy()),
LossHook(),
ThroughputHook(),
LRSchedulerHook(lr_scheduler=lr_scheduler, by_epoch=True)

View File

@@ -4,7 +4,7 @@ TOTAL_BATCH_SIZE = 4096
LEARNING_RATE = 3e-3
WEIGHT_DECAY = 0.3
TENSOR_PARALLEL_SIZE = 4
TENSOR_PARALLEL_SIZE = 2
TENSOR_PARALLEL_MODE = '1d'
NUM_EPOCHS = 300

View File

@@ -159,14 +159,12 @@ def train_imagenet():
os.mkdir(log_path)
logger.log_to_file(log_path)
tp = gpc.config.parallel.tensor.mode
model = vit_small_patch16_224(tensor_parallel=tp, num_classes=1000, init_method='jax')
model = vit_small_patch16_224(num_classes=1000, init_method='jax')
train_dataloader = build_dali_train(gpc.config.BATCH_SIZE // gpc.data_parallel_size)
test_dataloader = build_dali_test(gpc.config.BATCH_SIZE // gpc.data_parallel_size)
criterion = CrossEntropyLoss(label_smoothing=0.1, tensor_parallel=tp)
criterion = CrossEntropyLoss(label_smoothing=0.1)
optimizer = torch.optim.AdamW(model.parameters(), lr=gpc.config.LEARNING_RATE, weight_decay=gpc.config.WEIGHT_DECAY)
@@ -192,7 +190,7 @@ def train_imagenet():
LogMetricByStepHook(),
# LogTimingByEpochHook(timer=timer, logger=logger),
# LogMemoryByEpochHook(logger=logger),
AccuracyHook(accuracy_func=Accuracy(tensor_parallel=tp)),
AccuracyHook(accuracy_func=Accuracy()),
LossHook(),
ThroughputHook(),
LRSchedulerHook(lr_scheduler=lr_scheduler, by_epoch=True)