mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2026-04-25 09:12:10 +00:00
* Add gradient accumulation, fix lr scheduler
* fix FP16 optimizer and adapted torch amp with tensor parallel (#18)
* fixed bugs in compatibility between torch amp and tensor parallel and performed some minor fixes
* fixed trainer
* Revert "fixed trainer"
This reverts commit 2e0b0b7699.
* improved consistency between trainer, engine and schedule (#23)
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: ver217 <lhx0217@gmail.com>
34 lines
876 B
Python
34 lines
876 B
Python
#!/usr/bin/env python
|
|
# -*- encoding: utf-8 -*-
|
|
|
|
import colossalai
|
|
from colossalai.core import global_context as gpc
|
|
from colossalai.logging import get_global_dist_logger
|
|
from colossalai.trainer import Trainer
|
|
|
|
|
|
def run_trainer():
|
|
engine, train_dataloader, test_dataloader = colossalai.initialize()
|
|
logger = get_global_dist_logger()
|
|
engine.schedule.data_sync = False
|
|
|
|
logger.info("engine is built", ranks=[0])
|
|
|
|
trainer = Trainer(engine=engine,
|
|
verbose=True)
|
|
logger.info("trainer is built", ranks=[0])
|
|
|
|
logger.info("start training", ranks=[0])
|
|
trainer.fit(
|
|
train_dataloader=train_dataloader,
|
|
test_dataloader=test_dataloader,
|
|
epochs=gpc.config.num_epochs,
|
|
hooks_cfg=gpc.config.hooks,
|
|
display_progress=True,
|
|
test_interval=2
|
|
)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
run_trainer()
|