Develop/experiments (#59)

* Add gradient accumulation, fix lr scheduler

* fix FP16 optimizer and adapted torch amp with tensor parallel (#18)

* fixed bugs in compatibility between torch amp and tensor parallel and performed some minor fixes

* fixed trainer

* Revert "fixed trainer"

This reverts commit 2e0b0b7699.

* improved consistency between trainer, engine and schedule (#23)

Co-authored-by: 1SAA <c2h214748@gmail.com>

* Split conv2d, class token, positional embedding in 2d, Fix random number in ddp
Fix convergence in cifar10, Imagenet1000

* Integrate 1d tensor parallel in Colossal-AI (#39)

* fixed 1D and 2D convergence (#38)

* optimized 2D operations

* fixed 1D ViT convergence problem

* Feature/ddp (#49)

* remove redundancy func in setup (#19) (#20)

* use env to control the language of doc (#24) (#25)

* Support TP-compatible Torch AMP and Update trainer API (#27)

* Add gradient accumulation, fix lr scheduler

* fix FP16 optimizer and adapted torch amp with tensor parallel (#18)

* fixed bugs in compatibility between torch amp and tensor parallel and performed some minor fixes

* fixed trainer

* Revert "fixed trainer"

This reverts commit 2e0b0b7699.

* improved consistency between trainer, engine and schedule (#23)

Co-authored-by: 1SAA <c2h214748@gmail.com>

Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: ver217 <lhx0217@gmail.com>

* add an example of ViT-B/16 and remove w_norm clipping in LAMB (#29)

* add explanation for ViT example (#35) (#36)

* support torch ddp

* fix loss accumulation

* add log for ddp

* change seed

* modify timing hook

Co-authored-by: Frank Lee <somerlee.9@gmail.com>
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: binmakeswell <binmakeswell@gmail.com>

* Feature/pipeline (#40)

* remove redundancy func in setup (#19) (#20)

* use env to control the language of doc (#24) (#25)

* Support TP-compatible Torch AMP and Update trainer API (#27)

* Add gradient accumulation, fix lr scheduler

* fix FP16 optimizer and adapted torch amp with tensor parallel (#18)

* fixed bugs in compatibility between torch amp and tensor parallel and performed some minor fixes

* fixed trainer

* Revert "fixed trainer"

This reverts commit 2e0b0b7699.

* improved consistency between trainer, engine and schedule (#23)

Co-authored-by: 1SAA <c2h214748@gmail.com>

Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: ver217 <lhx0217@gmail.com>

* add an example of ViT-B/16 and remove w_norm clipping in LAMB (#29)

* add explanation for ViT example (#35) (#36)

* optimize communication of pipeline parallel

* fix grad clip for pipeline

Co-authored-by: Frank Lee <somerlee.9@gmail.com>
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: binmakeswell <binmakeswell@gmail.com>

* optimized 3d layer to fix slow computation ; tested imagenet performance with 3d; reworked lr_scheduler config definition; fixed launch args; fixed some printing issues; simplified apis of 3d layers (#51)

* Update 2.5d layer code to get a similar accuracy on imagenet-1k dataset

* update api for better usability (#58)

update api for better usability

Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: ver217 <lhx0217@gmail.com>
Co-authored-by: puck_WCR <46049915+WANG-CR@users.noreply.github.com>
Co-authored-by: binmakeswell <binmakeswell@gmail.com>
Co-authored-by: アマデウス <kurisusnowdeng@users.noreply.github.com>
Co-authored-by: BoxiangW <45734921+BoxiangW@users.noreply.github.com>
This commit is contained in:
Frank Lee
2021-12-09 15:08:29 +08:00
committed by GitHub
parent eb2f8b1f6b
commit da01c234e1
229 changed files with 6532 additions and 8741 deletions

View File

@@ -3,377 +3,326 @@
import argparse
import pprint
import random
from pathlib import Path
from typing import Callable, Iterable, Optional, Union
from typing import Tuple
import os
from colossalai.nn.optimizer.colossalai_optimizer import ColossalaiOptimizer
import numpy as np
import torch
from torch.utils.data import DataLoader
import torch.nn as nn
from colossalai.engine import AMP_TYPE, NoPipelineSchedule, PipelineSchedule
from pathlib import Path
from typing import Iterable, Union, Optional, Tuple, List, Dict
from colossalai.amp import convert_to_amp, AMP_TYPE
from colossalai.context import Config, ParallelMode, ConfigException
from colossalai.core import global_context as gpc
from colossalai.engine import Engine
from colossalai.logging import get_global_dist_logger, init_global_dist_logger
from colossalai.nn import DataParallelSampler
from colossalai.nn.model.base_model import BaseModel
from .builder import (ModelInitializer, build_dataset, build_loss,
build_model, build_optimizer,
build_optimizer_wrapper, build_schedule)
from .context import Config, ParallelMode
from .core import global_context as gpc
from .utils import get_current_device, sync_model_param_in_dp
from colossalai.logging import get_dist_logger
from colossalai.utils import (accumulate_gradient, get_current_device,
sync_model_param_in_dp, is_using_ddp, is_using_pp)
from colossalai.zero import convert_to_zero, ZeroRedundancyOptimizer_Level_2, ZeroRedundancyOptimizer_Level_3
from colossalai.builder.builder import build_gradient_handler
from torch.optim.optimizer import Optimizer
from torch.optim.lr_scheduler import _LRScheduler
from torch.utils.data import DataLoader
from torch.nn.modules.loss import _Loss
from torch.nn.parallel import DistributedDataParallel as DDP
def parse_args():
def get_default_parser():
'''Reads user command line and uses an argument parser to parse the input arguments.
Input arguments include configuration, host, port, world size, local rank, backend for torch.distributed.
:return: call the parse arguments function
:return: returns the parser with the default arguments, the user may add customized arguments into this parser
:rtype: Namespace
'''
parser = argparse.ArgumentParser()
parser.add_argument('--config', type=str, help='path to the config file')
parser.add_argument('--host',
type=str,
default=None,
help='the master address for distributed training')
parser.add_argument('--port',
type=str,
default=None,
type=int,
help='the master port for distributed training')
parser.add_argument('--world_size', type=int, help='world size for ')
parser.add_argument('--world_size', type=int, help='world size for distributed training')
parser.add_argument('--rank', type=int, help='rank for the default process group')
parser.add_argument('--local_rank',
type=int,
help='rank for the default process group')
help='local rank on the node')
parser.add_argument('--backend',
type=str,
default='nccl',
help='backend for torch.distributed')
return parser.parse_args()
help='backend for distributed communication')
return parser
def init_dist(config: Union[str, dict] = None,
local_rank: int = None,
world_size: int = None,
host: str = None,
port: str = None,
backend: str = None):
def launch(config: Union[str, Path, Config, Dict],
rank: int,
world_size: int,
host: str,
port: int,
backend: str = 'nccl',
local_rank: int = None,
seed: int = 1024,
verbose: bool = True):
'''This function first parses the configuration arguments, using :func:parse_args() in case one of the input arguments are not given.
Then initialize and set distributed environment by calling global_context's functions.
Then initialize and set distributed environment by calling global_context's functions.
:param config: config file or config file path are both acceptable
:type config: Union[str, dict], optional
:param local_rank: rank for the default process group, defaults to None
:type config: Union[str, dict, Config]
:param rank: rank for the default process group
:type rank: int
:param world_size: world size of the default process group
:type world_size: int
:param host: the master address for distributed training
:type host: str
:param port: the master port for distributed training
:type port: str
:param backend: backend for torch.distributed
:type backend: str
:param local_rank: rank for the process on the node and is used to set the default CUDA device,
defaults to None. If local_rank = None, the default device ordinal will be calculated automatically
:type local_rank: int, optional
:param world_size: world size of GPUs, defaults to None
:type world_size: int, optional
:param host: the master address for distributed training, defaults to None
:type host: str, optional
:param port: the master port for distributed training, defaults to None
:type port: str, optional
:param backend: backend for torch.distributed, defaults to None
:type backend: str, optional
:raises Exception: raise exception when config type is wrong
'''
args = [config, local_rank, world_size, host, port, backend]
arg_given = [arg is not None for arg in args]
if not all(arg_given):
args = parse_args()
if config is None:
config = args.config
if local_rank is None:
local_rank = args.local_rank
if world_size is None:
world_size = args.world_size
if host is None:
host = args.host
if port is None:
port = args.port
if backend is None:
backend = args.backend
args = Config(
dict(config=config,
host=host,
port=port,
world_size=world_size,
local_rank=local_rank,
backend=backend))
# set distributed settings
dist_args = Config(
dict(local_rank=args.local_rank,
world_size=args.world_size,
backend=args.backend))
gpc.set_dist_args(dist_args)
gpc.verbose = verbose
# set config
if isinstance(args.config, dict):
cfg = args.config
elif isinstance(args.config, (str, Path)):
cfg = Config.from_file(args.config)
else:
raise Exception('Config type error: {}'.format(type(args.config)))
gpc.load_config(cfg)
assert isinstance(config, (Config, str, Path, dict)), \
f'expected argument config to be Config, str or Path, but got {type(config)}'
if not isinstance(config, Config) and isinstance(config, dict):
config = Config(config)
if isinstance(config, (str, Path)):
config = Config.from_file(config)
gpc.load_config(config)
# init dist groups
gpc.init_global_dist(args.host, args.port)
# init default process group
gpc.init_global_dist(rank, world_size, backend, host, port)
# init process groups for different parallel modes from config
gpc.init_parallel_groups()
# init dist logger
init_global_dist_logger()
# set cuda device
if torch.cuda.is_available():
gpc.set_device()
# if local rank is not given, calculate automatically
gpc.set_device(local_rank)
gpc.set_seed(seed)
if verbose:
logger = get_dist_logger()
logger.info(f'Distributed environment is initialized, '
f'data parallel size: {gpc.data_parallel_size}, pipeline parallel size: {gpc.pipeline_parallel_size}, '
f'tensor parallel size: {gpc.tensor_parallel_size}', ranks=[0])
def get_dataloader(dataset, seed=1024, add_sampler_if_possible=False, **kwargs):
'''Set up a deterministic dataloader (also configure seed workers, samplers and whether shuffle or not)
.. note: when pipeline parallel is enabled, shuffle cannot be True
as it will result in mismatch between input data on the 1st
stage and label on the last stage
:param dataset: a :class:utils.data.dataset dataset
:param seed: random worker seed, defaults to 1024
:type seed: int, optional
:param add_sampler_if_possible: [description], defaults to False
:type add_sampler_if_possible: bool, optional
:return: a :class:utils.data.dataset dataloader
:rtype: torch.utils.data.dataset
'''
_kwargs = kwargs.copy()
if 'shuffle' in _kwargs:
shuffle = _kwargs.pop('shuffle')
else:
shuffle = False
if add_sampler_if_possible and gpc.is_initialized(ParallelMode.DATA) and gpc.get_world_size(ParallelMode.DATA) > 1:
sampler = DataParallelSampler(dataset, shuffle=shuffle)
else:
sampler = None
# Deterministic dataloader
def seed_worker(worker_id):
worker_seed = seed
np.random.seed(worker_seed)
torch.manual_seed(worker_seed)
random.seed(worker_seed)
if sampler is None:
return DataLoader(dataset,
worker_init_fn=seed_worker,
shuffle=shuffle,
**_kwargs)
else:
return DataLoader(dataset,
sampler=sampler,
worker_init_fn=seed_worker,
**_kwargs)
def launch_from_slurm(config: Union[str, Path, Config, Dict],
host: str,
port: int,
backend: str = 'nccl',
seed: int = 1024,
verbose: bool = True):
rank = int(os.environ['SLURM_PROCID'])
world_size = int(os.environ['SLURM_NPROCS'])
launch(config=config,
rank=rank,
world_size=world_size,
host=host,
port=port,
backend=backend,
seed=seed,
verbose=verbose)
def initialize(config: Union[str, dict] = None,
local_rank: int = None,
world_size: int = None,
host: str = None,
port: str = None,
backend: str = None,
train_dataloader: Optional[Union[Iterable, Callable]] = None,
test_dataloader: Optional[Union[Iterable, Callable]] = None,
def launch_from_openmpi(config: Union[str, Path, Config, Dict],
host: str,
port: int,
backend: str = 'nccl',
seed: int = 1024,
verbose: bool = True):
rank = int(os.environ['OMPI_COMM_WORLD_RANK'])
local_rank = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK'])
world_size = int(os.environ['OMPI_COMM_WORLD_SIZE'])
launch(config=config,
local_rank=local_rank,
rank=rank,
world_size=world_size,
host=host,
port=port,
backend=backend,
seed=seed,
verbose=verbose)
def launch_from_torch(config: Union[str, Path, Config, Dict],
host: str,
port: int,
backend: str = 'nccl',
seed: int = 1024,
verbose: bool = True):
rank = int(os.environ['RANK'])
local_rank = int(os.environ['LOCAL_RANK'])
world_size = int(os.environ['WORLD_SIZE'])
launch(config=config,
local_rank=local_rank,
rank=rank,
world_size=world_size,
host=host,
port=port,
backend=backend,
seed=seed,
verbose=verbose)
def initialize(model: Union[nn.Module, List[nn.Module]],
optimizer: Union[Optimizer, List[Optimizer]],
criterion: Union[_Loss, List[_Loss]],
train_dataloader: Optional[Union[Iterable, List[Iterable]]] = None,
test_dataloader: Optional[Union[Iterable, List[Iterable]]] = None,
lr_scheduler: _LRScheduler = None,
verbose: bool = True
) -> Tuple[Engine, DataLoader, DataLoader]:
'''Core function that initializes distributed environment, logger, cudnn, data, model, loss function, optimizer, and lr_scheduler(their configs are in gpc.config).
''' Core function to wrap the essential training components with our functionality based on the config which is loaded into gpc.config.
:param config: config file or config file path are both acceptable
:type config: Union[str, dict], optional
:param local_rank: rank for the default process group, defaults to None
:type local_rank: int, optional
:param world_size: world size of GPUs, defaults to None
:type world_size: int, optional
:param host: the master address for distributed training, defaults to None
:type host: str, optional
:param port: the master port for distributed training, defaults to None
:type port: str, optional
:param backend: backend for torch.distributed, defaults to None
:type backend: str, optional
:param train_dataloader: If None, the config is used to build a dataloder; Else, it should be a dataloader object or a function with no arguments which can build a dataloader, defaults to None
:type train_dataloader: Optional[Union[Iterable, Callable]], optional
:param test_dataloader: If None, the config is used to build a dataloder; Else, it should be a dataloader object or a function with no arguments which can build a dataloader, defaults to None
:type test_dataloader: Optional[Union[Iterable, Callable]], optional
:return: (engine, train_dataloader, test_dataloader, criterion)
:param model: your model instance
:type model: a single or a list of ``torch.nn.Module`` objects
:param optimizer: your optimizer instance
:type optimizer: a single or a list of ``torch.optim.optimizer.Optimizer`` objects
:param criterion: your criterion instance
:type criterion: a single or a list of ``torch.nn.modules.loss._Loss`` objects
:param train_dataloader: dataloaders for training data
:type train_dataloader: a single or a list of ``torch.utils.data.DataLoader`` objects, defaults to None
:param train_dataloader: dataloaders for testing data
:type train_dataloader: a single or a list of ``torch.utils.data.DataLoader`` objects, defaults to None
:return: (engine, criterion, train_dataloader, test_dataloader)
:rtype: tuple
'''
# initialize distributed environment
init_dist(config=config,
local_rank=local_rank,
world_size=world_size,
host=host,
port=port,
backend=backend)
# get logger
logger = get_dist_logger()
gpc.verbose = verbose
# init logger
logger = get_global_dist_logger()
logger.info(f'Distributed environment is initialized, '
f'data parallel size: {gpc.data_parallel_size}, pipeline parallel size: {gpc.pipeline_parallel_size}, '
f'tensor parallel size: {gpc.tensor_parallel_size}', ranks=[0])
# get config from gpc
config = gpc.config
# print config
logger.info(f"\n========== Your Config ========\n"
f"{pprint.pformat(gpc.config)}\n"
f"================================", ranks=[0])
if verbose:
logger.info(f"\n========== Your Config ========\n"
f"{pprint.pformat(gpc.config)}\n"
f"================================\n", ranks=[0])
# cudnn
cudnn_benchmark = gpc.config.get('cudnn_benchmark', True)
cudnn_deterministic = gpc.config.get('cudnn_deterministic', False)
cudnn_benchmark = config.get('cudnn_benchmark', True)
cudnn_deterministic = config.get('cudnn_deterministic', False)
torch.backends.cudnn.benchmark = cudnn_benchmark
torch.backends.cudnn.deterministic = cudnn_deterministic
logger.info(
f"cuDNN benchmark = {cudnn_benchmark}, deterministic = {cudnn_deterministic}", ranks=[0])
if verbose:
logger.info(
f"cuDNN benchmark = {cudnn_benchmark}, deterministic = {cudnn_deterministic}", ranks=[0])
# set seed, cuda seed is only set when cuda is avail
gpc.set_seed()
# return_items = list()
# check fp16 and zero
should_convert_model_to_half = False
should_wrap_fp16_optimizer = False
should_wrap_zero_optimizer_level_2_3 = False
if hasattr(gpc.config, 'fp16'):
fp16_mode = gpc.config.fp16.mode
if fp16_mode == AMP_TYPE.PARALLEL:
should_convert_model_to_half = True
should_wrap_fp16_optimizer = True
if hasattr(gpc.config, 'zero'):
should_wrap_zero_optimizer_level_2_3 = True
zero_type = gpc.config.zero.type
if zero_type in ['ZeroRedundancyOptimizer_Level_2', 'ZeroRedundancyOptimizer_Level_3']:
should_convert_model_to_half = True
assert not should_wrap_fp16_optimizer, \
'AMP_TYPE.PARALLEL is mutually exclusive with zero level 2 and 3'
# build model
logger.info('Building model ...', ranks=[0])
assert hasattr(
gpc.config, 'model'), "Build error: configuration 'model' is missing"
if gpc.pipeline_parallel_size > 1:
model = ModelInitializer(gpc.config.model, 1, verbose=True)
model = model.model_initialize()
else:
model = build_model(gpc.config.model)
if isinstance(model, BaseModel):
model.build_from_cfg()
model = model.to(get_current_device())
# first sync model across dp ranks
model.to(get_current_device())
sync_model_param_in_dp(model)
logger.info('Model is created', ranks=[0])
if should_convert_model_to_half:
model = model.half()
logger.info("Model is cast to fp16", ranks=[0])
# check amp and zero
fp16_cfg = gpc.config.get('fp16', None)
zero_cfg = gpc.config.get('zero', None)
# training data
if callable(train_dataloader):
logger.info(
f'Build train data loader from {train_dataloader}', ranks=[0])
train_dataloader = train_dataloader()
if train_dataloader is None and hasattr(gpc.config, 'train_data'):
logger.info('Preparing data ...', ranks=[0])
# assert hasattr(gpc.config, 'train_data'), "Build error: configuration 'train_data' is missing."
train_dataset = build_dataset(gpc.config.train_data.dataset)
logger.info('Train dataset is ready.', ranks=[0])
if fp16_cfg is not None and fp16_cfg.mode is not None and zero_cfg is not None:
raise ConfigException(
"It is not allowed to set fp16 and zero configuration in your config file at the same time")
train_dataloader = get_dataloader(train_dataset,
gpc.config.get('seed', 1024),
True,
**gpc.config.train_data.dataloader,
)
logger.info(
f'Loaded {len(train_dataset)} samples in {len(train_dataloader)} batches for training', ranks=[0])
# initialize amp
amp_mode = None
if fp16_cfg is not None and fp16_cfg.mode is not None:
cfg_ = fp16_cfg.copy()
amp_mode = cfg_.pop('mode')
model, optimizer, criterion = convert_to_amp(model=model,
optimizer=optimizer,
criterion=criterion,
mode=amp_mode,
amp_config=cfg_)
if callable(test_dataloader):
logger.info(
f'Build test data loader from {test_dataloader}', ranks=[0])
test_dataloader = test_dataloader()
# testing data, allowed to be None
if test_dataloader is None and hasattr(gpc.config, 'test_data'):
test_dataset = build_dataset(gpc.config.test_data.dataset)
test_dataloader = get_dataloader(
test_dataset, add_sampler_if_possible=True, **gpc.config.test_data.dataloader)
logger.info(
f'Loaded {len(test_dataset)} samples in {len(test_dataloader)} batches for testing', ranks=[0])
if zero_cfg is not None:
cfg_ = zero_cfg.copy()
level = cfg_.pop('level')
model, optimizer = convert_to_zero(model=model,
optimizer=optimizer,
level=level,
zero_config=cfg_
)
# build loss function
assert hasattr(gpc.config, 'loss'), \
'Build error: configuration \'loss\' is missing.'
criterion = build_loss(gpc.config.loss)
logger.info('Loss function is created', ranks=[0])
# build optimizer
assert hasattr(gpc.config, 'optimizer'), \
"Build error: configuration 'optimizer' is missing."
optim_type = gpc.config.optimizer.type
is_pytorch_native_zero_level_1 = optim_type == 'ZeroRedundancyOptimizer'
if is_pytorch_native_zero_level_1:
original_cfg_copy = gpc.config.optimizer.copy()
original_cfg_copy.pop('type')
cfg = dict(type=optim_type, process_group=gpc.get_group(
ParallelMode.DATA), **original_cfg_copy)
optimizer = build_optimizer(cfg, model)
# gradient handler
gradient_handler_cfg = gpc.config.get('gradient_handler', None)
if gradient_handler_cfg is None:
# if gradient handler is not specified in the configuration file,
# check in the following order
# 1. if optimizer is ZERO, then use zero grad handler
# 2. if dp size is larger than 1 and pipeline is not used, use pytorch ddp
# 3. if using pipeline and dp size larger than 1, use data parallel grad handler
if isinstance(optimizer, (ZeroRedundancyOptimizer_Level_2,
ZeroRedundancyOptimizer_Level_3)):
gradient_handler_cfg = [dict(type='ZeROGradientHandler')]
if verbose:
logger.info(
"Training with zero is detected, ZeROGradientHandler is automatically "
"added even though not specified in the configuration",
ranks=[0])
elif is_using_ddp() and not is_using_pp() and amp_mode != AMP_TYPE.NAIVE:
model = DDP(model, process_group=gpc.get_group(ParallelMode.DATA))
if verbose:
logger.info(
'Model is using torch.nn.parallel.DistributedDataParallel', ranks=[0])
elif is_using_ddp():
gradient_handler_cfg = [dict(type='DataParallelGradientHandler')]
if verbose:
logger.info(
"Data parallel training is detected when using pipeline parallel, DataParallelGradientHandler is automatically "
"added even though not specified in the configuration",
ranks=[0])
else:
optimizer = build_optimizer(gpc.config.optimizer, model)
if not isinstance(gradient_handler_cfg, list):
raise ConfigException(
f"expected gradient_handler in the configuration file to be a list but got {type(gradient_handler_cfg)}")
if should_wrap_zero_optimizer_level_2_3:
optimizer = build_optimizer_wrapper(gpc.config.zero, optimizer, model)
if should_wrap_fp16_optimizer:
# replace the field mode with type
fp16_cfg = gpc.config.fp16.copy()
amp_type = fp16_cfg.pop('mode')
assert amp_type == AMP_TYPE.PARALLEL, 'FP Optimizer should only be used for AMP_TYPE.PARALLEL'
fp16_cfg['type'] = 'FP16Optimizer'
optimizer = build_optimizer_wrapper(fp16_cfg, optimizer)
logger.info('Optimizer is created', ranks=[0])
# build schedule and engine
if hasattr(gpc.config, 'fp16'):
amp_type = gpc.config.fp16.mode
amp_cfg = gpc.config.fp16.copy()
amp_cfg.pop('mode')
if gradient_handler_cfg is None:
gradient_handlers = None
if verbose and not isinstance(model, DDP):
logger.warning(
"No PyTorch DDP or gradient handler is set up, please make sure you do not need "
"to all-reduce the gradients after a training step.",
ranks=[0])
else:
amp_type = None
amp_cfg = None
gradient_handlers = [build_gradient_handler(cfg, model, optimizer) for cfg in gradient_handler_cfg]
engine_cfg = gpc.config.get('engine', dict())
schedule_cfg = engine_cfg.pop('schedule', None)
# check if optimizer is ColossalaiOptimizer
if not isinstance(optimizer, (ColossalaiOptimizer, ZeroRedundancyOptimizer_Level_2, ZeroRedundancyOptimizer_Level_3)):
optimizer = ColossalaiOptimizer(optim=optimizer)
schedule_type = None
if schedule_cfg is not None:
schedule_type = schedule_cfg.get('type', None)
# gradient accumulation
grad_accum_size = gpc.config.get('gradient_accumulation', None)
if grad_accum_size is not None:
optimizer, train_dataloader, gradient_handlers, lr_scheduler = accumulate_gradient(model=model,
optimizer=optimizer,
dataloader=train_dataloader,
accumulate_size=grad_accum_size,
gradient_handlers=gradient_handlers,
lr_scheduler=lr_scheduler)
if schedule_type is not None:
# run customized schedule
schedule_cfg['amp_type'] = amp_type
schedule_cfg['amp_config'] = amp_cfg
schedule = build_schedule(schedule_cfg)
elif gpc.is_initialized(ParallelMode.PIPELINE) and gpc.get_world_size(ParallelMode.PIPELINE) > 1:
assert schedule_cfg is not None, \
"Config 'engine.schedule' not found in your configuration file for pipeline parallel training"
schedule = PipelineSchedule(
amp_type=amp_type, amp_config=amp_cfg, **schedule_cfg.copy())
else:
schedule = NoPipelineSchedule(amp_type=amp_type, amp_config=amp_cfg)
# clip grad norm
clip_grad_norm = gpc.config.get('clip_grad_norm', 0.0)
if clip_grad_norm > 0:
if zero_cfg is not None:
raise ConfigException(
"clip_grad_norm should be specified with zero, you should specify clip_grad in zero configuration")
elif fp16_cfg is not None and fp16_cfg.mode == AMP_TYPE.NAIVE:
raise ConfigException(
"clip_grad_norm should be specified with AMP_TYPE.NAIVE, you should specify clip_grad in fp16 configuration")
engine = Engine(
model=model,
optimizer=optimizer,
criterion=criterion,
step_schedule=schedule,
**gpc.config.get('engine', dict())
gradient_handlers=gradient_handlers,
clip_grad_norm=clip_grad_norm
)
return engine, train_dataloader, test_dataloader
return engine, train_dataloader, test_dataloader, lr_scheduler