mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-10 13:30:19 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -18,11 +18,11 @@ NUM_HEADS = 4
|
||||
MLP_RATIO = 2
|
||||
NUM_CLASSES = 10
|
||||
CHECKPOINT = False
|
||||
SEQ_LENGTH = (IMG_SIZE // PATCH_SIZE)**2 + 1 # add 1 for cls token
|
||||
SEQ_LENGTH = (IMG_SIZE // PATCH_SIZE) ** 2 + 1 # add 1 for cls token
|
||||
|
||||
# parallel setting
|
||||
TENSOR_PARALLEL_SIZE = 2
|
||||
TENSOR_PARALLEL_MODE = '1d'
|
||||
TENSOR_PARALLEL_MODE = "1d"
|
||||
|
||||
parallel = dict(
|
||||
pipeline=2,
|
||||
@@ -33,4 +33,4 @@ fp16 = dict(mode=AMP_TYPE.NAIVE)
|
||||
clip_grad_norm = 1.0
|
||||
|
||||
# pipeline config
|
||||
NUM_MICRO_BATCHES = parallel['pipeline']
|
||||
NUM_MICRO_BATCHES = parallel["pipeline"]
|
||||
|
@@ -14,8 +14,7 @@ from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
|
||||
from colossalai.utils import is_using_pp
|
||||
|
||||
|
||||
class DummyDataloader():
|
||||
|
||||
class DummyDataloader:
|
||||
def __init__(self, length, batch_size):
|
||||
self.length = length
|
||||
self.batch_size = batch_size
|
||||
@@ -50,7 +49,7 @@ def main():
|
||||
logger = get_dist_logger()
|
||||
logger.info("initialized distributed environment", ranks=[0])
|
||||
|
||||
if hasattr(gpc.config, 'LOG_PATH'):
|
||||
if hasattr(gpc.config, "LOG_PATH"):
|
||||
if gpc.get_global_rank() == 0:
|
||||
log_path = gpc.config.LOG_PATH
|
||||
if not os.path.exists(log_path):
|
||||
@@ -60,15 +59,17 @@ def main():
|
||||
use_pipeline = is_using_pp()
|
||||
|
||||
# create model
|
||||
model_kwargs = dict(img_size=gpc.config.IMG_SIZE,
|
||||
patch_size=gpc.config.PATCH_SIZE,
|
||||
hidden_size=gpc.config.HIDDEN_SIZE,
|
||||
depth=gpc.config.DEPTH,
|
||||
num_heads=gpc.config.NUM_HEADS,
|
||||
mlp_ratio=gpc.config.MLP_RATIO,
|
||||
num_classes=10,
|
||||
init_method='jax',
|
||||
checkpoint=gpc.config.CHECKPOINT)
|
||||
model_kwargs = dict(
|
||||
img_size=gpc.config.IMG_SIZE,
|
||||
patch_size=gpc.config.PATCH_SIZE,
|
||||
hidden_size=gpc.config.HIDDEN_SIZE,
|
||||
depth=gpc.config.DEPTH,
|
||||
num_heads=gpc.config.NUM_HEADS,
|
||||
mlp_ratio=gpc.config.MLP_RATIO,
|
||||
num_classes=10,
|
||||
init_method="jax",
|
||||
checkpoint=gpc.config.CHECKPOINT,
|
||||
)
|
||||
|
||||
if use_pipeline:
|
||||
pipelinable = PipelinableContext()
|
||||
@@ -102,16 +103,18 @@ def main():
|
||||
optimizer = torch.optim.AdamW(model.parameters(), lr=gpc.config.LEARNING_RATE, weight_decay=gpc.config.WEIGHT_DECAY)
|
||||
|
||||
# create lr scheduler
|
||||
lr_scheduler = CosineAnnealingWarmupLR(optimizer=optimizer,
|
||||
total_steps=gpc.config.NUM_EPOCHS,
|
||||
warmup_steps=gpc.config.WARMUP_EPOCHS)
|
||||
lr_scheduler = CosineAnnealingWarmupLR(
|
||||
optimizer=optimizer, total_steps=gpc.config.NUM_EPOCHS, warmup_steps=gpc.config.WARMUP_EPOCHS
|
||||
)
|
||||
|
||||
# initialize
|
||||
engine, train_dataloader, test_dataloader, _ = colossalai.initialize(model=model,
|
||||
optimizer=optimizer,
|
||||
criterion=criterion,
|
||||
train_dataloader=train_dataloader,
|
||||
test_dataloader=test_dataloader)
|
||||
engine, train_dataloader, test_dataloader, _ = colossalai.initialize(
|
||||
model=model,
|
||||
optimizer=optimizer,
|
||||
criterion=criterion,
|
||||
train_dataloader=train_dataloader,
|
||||
test_dataloader=test_dataloader,
|
||||
)
|
||||
|
||||
logger.info("Engine is built", ranks=[0])
|
||||
|
||||
@@ -121,7 +124,7 @@ def main():
|
||||
data_iter = iter(train_dataloader)
|
||||
|
||||
if gpc.get_global_rank() == 0:
|
||||
description = 'Epoch {} / {}'.format(epoch, gpc.config.NUM_EPOCHS)
|
||||
description = "Epoch {} / {}".format(epoch, gpc.config.NUM_EPOCHS)
|
||||
progress = tqdm(range(len(train_dataloader)), desc=description)
|
||||
else:
|
||||
progress = range(len(train_dataloader))
|
||||
@@ -133,5 +136,5 @@ def main():
|
||||
gpc.destroy()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
Reference in New Issue
Block a user