mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-04 18:40:28 +00:00
Support TP-compatible Torch AMP and Update trainer API (#27)
* Add gradient accumulation, fix lr scheduler
* fix FP16 optimizer and adapted torch amp with tensor parallel (#18)
* fixed bugs in compatibility between torch amp and tensor parallel and performed some minor fixes
* fixed trainer
* Revert "fixed trainer"
This reverts commit 2e0b0b7699
.
* improved consistency between trainer, engine and schedule (#23)
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: ver217 <lhx0217@gmail.com>
This commit is contained in:
@@ -6,18 +6,20 @@ import pprint
|
||||
import random
|
||||
from pathlib import Path
|
||||
from typing import Callable, Iterable, Optional, Union
|
||||
from typing import Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from colossalai.engine import AMP_TYPE, NoPipelineSchedule, PipelineSchedule
|
||||
from colossalai.engine import Engine
|
||||
from colossalai.logging import get_global_dist_logger, init_global_dist_logger
|
||||
from colossalai.nn import DataParallelSampler
|
||||
from colossalai.nn.model.base_model import BaseModel
|
||||
from .builder import (ModelInitializer, build_dataset, build_loss,
|
||||
build_lr_scheduler, build_model, build_optimizer,
|
||||
build_optimizer_wrapper)
|
||||
build_model, build_optimizer,
|
||||
build_optimizer_wrapper, build_schedule)
|
||||
from .context import Config, ParallelMode
|
||||
from .core import global_context as gpc
|
||||
from .utils import get_current_device, sync_model_param_in_dp
|
||||
@@ -182,7 +184,7 @@ def initialize(config: Union[str, dict] = None,
|
||||
backend: str = None,
|
||||
train_dataloader: Optional[Union[Iterable, Callable]] = None,
|
||||
test_dataloader: Optional[Union[Iterable, Callable]] = None,
|
||||
):
|
||||
) -> Tuple[Engine, DataLoader, DataLoader]:
|
||||
'''Core function that initializes distributed environment, logger, cudnn, data, model, loss function, optimizer, and lr_scheduler(their configs are in gpc.config).
|
||||
|
||||
:param config: config file or config file path are both acceptable
|
||||
@@ -201,7 +203,7 @@ def initialize(config: Union[str, dict] = None,
|
||||
:type train_dataloader: Optional[Union[Iterable, Callable]], optional
|
||||
:param test_dataloader: If None, the config is used to build a dataloder; Else, it should be a dataloader object or a function with no arguments which can build a dataloader, defaults to None
|
||||
:type test_dataloader: Optional[Union[Iterable, Callable]], optional
|
||||
:return: (model, train_dataloader, test_dataloader, criterion, optimizer, schedule, lr_scheduler)
|
||||
:return: (engine, train_dataloader, test_dataloader, criterion)
|
||||
:rtype: tuple
|
||||
'''
|
||||
# initialize distributed environment
|
||||
@@ -337,21 +339,7 @@ def initialize(config: Union[str, dict] = None,
|
||||
optimizer = build_optimizer_wrapper(fp16_cfg, optimizer)
|
||||
logger.info('Optimizer is created', ranks=[0])
|
||||
|
||||
lr_scheduler = None
|
||||
if hasattr(gpc.config, 'lr_scheduler'):
|
||||
if hasattr(gpc.config, 'num_steps'):
|
||||
total_steps = gpc.config.num_steps
|
||||
elif hasattr(gpc.config, 'num_epochs'):
|
||||
total_steps = int(gpc.config.num_epochs * len(train_dataloader))
|
||||
else:
|
||||
raise Exception(
|
||||
'Please specify training stopping criterion num_steps or num_epochs in your configuration.'
|
||||
)
|
||||
lr_scheduler = build_lr_scheduler(gpc.config.lr_scheduler, optimizer,
|
||||
total_steps, len(train_dataloader))
|
||||
logger.info('Learning rate scheduler is created', ranks=[0])
|
||||
|
||||
# pipeline or no pipeline schedule
|
||||
# build schedule and engine
|
||||
if hasattr(gpc.config, 'fp16'):
|
||||
amp_type = gpc.config.fp16.mode
|
||||
amp_cfg = gpc.config.fp16.copy()
|
||||
@@ -360,12 +348,32 @@ def initialize(config: Union[str, dict] = None,
|
||||
amp_type = None
|
||||
amp_cfg = None
|
||||
|
||||
if gpc.is_initialized(ParallelMode.PIPELINE) and gpc.get_world_size(ParallelMode.PIPELINE) > 1:
|
||||
assert hasattr(gpc.config,
|
||||
'schedule'), "Config 'schedule' not found in your configuration file for pipeline parallel training"
|
||||
engine_cfg = gpc.config.get('engine', dict())
|
||||
schedule_cfg = engine_cfg.pop('schedule', None)
|
||||
|
||||
schedule_type = None
|
||||
if schedule_cfg is not None:
|
||||
schedule_type = schedule_cfg.get('type', None)
|
||||
|
||||
if schedule_type is not None:
|
||||
# run customized schedule
|
||||
schedule_cfg['amp_type'] = amp_type
|
||||
schedule_cfg['amp_config'] = amp_cfg
|
||||
schedule = build_schedule(schedule_cfg)
|
||||
elif gpc.is_initialized(ParallelMode.PIPELINE) and gpc.get_world_size(ParallelMode.PIPELINE) > 1:
|
||||
assert schedule_cfg is not None, \
|
||||
"Config 'engine.schedule' not found in your configuration file for pipeline parallel training"
|
||||
schedule = PipelineSchedule(
|
||||
amp_type=amp_type, amp_config=amp_cfg, **gpc.config.schedule.copy())
|
||||
amp_type=amp_type, amp_config=amp_cfg, **schedule_cfg.copy())
|
||||
else:
|
||||
schedule = NoPipelineSchedule(amp_type=amp_type, amp_config=amp_cfg)
|
||||
|
||||
return model, train_dataloader, test_dataloader, criterion, optimizer, schedule, lr_scheduler
|
||||
engine = Engine(
|
||||
model=model,
|
||||
optimizer=optimizer,
|
||||
criterion=criterion,
|
||||
step_schedule=schedule,
|
||||
**gpc.config.get('engine', dict())
|
||||
)
|
||||
|
||||
return engine, train_dataloader, test_dataloader
|
||||
|
Reference in New Issue
Block a user