mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-04 18:40:28 +00:00
Develop/experiments (#59)
* Add gradient accumulation, fix lr scheduler * fix FP16 optimizer and adapted torch amp with tensor parallel (#18) * fixed bugs in compatibility between torch amp and tensor parallel and performed some minor fixes * fixed trainer * Revert "fixed trainer" This reverts commit2e0b0b7699
. * improved consistency between trainer, engine and schedule (#23) Co-authored-by: 1SAA <c2h214748@gmail.com> * Split conv2d, class token, positional embedding in 2d, Fix random number in ddp Fix convergence in cifar10, Imagenet1000 * Integrate 1d tensor parallel in Colossal-AI (#39) * fixed 1D and 2D convergence (#38) * optimized 2D operations * fixed 1D ViT convergence problem * Feature/ddp (#49) * remove redundancy func in setup (#19) (#20) * use env to control the language of doc (#24) (#25) * Support TP-compatible Torch AMP and Update trainer API (#27) * Add gradient accumulation, fix lr scheduler * fix FP16 optimizer and adapted torch amp with tensor parallel (#18) * fixed bugs in compatibility between torch amp and tensor parallel and performed some minor fixes * fixed trainer * Revert "fixed trainer" This reverts commit2e0b0b7699
. * improved consistency between trainer, engine and schedule (#23) Co-authored-by: 1SAA <c2h214748@gmail.com> Co-authored-by: 1SAA <c2h214748@gmail.com> Co-authored-by: ver217 <lhx0217@gmail.com> * add an example of ViT-B/16 and remove w_norm clipping in LAMB (#29) * add explanation for ViT example (#35) (#36) * support torch ddp * fix loss accumulation * add log for ddp * change seed * modify timing hook Co-authored-by: Frank Lee <somerlee.9@gmail.com> Co-authored-by: 1SAA <c2h214748@gmail.com> Co-authored-by: binmakeswell <binmakeswell@gmail.com> * Feature/pipeline (#40) * remove redundancy func in setup (#19) (#20) * use env to control the language of doc (#24) (#25) * Support TP-compatible Torch AMP and Update trainer API (#27) * Add gradient accumulation, fix lr scheduler * fix FP16 optimizer and adapted torch amp with tensor parallel (#18) * fixed bugs in compatibility between torch amp and tensor parallel and performed some minor fixes * fixed trainer * Revert "fixed trainer" This reverts commit2e0b0b7699
. * improved consistency between trainer, engine and schedule (#23) Co-authored-by: 1SAA <c2h214748@gmail.com> Co-authored-by: 1SAA <c2h214748@gmail.com> Co-authored-by: ver217 <lhx0217@gmail.com> * add an example of ViT-B/16 and remove w_norm clipping in LAMB (#29) * add explanation for ViT example (#35) (#36) * optimize communication of pipeline parallel * fix grad clip for pipeline Co-authored-by: Frank Lee <somerlee.9@gmail.com> Co-authored-by: 1SAA <c2h214748@gmail.com> Co-authored-by: binmakeswell <binmakeswell@gmail.com> * optimized 3d layer to fix slow computation ; tested imagenet performance with 3d; reworked lr_scheduler config definition; fixed launch args; fixed some printing issues; simplified apis of 3d layers (#51) * Update 2.5d layer code to get a similar accuracy on imagenet-1k dataset * update api for better usability (#58) update api for better usability Co-authored-by: 1SAA <c2h214748@gmail.com> Co-authored-by: ver217 <lhx0217@gmail.com> Co-authored-by: puck_WCR <46049915+WANG-CR@users.noreply.github.com> Co-authored-by: binmakeswell <binmakeswell@gmail.com> Co-authored-by: アマデウス <kurisusnowdeng@users.noreply.github.com> Co-authored-by: BoxiangW <45734921+BoxiangW@users.noreply.github.com>
This commit is contained in:
@@ -3,377 +3,326 @@
|
||||
|
||||
import argparse
|
||||
import pprint
|
||||
import random
|
||||
from pathlib import Path
|
||||
from typing import Callable, Iterable, Optional, Union
|
||||
from typing import Tuple
|
||||
|
||||
import os
|
||||
from colossalai.nn.optimizer.colossalai_optimizer import ColossalaiOptimizer
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
import torch.nn as nn
|
||||
|
||||
from colossalai.engine import AMP_TYPE, NoPipelineSchedule, PipelineSchedule
|
||||
from pathlib import Path
|
||||
from typing import Iterable, Union, Optional, Tuple, List, Dict
|
||||
|
||||
from colossalai.amp import convert_to_amp, AMP_TYPE
|
||||
from colossalai.context import Config, ParallelMode, ConfigException
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.engine import Engine
|
||||
from colossalai.logging import get_global_dist_logger, init_global_dist_logger
|
||||
from colossalai.nn import DataParallelSampler
|
||||
from colossalai.nn.model.base_model import BaseModel
|
||||
from .builder import (ModelInitializer, build_dataset, build_loss,
|
||||
build_model, build_optimizer,
|
||||
build_optimizer_wrapper, build_schedule)
|
||||
from .context import Config, ParallelMode
|
||||
from .core import global_context as gpc
|
||||
from .utils import get_current_device, sync_model_param_in_dp
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.utils import (accumulate_gradient, get_current_device,
|
||||
sync_model_param_in_dp, is_using_ddp, is_using_pp)
|
||||
from colossalai.zero import convert_to_zero, ZeroRedundancyOptimizer_Level_2, ZeroRedundancyOptimizer_Level_3
|
||||
from colossalai.builder.builder import build_gradient_handler
|
||||
from torch.optim.optimizer import Optimizer
|
||||
from torch.optim.lr_scheduler import _LRScheduler
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.nn.modules.loss import _Loss
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
|
||||
|
||||
def parse_args():
|
||||
def get_default_parser():
|
||||
'''Reads user command line and uses an argument parser to parse the input arguments.
|
||||
Input arguments include configuration, host, port, world size, local rank, backend for torch.distributed.
|
||||
|
||||
:return: call the parse arguments function
|
||||
:return: returns the parser with the default arguments, the user may add customized arguments into this parser
|
||||
:rtype: Namespace
|
||||
'''
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--config', type=str, help='path to the config file')
|
||||
parser.add_argument('--host',
|
||||
type=str,
|
||||
default=None,
|
||||
help='the master address for distributed training')
|
||||
parser.add_argument('--port',
|
||||
type=str,
|
||||
default=None,
|
||||
type=int,
|
||||
help='the master port for distributed training')
|
||||
parser.add_argument('--world_size', type=int, help='world size for ')
|
||||
parser.add_argument('--world_size', type=int, help='world size for distributed training')
|
||||
parser.add_argument('--rank', type=int, help='rank for the default process group')
|
||||
parser.add_argument('--local_rank',
|
||||
type=int,
|
||||
help='rank for the default process group')
|
||||
help='local rank on the node')
|
||||
parser.add_argument('--backend',
|
||||
type=str,
|
||||
default='nccl',
|
||||
help='backend for torch.distributed')
|
||||
return parser.parse_args()
|
||||
help='backend for distributed communication')
|
||||
return parser
|
||||
|
||||
|
||||
def init_dist(config: Union[str, dict] = None,
|
||||
local_rank: int = None,
|
||||
world_size: int = None,
|
||||
host: str = None,
|
||||
port: str = None,
|
||||
backend: str = None):
|
||||
def launch(config: Union[str, Path, Config, Dict],
|
||||
rank: int,
|
||||
world_size: int,
|
||||
host: str,
|
||||
port: int,
|
||||
backend: str = 'nccl',
|
||||
local_rank: int = None,
|
||||
seed: int = 1024,
|
||||
verbose: bool = True):
|
||||
'''This function first parses the configuration arguments, using :func:parse_args() in case one of the input arguments are not given.
|
||||
Then initialize and set distributed environment by calling global_context's functions.
|
||||
Then initialize and set distributed environment by calling global_context's functions.
|
||||
|
||||
:param config: config file or config file path are both acceptable
|
||||
:type config: Union[str, dict], optional
|
||||
:param local_rank: rank for the default process group, defaults to None
|
||||
:type config: Union[str, dict, Config]
|
||||
:param rank: rank for the default process group
|
||||
:type rank: int
|
||||
:param world_size: world size of the default process group
|
||||
:type world_size: int
|
||||
:param host: the master address for distributed training
|
||||
:type host: str
|
||||
:param port: the master port for distributed training
|
||||
:type port: str
|
||||
:param backend: backend for torch.distributed
|
||||
:type backend: str
|
||||
:param local_rank: rank for the process on the node and is used to set the default CUDA device,
|
||||
defaults to None. If local_rank = None, the default device ordinal will be calculated automatically
|
||||
:type local_rank: int, optional
|
||||
:param world_size: world size of GPUs, defaults to None
|
||||
:type world_size: int, optional
|
||||
:param host: the master address for distributed training, defaults to None
|
||||
:type host: str, optional
|
||||
:param port: the master port for distributed training, defaults to None
|
||||
:type port: str, optional
|
||||
:param backend: backend for torch.distributed, defaults to None
|
||||
:type backend: str, optional
|
||||
:raises Exception: raise exception when config type is wrong
|
||||
'''
|
||||
args = [config, local_rank, world_size, host, port, backend]
|
||||
arg_given = [arg is not None for arg in args]
|
||||
|
||||
if not all(arg_given):
|
||||
args = parse_args()
|
||||
|
||||
if config is None:
|
||||
config = args.config
|
||||
if local_rank is None:
|
||||
local_rank = args.local_rank
|
||||
if world_size is None:
|
||||
world_size = args.world_size
|
||||
if host is None:
|
||||
host = args.host
|
||||
if port is None:
|
||||
port = args.port
|
||||
if backend is None:
|
||||
backend = args.backend
|
||||
args = Config(
|
||||
dict(config=config,
|
||||
host=host,
|
||||
port=port,
|
||||
world_size=world_size,
|
||||
local_rank=local_rank,
|
||||
backend=backend))
|
||||
|
||||
# set distributed settings
|
||||
dist_args = Config(
|
||||
dict(local_rank=args.local_rank,
|
||||
world_size=args.world_size,
|
||||
backend=args.backend))
|
||||
|
||||
gpc.set_dist_args(dist_args)
|
||||
gpc.verbose = verbose
|
||||
|
||||
# set config
|
||||
if isinstance(args.config, dict):
|
||||
cfg = args.config
|
||||
elif isinstance(args.config, (str, Path)):
|
||||
cfg = Config.from_file(args.config)
|
||||
else:
|
||||
raise Exception('Config type error: {}'.format(type(args.config)))
|
||||
gpc.load_config(cfg)
|
||||
assert isinstance(config, (Config, str, Path, dict)), \
|
||||
f'expected argument config to be Config, str or Path, but got {type(config)}'
|
||||
if not isinstance(config, Config) and isinstance(config, dict):
|
||||
config = Config(config)
|
||||
if isinstance(config, (str, Path)):
|
||||
config = Config.from_file(config)
|
||||
gpc.load_config(config)
|
||||
|
||||
# init dist groups
|
||||
gpc.init_global_dist(args.host, args.port)
|
||||
# init default process group
|
||||
gpc.init_global_dist(rank, world_size, backend, host, port)
|
||||
|
||||
# init process groups for different parallel modes from config
|
||||
gpc.init_parallel_groups()
|
||||
|
||||
# init dist logger
|
||||
init_global_dist_logger()
|
||||
|
||||
# set cuda device
|
||||
if torch.cuda.is_available():
|
||||
gpc.set_device()
|
||||
# if local rank is not given, calculate automatically
|
||||
gpc.set_device(local_rank)
|
||||
|
||||
gpc.set_seed(seed)
|
||||
|
||||
if verbose:
|
||||
logger = get_dist_logger()
|
||||
logger.info(f'Distributed environment is initialized, '
|
||||
f'data parallel size: {gpc.data_parallel_size}, pipeline parallel size: {gpc.pipeline_parallel_size}, '
|
||||
f'tensor parallel size: {gpc.tensor_parallel_size}', ranks=[0])
|
||||
|
||||
|
||||
def get_dataloader(dataset, seed=1024, add_sampler_if_possible=False, **kwargs):
|
||||
'''Set up a deterministic dataloader (also configure seed workers, samplers and whether shuffle or not)
|
||||
|
||||
.. note: when pipeline parallel is enabled, shuffle cannot be True
|
||||
as it will result in mismatch between input data on the 1st
|
||||
stage and label on the last stage
|
||||
|
||||
:param dataset: a :class:utils.data.dataset dataset
|
||||
:param seed: random worker seed, defaults to 1024
|
||||
:type seed: int, optional
|
||||
:param add_sampler_if_possible: [description], defaults to False
|
||||
:type add_sampler_if_possible: bool, optional
|
||||
:return: a :class:utils.data.dataset dataloader
|
||||
:rtype: torch.utils.data.dataset
|
||||
'''
|
||||
_kwargs = kwargs.copy()
|
||||
if 'shuffle' in _kwargs:
|
||||
shuffle = _kwargs.pop('shuffle')
|
||||
else:
|
||||
shuffle = False
|
||||
|
||||
if add_sampler_if_possible and gpc.is_initialized(ParallelMode.DATA) and gpc.get_world_size(ParallelMode.DATA) > 1:
|
||||
sampler = DataParallelSampler(dataset, shuffle=shuffle)
|
||||
else:
|
||||
sampler = None
|
||||
|
||||
# Deterministic dataloader
|
||||
def seed_worker(worker_id):
|
||||
worker_seed = seed
|
||||
np.random.seed(worker_seed)
|
||||
torch.manual_seed(worker_seed)
|
||||
random.seed(worker_seed)
|
||||
|
||||
if sampler is None:
|
||||
return DataLoader(dataset,
|
||||
worker_init_fn=seed_worker,
|
||||
shuffle=shuffle,
|
||||
**_kwargs)
|
||||
else:
|
||||
return DataLoader(dataset,
|
||||
sampler=sampler,
|
||||
worker_init_fn=seed_worker,
|
||||
**_kwargs)
|
||||
def launch_from_slurm(config: Union[str, Path, Config, Dict],
|
||||
host: str,
|
||||
port: int,
|
||||
backend: str = 'nccl',
|
||||
seed: int = 1024,
|
||||
verbose: bool = True):
|
||||
rank = int(os.environ['SLURM_PROCID'])
|
||||
world_size = int(os.environ['SLURM_NPROCS'])
|
||||
launch(config=config,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
host=host,
|
||||
port=port,
|
||||
backend=backend,
|
||||
seed=seed,
|
||||
verbose=verbose)
|
||||
|
||||
|
||||
def initialize(config: Union[str, dict] = None,
|
||||
local_rank: int = None,
|
||||
world_size: int = None,
|
||||
host: str = None,
|
||||
port: str = None,
|
||||
backend: str = None,
|
||||
train_dataloader: Optional[Union[Iterable, Callable]] = None,
|
||||
test_dataloader: Optional[Union[Iterable, Callable]] = None,
|
||||
def launch_from_openmpi(config: Union[str, Path, Config, Dict],
|
||||
host: str,
|
||||
port: int,
|
||||
backend: str = 'nccl',
|
||||
seed: int = 1024,
|
||||
verbose: bool = True):
|
||||
rank = int(os.environ['OMPI_COMM_WORLD_RANK'])
|
||||
local_rank = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK'])
|
||||
world_size = int(os.environ['OMPI_COMM_WORLD_SIZE'])
|
||||
launch(config=config,
|
||||
local_rank=local_rank,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
host=host,
|
||||
port=port,
|
||||
backend=backend,
|
||||
seed=seed,
|
||||
verbose=verbose)
|
||||
|
||||
|
||||
def launch_from_torch(config: Union[str, Path, Config, Dict],
|
||||
host: str,
|
||||
port: int,
|
||||
backend: str = 'nccl',
|
||||
seed: int = 1024,
|
||||
verbose: bool = True):
|
||||
rank = int(os.environ['RANK'])
|
||||
local_rank = int(os.environ['LOCAL_RANK'])
|
||||
world_size = int(os.environ['WORLD_SIZE'])
|
||||
launch(config=config,
|
||||
local_rank=local_rank,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
host=host,
|
||||
port=port,
|
||||
backend=backend,
|
||||
seed=seed,
|
||||
verbose=verbose)
|
||||
|
||||
|
||||
def initialize(model: Union[nn.Module, List[nn.Module]],
|
||||
optimizer: Union[Optimizer, List[Optimizer]],
|
||||
criterion: Union[_Loss, List[_Loss]],
|
||||
train_dataloader: Optional[Union[Iterable, List[Iterable]]] = None,
|
||||
test_dataloader: Optional[Union[Iterable, List[Iterable]]] = None,
|
||||
lr_scheduler: _LRScheduler = None,
|
||||
verbose: bool = True
|
||||
) -> Tuple[Engine, DataLoader, DataLoader]:
|
||||
'''Core function that initializes distributed environment, logger, cudnn, data, model, loss function, optimizer, and lr_scheduler(their configs are in gpc.config).
|
||||
''' Core function to wrap the essential training components with our functionality based on the config which is loaded into gpc.config.
|
||||
|
||||
:param config: config file or config file path are both acceptable
|
||||
:type config: Union[str, dict], optional
|
||||
:param local_rank: rank for the default process group, defaults to None
|
||||
:type local_rank: int, optional
|
||||
:param world_size: world size of GPUs, defaults to None
|
||||
:type world_size: int, optional
|
||||
:param host: the master address for distributed training, defaults to None
|
||||
:type host: str, optional
|
||||
:param port: the master port for distributed training, defaults to None
|
||||
:type port: str, optional
|
||||
:param backend: backend for torch.distributed, defaults to None
|
||||
:type backend: str, optional
|
||||
:param train_dataloader: If None, the config is used to build a dataloder; Else, it should be a dataloader object or a function with no arguments which can build a dataloader, defaults to None
|
||||
:type train_dataloader: Optional[Union[Iterable, Callable]], optional
|
||||
:param test_dataloader: If None, the config is used to build a dataloder; Else, it should be a dataloader object or a function with no arguments which can build a dataloader, defaults to None
|
||||
:type test_dataloader: Optional[Union[Iterable, Callable]], optional
|
||||
:return: (engine, train_dataloader, test_dataloader, criterion)
|
||||
:param model: your model instance
|
||||
:type model: a single or a list of ``torch.nn.Module`` objects
|
||||
:param optimizer: your optimizer instance
|
||||
:type optimizer: a single or a list of ``torch.optim.optimizer.Optimizer`` objects
|
||||
:param criterion: your criterion instance
|
||||
:type criterion: a single or a list of ``torch.nn.modules.loss._Loss`` objects
|
||||
:param train_dataloader: dataloaders for training data
|
||||
:type train_dataloader: a single or a list of ``torch.utils.data.DataLoader`` objects, defaults to None
|
||||
:param train_dataloader: dataloaders for testing data
|
||||
:type train_dataloader: a single or a list of ``torch.utils.data.DataLoader`` objects, defaults to None
|
||||
:return: (engine, criterion, train_dataloader, test_dataloader)
|
||||
:rtype: tuple
|
||||
'''
|
||||
# initialize distributed environment
|
||||
init_dist(config=config,
|
||||
local_rank=local_rank,
|
||||
world_size=world_size,
|
||||
host=host,
|
||||
port=port,
|
||||
backend=backend)
|
||||
# get logger
|
||||
logger = get_dist_logger()
|
||||
gpc.verbose = verbose
|
||||
|
||||
# init logger
|
||||
logger = get_global_dist_logger()
|
||||
logger.info(f'Distributed environment is initialized, '
|
||||
f'data parallel size: {gpc.data_parallel_size}, pipeline parallel size: {gpc.pipeline_parallel_size}, '
|
||||
f'tensor parallel size: {gpc.tensor_parallel_size}', ranks=[0])
|
||||
# get config from gpc
|
||||
config = gpc.config
|
||||
|
||||
# print config
|
||||
logger.info(f"\n========== Your Config ========\n"
|
||||
f"{pprint.pformat(gpc.config)}\n"
|
||||
f"================================", ranks=[0])
|
||||
if verbose:
|
||||
logger.info(f"\n========== Your Config ========\n"
|
||||
f"{pprint.pformat(gpc.config)}\n"
|
||||
f"================================\n", ranks=[0])
|
||||
|
||||
# cudnn
|
||||
cudnn_benchmark = gpc.config.get('cudnn_benchmark', True)
|
||||
cudnn_deterministic = gpc.config.get('cudnn_deterministic', False)
|
||||
cudnn_benchmark = config.get('cudnn_benchmark', True)
|
||||
cudnn_deterministic = config.get('cudnn_deterministic', False)
|
||||
torch.backends.cudnn.benchmark = cudnn_benchmark
|
||||
torch.backends.cudnn.deterministic = cudnn_deterministic
|
||||
logger.info(
|
||||
f"cuDNN benchmark = {cudnn_benchmark}, deterministic = {cudnn_deterministic}", ranks=[0])
|
||||
if verbose:
|
||||
logger.info(
|
||||
f"cuDNN benchmark = {cudnn_benchmark}, deterministic = {cudnn_deterministic}", ranks=[0])
|
||||
|
||||
# set seed, cuda seed is only set when cuda is avail
|
||||
gpc.set_seed()
|
||||
|
||||
# return_items = list()
|
||||
|
||||
# check fp16 and zero
|
||||
should_convert_model_to_half = False
|
||||
should_wrap_fp16_optimizer = False
|
||||
should_wrap_zero_optimizer_level_2_3 = False
|
||||
|
||||
if hasattr(gpc.config, 'fp16'):
|
||||
fp16_mode = gpc.config.fp16.mode
|
||||
if fp16_mode == AMP_TYPE.PARALLEL:
|
||||
should_convert_model_to_half = True
|
||||
should_wrap_fp16_optimizer = True
|
||||
|
||||
if hasattr(gpc.config, 'zero'):
|
||||
should_wrap_zero_optimizer_level_2_3 = True
|
||||
zero_type = gpc.config.zero.type
|
||||
if zero_type in ['ZeroRedundancyOptimizer_Level_2', 'ZeroRedundancyOptimizer_Level_3']:
|
||||
should_convert_model_to_half = True
|
||||
assert not should_wrap_fp16_optimizer, \
|
||||
'AMP_TYPE.PARALLEL is mutually exclusive with zero level 2 and 3'
|
||||
|
||||
# build model
|
||||
logger.info('Building model ...', ranks=[0])
|
||||
assert hasattr(
|
||||
gpc.config, 'model'), "Build error: configuration 'model' is missing"
|
||||
if gpc.pipeline_parallel_size > 1:
|
||||
model = ModelInitializer(gpc.config.model, 1, verbose=True)
|
||||
model = model.model_initialize()
|
||||
else:
|
||||
model = build_model(gpc.config.model)
|
||||
if isinstance(model, BaseModel):
|
||||
model.build_from_cfg()
|
||||
model = model.to(get_current_device())
|
||||
# first sync model across dp ranks
|
||||
model.to(get_current_device())
|
||||
sync_model_param_in_dp(model)
|
||||
logger.info('Model is created', ranks=[0])
|
||||
|
||||
if should_convert_model_to_half:
|
||||
model = model.half()
|
||||
logger.info("Model is cast to fp16", ranks=[0])
|
||||
# check amp and zero
|
||||
fp16_cfg = gpc.config.get('fp16', None)
|
||||
zero_cfg = gpc.config.get('zero', None)
|
||||
|
||||
# training data
|
||||
if callable(train_dataloader):
|
||||
logger.info(
|
||||
f'Build train data loader from {train_dataloader}', ranks=[0])
|
||||
train_dataloader = train_dataloader()
|
||||
if train_dataloader is None and hasattr(gpc.config, 'train_data'):
|
||||
logger.info('Preparing data ...', ranks=[0])
|
||||
# assert hasattr(gpc.config, 'train_data'), "Build error: configuration 'train_data' is missing."
|
||||
train_dataset = build_dataset(gpc.config.train_data.dataset)
|
||||
logger.info('Train dataset is ready.', ranks=[0])
|
||||
if fp16_cfg is not None and fp16_cfg.mode is not None and zero_cfg is not None:
|
||||
raise ConfigException(
|
||||
"It is not allowed to set fp16 and zero configuration in your config file at the same time")
|
||||
|
||||
train_dataloader = get_dataloader(train_dataset,
|
||||
gpc.config.get('seed', 1024),
|
||||
True,
|
||||
**gpc.config.train_data.dataloader,
|
||||
)
|
||||
logger.info(
|
||||
f'Loaded {len(train_dataset)} samples in {len(train_dataloader)} batches for training', ranks=[0])
|
||||
# initialize amp
|
||||
amp_mode = None
|
||||
if fp16_cfg is not None and fp16_cfg.mode is not None:
|
||||
cfg_ = fp16_cfg.copy()
|
||||
amp_mode = cfg_.pop('mode')
|
||||
model, optimizer, criterion = convert_to_amp(model=model,
|
||||
optimizer=optimizer,
|
||||
criterion=criterion,
|
||||
mode=amp_mode,
|
||||
amp_config=cfg_)
|
||||
|
||||
if callable(test_dataloader):
|
||||
logger.info(
|
||||
f'Build test data loader from {test_dataloader}', ranks=[0])
|
||||
test_dataloader = test_dataloader()
|
||||
# testing data, allowed to be None
|
||||
if test_dataloader is None and hasattr(gpc.config, 'test_data'):
|
||||
test_dataset = build_dataset(gpc.config.test_data.dataset)
|
||||
test_dataloader = get_dataloader(
|
||||
test_dataset, add_sampler_if_possible=True, **gpc.config.test_data.dataloader)
|
||||
logger.info(
|
||||
f'Loaded {len(test_dataset)} samples in {len(test_dataloader)} batches for testing', ranks=[0])
|
||||
if zero_cfg is not None:
|
||||
cfg_ = zero_cfg.copy()
|
||||
level = cfg_.pop('level')
|
||||
model, optimizer = convert_to_zero(model=model,
|
||||
optimizer=optimizer,
|
||||
level=level,
|
||||
zero_config=cfg_
|
||||
)
|
||||
|
||||
# build loss function
|
||||
assert hasattr(gpc.config, 'loss'), \
|
||||
'Build error: configuration \'loss\' is missing.'
|
||||
criterion = build_loss(gpc.config.loss)
|
||||
logger.info('Loss function is created', ranks=[0])
|
||||
|
||||
# build optimizer
|
||||
assert hasattr(gpc.config, 'optimizer'), \
|
||||
"Build error: configuration 'optimizer' is missing."
|
||||
optim_type = gpc.config.optimizer.type
|
||||
is_pytorch_native_zero_level_1 = optim_type == 'ZeroRedundancyOptimizer'
|
||||
if is_pytorch_native_zero_level_1:
|
||||
original_cfg_copy = gpc.config.optimizer.copy()
|
||||
original_cfg_copy.pop('type')
|
||||
cfg = dict(type=optim_type, process_group=gpc.get_group(
|
||||
ParallelMode.DATA), **original_cfg_copy)
|
||||
optimizer = build_optimizer(cfg, model)
|
||||
# gradient handler
|
||||
gradient_handler_cfg = gpc.config.get('gradient_handler', None)
|
||||
if gradient_handler_cfg is None:
|
||||
# if gradient handler is not specified in the configuration file,
|
||||
# check in the following order
|
||||
# 1. if optimizer is ZERO, then use zero grad handler
|
||||
# 2. if dp size is larger than 1 and pipeline is not used, use pytorch ddp
|
||||
# 3. if using pipeline and dp size larger than 1, use data parallel grad handler
|
||||
if isinstance(optimizer, (ZeroRedundancyOptimizer_Level_2,
|
||||
ZeroRedundancyOptimizer_Level_3)):
|
||||
gradient_handler_cfg = [dict(type='ZeROGradientHandler')]
|
||||
if verbose:
|
||||
logger.info(
|
||||
"Training with zero is detected, ZeROGradientHandler is automatically "
|
||||
"added even though not specified in the configuration",
|
||||
ranks=[0])
|
||||
elif is_using_ddp() and not is_using_pp() and amp_mode != AMP_TYPE.NAIVE:
|
||||
model = DDP(model, process_group=gpc.get_group(ParallelMode.DATA))
|
||||
if verbose:
|
||||
logger.info(
|
||||
'Model is using torch.nn.parallel.DistributedDataParallel', ranks=[0])
|
||||
elif is_using_ddp():
|
||||
gradient_handler_cfg = [dict(type='DataParallelGradientHandler')]
|
||||
if verbose:
|
||||
logger.info(
|
||||
"Data parallel training is detected when using pipeline parallel, DataParallelGradientHandler is automatically "
|
||||
"added even though not specified in the configuration",
|
||||
ranks=[0])
|
||||
else:
|
||||
optimizer = build_optimizer(gpc.config.optimizer, model)
|
||||
if not isinstance(gradient_handler_cfg, list):
|
||||
raise ConfigException(
|
||||
f"expected gradient_handler in the configuration file to be a list but got {type(gradient_handler_cfg)}")
|
||||
|
||||
if should_wrap_zero_optimizer_level_2_3:
|
||||
optimizer = build_optimizer_wrapper(gpc.config.zero, optimizer, model)
|
||||
|
||||
if should_wrap_fp16_optimizer:
|
||||
# replace the field mode with type
|
||||
fp16_cfg = gpc.config.fp16.copy()
|
||||
amp_type = fp16_cfg.pop('mode')
|
||||
assert amp_type == AMP_TYPE.PARALLEL, 'FP Optimizer should only be used for AMP_TYPE.PARALLEL'
|
||||
fp16_cfg['type'] = 'FP16Optimizer'
|
||||
optimizer = build_optimizer_wrapper(fp16_cfg, optimizer)
|
||||
logger.info('Optimizer is created', ranks=[0])
|
||||
|
||||
# build schedule and engine
|
||||
if hasattr(gpc.config, 'fp16'):
|
||||
amp_type = gpc.config.fp16.mode
|
||||
amp_cfg = gpc.config.fp16.copy()
|
||||
amp_cfg.pop('mode')
|
||||
if gradient_handler_cfg is None:
|
||||
gradient_handlers = None
|
||||
if verbose and not isinstance(model, DDP):
|
||||
logger.warning(
|
||||
"No PyTorch DDP or gradient handler is set up, please make sure you do not need "
|
||||
"to all-reduce the gradients after a training step.",
|
||||
ranks=[0])
|
||||
else:
|
||||
amp_type = None
|
||||
amp_cfg = None
|
||||
gradient_handlers = [build_gradient_handler(cfg, model, optimizer) for cfg in gradient_handler_cfg]
|
||||
|
||||
engine_cfg = gpc.config.get('engine', dict())
|
||||
schedule_cfg = engine_cfg.pop('schedule', None)
|
||||
# check if optimizer is ColossalaiOptimizer
|
||||
if not isinstance(optimizer, (ColossalaiOptimizer, ZeroRedundancyOptimizer_Level_2, ZeroRedundancyOptimizer_Level_3)):
|
||||
optimizer = ColossalaiOptimizer(optim=optimizer)
|
||||
|
||||
schedule_type = None
|
||||
if schedule_cfg is not None:
|
||||
schedule_type = schedule_cfg.get('type', None)
|
||||
# gradient accumulation
|
||||
grad_accum_size = gpc.config.get('gradient_accumulation', None)
|
||||
if grad_accum_size is not None:
|
||||
optimizer, train_dataloader, gradient_handlers, lr_scheduler = accumulate_gradient(model=model,
|
||||
optimizer=optimizer,
|
||||
dataloader=train_dataloader,
|
||||
accumulate_size=grad_accum_size,
|
||||
gradient_handlers=gradient_handlers,
|
||||
lr_scheduler=lr_scheduler)
|
||||
|
||||
if schedule_type is not None:
|
||||
# run customized schedule
|
||||
schedule_cfg['amp_type'] = amp_type
|
||||
schedule_cfg['amp_config'] = amp_cfg
|
||||
schedule = build_schedule(schedule_cfg)
|
||||
elif gpc.is_initialized(ParallelMode.PIPELINE) and gpc.get_world_size(ParallelMode.PIPELINE) > 1:
|
||||
assert schedule_cfg is not None, \
|
||||
"Config 'engine.schedule' not found in your configuration file for pipeline parallel training"
|
||||
schedule = PipelineSchedule(
|
||||
amp_type=amp_type, amp_config=amp_cfg, **schedule_cfg.copy())
|
||||
else:
|
||||
schedule = NoPipelineSchedule(amp_type=amp_type, amp_config=amp_cfg)
|
||||
# clip grad norm
|
||||
clip_grad_norm = gpc.config.get('clip_grad_norm', 0.0)
|
||||
if clip_grad_norm > 0:
|
||||
if zero_cfg is not None:
|
||||
raise ConfigException(
|
||||
"clip_grad_norm should be specified with zero, you should specify clip_grad in zero configuration")
|
||||
elif fp16_cfg is not None and fp16_cfg.mode == AMP_TYPE.NAIVE:
|
||||
raise ConfigException(
|
||||
"clip_grad_norm should be specified with AMP_TYPE.NAIVE, you should specify clip_grad in fp16 configuration")
|
||||
|
||||
engine = Engine(
|
||||
model=model,
|
||||
optimizer=optimizer,
|
||||
criterion=criterion,
|
||||
step_schedule=schedule,
|
||||
**gpc.config.get('engine', dict())
|
||||
gradient_handlers=gradient_handlers,
|
||||
clip_grad_norm=clip_grad_norm
|
||||
)
|
||||
|
||||
return engine, train_dataloader, test_dataloader
|
||||
return engine, train_dataloader, test_dataloader, lr_scheduler
|
||||
|
Reference in New Issue
Block a user