mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-10-21 23:02:07 +00:00
Hotfix/Colossalai layers (#92)
* optimized 1d layer apis; reorganized nn.layer modules; fixed tests * fixed 2.5d runtime issue * reworked split batch, now called in trainer.schedule.load_batch Co-authored-by: BoxiangW <45734921+BoxiangW@users.noreply.github.com>
This commit is contained in:
@@ -2,7 +2,7 @@ BATCH_SIZE = 512
|
||||
LEARNING_RATE = 2e-3
|
||||
WEIGHT_DECAY = 3e-2
|
||||
|
||||
TENSOR_PARALLEL_SIZE = 4
|
||||
TENSOR_PARALLEL_SIZE = 2
|
||||
TENSOR_PARALLEL_MODE = '1d'
|
||||
|
||||
NUM_EPOCHS = 200
|
||||
|
@@ -72,13 +72,11 @@ def train_cifar():
|
||||
os.mkdir(log_path)
|
||||
logger.log_to_file(log_path)
|
||||
|
||||
tp = gpc.config.parallel.tensor.mode
|
||||
|
||||
model = vit_lite_depth7_patch4_32(tensor_parallel=tp)
|
||||
model = vit_lite_depth7_patch4_32()
|
||||
|
||||
train_dataloader, test_dataloader = build_cifar(gpc.config.BATCH_SIZE // gpc.data_parallel_size)
|
||||
|
||||
criterion = CrossEntropyLoss(label_smoothing=0.1, tensor_parallel=tp)
|
||||
criterion = CrossEntropyLoss(label_smoothing=0.1)
|
||||
|
||||
optimizer = torch.optim.AdamW(model.parameters(), lr=gpc.config.LEARNING_RATE, weight_decay=gpc.config.WEIGHT_DECAY)
|
||||
|
||||
@@ -107,7 +105,7 @@ def train_cifar():
|
||||
LogMetricByStepHook(),
|
||||
# LogTimingByEpochHook(timer=timer, logger=logger),
|
||||
# LogMemoryByEpochHook(logger=logger),
|
||||
AccuracyHook(accuracy_func=Accuracy(tensor_parallel=tp)),
|
||||
AccuracyHook(accuracy_func=Accuracy()),
|
||||
LossHook(),
|
||||
ThroughputHook(),
|
||||
LRSchedulerHook(lr_scheduler=lr_scheduler, by_epoch=False)
|
||||
|
@@ -4,7 +4,7 @@ TOTAL_BATCH_SIZE = 4096
|
||||
LEARNING_RATE = 3e-3
|
||||
WEIGHT_DECAY = 0.3
|
||||
|
||||
TENSOR_PARALLEL_SIZE = 4
|
||||
TENSOR_PARALLEL_SIZE = 2
|
||||
TENSOR_PARALLEL_MODE = '1d'
|
||||
|
||||
NUM_EPOCHS = 300
|
||||
|
@@ -159,14 +159,12 @@ def train_imagenet():
|
||||
os.mkdir(log_path)
|
||||
logger.log_to_file(log_path)
|
||||
|
||||
tp = gpc.config.parallel.tensor.mode
|
||||
|
||||
model = vit_small_patch16_224(tensor_parallel=tp, num_classes=100, init_method='jax')
|
||||
model = vit_small_patch16_224(num_classes=100, init_method='jax')
|
||||
|
||||
train_dataloader = build_dali_train(gpc.config.BATCH_SIZE // gpc.data_parallel_size)
|
||||
test_dataloader = build_dali_test(gpc.config.BATCH_SIZE // gpc.data_parallel_size)
|
||||
|
||||
criterion = CrossEntropyLoss(label_smoothing=0.1, tensor_parallel=tp)
|
||||
criterion = CrossEntropyLoss(label_smoothing=0.1)
|
||||
|
||||
optimizer = torch.optim.AdamW(model.parameters(), lr=gpc.config.LEARNING_RATE, weight_decay=gpc.config.WEIGHT_DECAY)
|
||||
|
||||
@@ -192,7 +190,7 @@ def train_imagenet():
|
||||
LogMetricByStepHook(),
|
||||
# LogTimingByEpochHook(timer=timer, logger=logger),
|
||||
# LogMemoryByEpochHook(logger=logger),
|
||||
AccuracyHook(accuracy_func=Accuracy(tensor_parallel=tp)),
|
||||
AccuracyHook(accuracy_func=Accuracy()),
|
||||
LossHook(),
|
||||
ThroughputHook(),
|
||||
LRSchedulerHook(lr_scheduler=lr_scheduler, by_epoch=True)
|
||||
|
@@ -4,7 +4,7 @@ TOTAL_BATCH_SIZE = 4096
|
||||
LEARNING_RATE = 3e-3
|
||||
WEIGHT_DECAY = 0.3
|
||||
|
||||
TENSOR_PARALLEL_SIZE = 4
|
||||
TENSOR_PARALLEL_SIZE = 2
|
||||
TENSOR_PARALLEL_MODE = '1d'
|
||||
|
||||
NUM_EPOCHS = 300
|
||||
|
@@ -159,14 +159,12 @@ def train_imagenet():
|
||||
os.mkdir(log_path)
|
||||
logger.log_to_file(log_path)
|
||||
|
||||
tp = gpc.config.parallel.tensor.mode
|
||||
|
||||
model = vit_small_patch16_224(tensor_parallel=tp, num_classes=1000, init_method='jax')
|
||||
model = vit_small_patch16_224(num_classes=1000, init_method='jax')
|
||||
|
||||
train_dataloader = build_dali_train(gpc.config.BATCH_SIZE // gpc.data_parallel_size)
|
||||
test_dataloader = build_dali_test(gpc.config.BATCH_SIZE // gpc.data_parallel_size)
|
||||
|
||||
criterion = CrossEntropyLoss(label_smoothing=0.1, tensor_parallel=tp)
|
||||
criterion = CrossEntropyLoss(label_smoothing=0.1)
|
||||
|
||||
optimizer = torch.optim.AdamW(model.parameters(), lr=gpc.config.LEARNING_RATE, weight_decay=gpc.config.WEIGHT_DECAY)
|
||||
|
||||
@@ -192,7 +190,7 @@ def train_imagenet():
|
||||
LogMetricByStepHook(),
|
||||
# LogTimingByEpochHook(timer=timer, logger=logger),
|
||||
# LogMemoryByEpochHook(logger=logger),
|
||||
AccuracyHook(accuracy_func=Accuracy(tensor_parallel=tp)),
|
||||
AccuracyHook(accuracy_func=Accuracy()),
|
||||
LossHook(),
|
||||
ThroughputHook(),
|
||||
LRSchedulerHook(lr_scheduler=lr_scheduler, by_epoch=True)
|
||||
|
Reference in New Issue
Block a user