mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-07 12:01:39 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -10,8 +10,7 @@ from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
|
||||
from colossalai.nn.optimizer import Lamb, Lars
|
||||
|
||||
|
||||
class DummyDataloader():
|
||||
|
||||
class DummyDataloader:
|
||||
def __init__(self, length, batch_size):
|
||||
self.length = length
|
||||
self.batch_size = batch_size
|
||||
@@ -39,10 +38,9 @@ class DummyDataloader():
|
||||
def main():
|
||||
# initialize distributed setting
|
||||
parser = colossalai.get_default_parser()
|
||||
parser.add_argument('--optimizer',
|
||||
choices=['lars', 'lamb'],
|
||||
help="Choose your large-batch optimizer",
|
||||
required=True)
|
||||
parser.add_argument(
|
||||
"--optimizer", choices=["lars", "lamb"], help="Choose your large-batch optimizer", required=True
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
# launch from torch
|
||||
@@ -70,16 +68,18 @@ def main():
|
||||
optimizer = optim_cls(model.parameters(), lr=gpc.config.LEARNING_RATE, weight_decay=gpc.config.WEIGHT_DECAY)
|
||||
|
||||
# create lr scheduler
|
||||
lr_scheduler = CosineAnnealingWarmupLR(optimizer=optimizer,
|
||||
total_steps=gpc.config.NUM_EPOCHS,
|
||||
warmup_steps=gpc.config.WARMUP_EPOCHS)
|
||||
lr_scheduler = CosineAnnealingWarmupLR(
|
||||
optimizer=optimizer, total_steps=gpc.config.NUM_EPOCHS, warmup_steps=gpc.config.WARMUP_EPOCHS
|
||||
)
|
||||
|
||||
# initialize
|
||||
engine, train_dataloader, test_dataloader, _ = colossalai.initialize(model=model,
|
||||
optimizer=optimizer,
|
||||
criterion=criterion,
|
||||
train_dataloader=train_dataloader,
|
||||
test_dataloader=test_dataloader)
|
||||
engine, train_dataloader, test_dataloader, _ = colossalai.initialize(
|
||||
model=model,
|
||||
optimizer=optimizer,
|
||||
criterion=criterion,
|
||||
train_dataloader=train_dataloader,
|
||||
test_dataloader=test_dataloader,
|
||||
)
|
||||
|
||||
logger.info("Engine is built", ranks=[0])
|
||||
|
||||
@@ -89,7 +89,7 @@ def main():
|
||||
data_iter = iter(train_dataloader)
|
||||
|
||||
if gpc.get_global_rank() == 0:
|
||||
description = 'Epoch {} / {}'.format(epoch, gpc.config.NUM_EPOCHS)
|
||||
description = "Epoch {} / {}".format(epoch, gpc.config.NUM_EPOCHS)
|
||||
progress = tqdm(range(len(train_dataloader)), desc=description)
|
||||
else:
|
||||
progress = range(len(train_dataloader))
|
||||
@@ -100,5 +100,5 @@ def main():
|
||||
lr_scheduler.step()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
Reference in New Issue
Block a user