mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2026-04-26 01:35:21 +00:00
* Add gradient accumulation, fix lr scheduler
* fix FP16 optimizer and adapted torch amp with tensor parallel (#18)
* fixed bugs in compatibility between torch amp and tensor parallel and performed some minor fixes
* fixed trainer
* Revert "fixed trainer"
This reverts commit 2e0b0b7699.
* improved consistency between trainer, engine and schedule (#23)
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: ver217 <lhx0217@gmail.com>
46 lines
1.1 KiB
Python
46 lines
1.1 KiB
Python
#!/usr/bin/env python
|
|
# -*- encoding: utf-8 -*-
|
|
|
|
import os.path as osp
|
|
|
|
import pytest
|
|
import torch
|
|
|
|
from colossalai import initialize
|
|
from colossalai.context import ParallelMode
|
|
from colossalai.core import global_context as gpc
|
|
from colossalai.logging import get_global_dist_logger
|
|
|
|
NUM_BATCH = 128
|
|
|
|
BATCH_SIZE = 32
|
|
SEQ_LENGTH = 128
|
|
HIDDEN_SIZE = 512
|
|
|
|
DIR_PATH = osp.dirname(osp.realpath(__file__))
|
|
PIPE_CONFIG_PATH = osp.join(DIR_PATH, '../configs/pipeline_vanilla_resnet.py')
|
|
|
|
|
|
def run_pipeline(config):
|
|
engine, train_dataloader, test_dataloader = initialize(config)
|
|
logger = get_global_dist_logger()
|
|
rank = torch.distributed.get_rank()
|
|
|
|
engine.train()
|
|
outputs, labels, loss = engine.step(iter(train_dataloader))
|
|
if gpc.is_last_rank(ParallelMode.PIPELINE):
|
|
logger.info('losses: {}'.format(rank, loss.item()))
|
|
|
|
gpc.destroy()
|
|
logger.info('Test engine pipeline finished')
|
|
|
|
|
|
@pytest.mark.skip("This test should be invoked using the test.sh provided")
|
|
@pytest.mark.dist
|
|
def test_engine():
|
|
run_pipeline(PIPE_CONFIG_PATH)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
test_engine()
|