Support TP-compatible Torch AMP and Update trainer API (#27)

* Add gradient accumulation, fix lr scheduler

* fix FP16 optimizer and adapted torch amp with tensor parallel (#18)

* fixed bugs in compatibility between torch amp and tensor parallel and performed some minor fixes

* fixed trainer

* Revert "fixed trainer"

This reverts commit 2e0b0b7699.

* improved consistency between trainer, engine and schedule (#23)

Co-authored-by: 1SAA <c2h214748@gmail.com>

Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: ver217 <lhx0217@gmail.com>
This commit is contained in:
Frank Lee
2021-11-18 19:45:06 +08:00
committed by GitHub
parent 2b05de4c64
commit 3defa32aee
80 changed files with 2194 additions and 1584 deletions

View File

@@ -6,18 +6,20 @@ import pprint
import random
from pathlib import Path
from typing import Callable, Iterable, Optional, Union
from typing import Tuple
import numpy as np
import torch
from torch.utils.data import DataLoader
from colossalai.engine import AMP_TYPE, NoPipelineSchedule, PipelineSchedule
from colossalai.engine import Engine
from colossalai.logging import get_global_dist_logger, init_global_dist_logger
from colossalai.nn import DataParallelSampler
from colossalai.nn.model.base_model import BaseModel
from .builder import (ModelInitializer, build_dataset, build_loss,
build_lr_scheduler, build_model, build_optimizer,
build_optimizer_wrapper)
build_model, build_optimizer,
build_optimizer_wrapper, build_schedule)
from .context import Config, ParallelMode
from .core import global_context as gpc
from .utils import get_current_device, sync_model_param_in_dp
@@ -182,7 +184,7 @@ def initialize(config: Union[str, dict] = None,
backend: str = None,
train_dataloader: Optional[Union[Iterable, Callable]] = None,
test_dataloader: Optional[Union[Iterable, Callable]] = None,
):
) -> Tuple[Engine, DataLoader, DataLoader]:
'''Core function that initializes distributed environment, logger, cudnn, data, model, loss function, optimizer, and lr_scheduler(their configs are in gpc.config).
:param config: config file or config file path are both acceptable
@@ -201,7 +203,7 @@ def initialize(config: Union[str, dict] = None,
:type train_dataloader: Optional[Union[Iterable, Callable]], optional
:param test_dataloader: If None, the config is used to build a dataloder; Else, it should be a dataloader object or a function with no arguments which can build a dataloader, defaults to None
:type test_dataloader: Optional[Union[Iterable, Callable]], optional
:return: (model, train_dataloader, test_dataloader, criterion, optimizer, schedule, lr_scheduler)
:return: (engine, train_dataloader, test_dataloader, criterion)
:rtype: tuple
'''
# initialize distributed environment
@@ -337,21 +339,7 @@ def initialize(config: Union[str, dict] = None,
optimizer = build_optimizer_wrapper(fp16_cfg, optimizer)
logger.info('Optimizer is created', ranks=[0])
lr_scheduler = None
if hasattr(gpc.config, 'lr_scheduler'):
if hasattr(gpc.config, 'num_steps'):
total_steps = gpc.config.num_steps
elif hasattr(gpc.config, 'num_epochs'):
total_steps = int(gpc.config.num_epochs * len(train_dataloader))
else:
raise Exception(
'Please specify training stopping criterion num_steps or num_epochs in your configuration.'
)
lr_scheduler = build_lr_scheduler(gpc.config.lr_scheduler, optimizer,
total_steps, len(train_dataloader))
logger.info('Learning rate scheduler is created', ranks=[0])
# pipeline or no pipeline schedule
# build schedule and engine
if hasattr(gpc.config, 'fp16'):
amp_type = gpc.config.fp16.mode
amp_cfg = gpc.config.fp16.copy()
@@ -360,12 +348,32 @@ def initialize(config: Union[str, dict] = None,
amp_type = None
amp_cfg = None
if gpc.is_initialized(ParallelMode.PIPELINE) and gpc.get_world_size(ParallelMode.PIPELINE) > 1:
assert hasattr(gpc.config,
'schedule'), "Config 'schedule' not found in your configuration file for pipeline parallel training"
engine_cfg = gpc.config.get('engine', dict())
schedule_cfg = engine_cfg.pop('schedule', None)
schedule_type = None
if schedule_cfg is not None:
schedule_type = schedule_cfg.get('type', None)
if schedule_type is not None:
# run customized schedule
schedule_cfg['amp_type'] = amp_type
schedule_cfg['amp_config'] = amp_cfg
schedule = build_schedule(schedule_cfg)
elif gpc.is_initialized(ParallelMode.PIPELINE) and gpc.get_world_size(ParallelMode.PIPELINE) > 1:
assert schedule_cfg is not None, \
"Config 'engine.schedule' not found in your configuration file for pipeline parallel training"
schedule = PipelineSchedule(
amp_type=amp_type, amp_config=amp_cfg, **gpc.config.schedule.copy())
amp_type=amp_type, amp_config=amp_cfg, **schedule_cfg.copy())
else:
schedule = NoPipelineSchedule(amp_type=amp_type, amp_config=amp_cfg)
return model, train_dataloader, test_dataloader, criterion, optimizer, schedule, lr_scheduler
engine = Engine(
model=model,
optimizer=optimizer,
criterion=criterion,
step_schedule=schedule,
**gpc.config.get('engine', dict())
)
return engine, train_dataloader, test_dataloader