Files
ColossalAI/examples/run_trainer.py
Frank Lee 3defa32aee Support TP-compatible Torch AMP and Update trainer API (#27)
* Add gradient accumulation, fix lr scheduler

* fix FP16 optimizer and adapted torch amp with tensor parallel (#18)

* fixed bugs in compatibility between torch amp and tensor parallel and performed some minor fixes

* fixed trainer

* Revert "fixed trainer"

This reverts commit 2e0b0b7699.

* improved consistency between trainer, engine and schedule (#23)

Co-authored-by: 1SAA <c2h214748@gmail.com>

Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: ver217 <lhx0217@gmail.com>
2021-11-18 19:45:06 +08:00

34 lines
876 B
Python

#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import colossalai
from colossalai.core import global_context as gpc
from colossalai.logging import get_global_dist_logger
from colossalai.trainer import Trainer
def run_trainer():
engine, train_dataloader, test_dataloader = colossalai.initialize()
logger = get_global_dist_logger()
engine.schedule.data_sync = False
logger.info("engine is built", ranks=[0])
trainer = Trainer(engine=engine,
verbose=True)
logger.info("trainer is built", ranks=[0])
logger.info("start training", ranks=[0])
trainer.fit(
train_dataloader=train_dataloader,
test_dataloader=test_dataloader,
epochs=gpc.config.num_epochs,
hooks_cfg=gpc.config.hooks,
display_progress=True,
test_interval=2
)
if __name__ == '__main__':
run_trainer()