mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-08-03 17:19:51 +00:00
* init
* rename and remove useless func
* basic chunk
* add evoformer
* align evoformer
* add meta
* basic chunk
* basic memory
* finish basic inference memory estimation
* finish memory estimation
* fix bug
* finish memory estimation
* add part of index tracer
* finish basic index tracer
* add doc string
* add doc str
* polish code
* polish code
* update active log
* polish code
* add possible region search
* finish region search loop
* finish chunk define
* support new op
* rename index tracer
* finishi codegen on msa
* redesign index tracer, add source and change compute
* pass outproduct mean
* code format
* code format
* work with outerproductmean and msa
* code style
* code style
* code style
* code style
* change threshold
* support check_index_duplicate
* support index dupilictae and update loop
* support output
* update memory estimate
* optimise search
* fix layernorm
* move flow tracer
* refactor flow tracer
* format code
* refactor flow search
* code style
* adapt codegen to prepose node
* code style
* remove abandoned function
* remove flow tracer
* code style
* code style
* reorder nodes
* finish node reorder
* update run
* code style
* add chunk select class
* add chunk select
* code style
* add chunksize in emit, fix bug in reassgin shape
* code style
* turn off print mem
* add evoformer openfold init
* init openfold
* add benchmark
* add print
* code style
* code style
* init openfold
* update openfold
* align openfold
* use max_mem to control stratge
* update source add
* add reorder in mem estimator
* improve reorder efficeincy
* support ones_like, add prompt if fit mode search fail
* fix a bug in ones like, dont gen chunk if dim size is 1
* fix bug again
* update min memory stratege, reduce mem usage by 30%
* last version of benchmark
* refactor structure
* restruct dir
* update test
* rename
* take apart chunk code gen
* close mem and code print
* code format
* rename ambiguous variable
* seperate flow tracer
* seperate input node dim search
* seperate prepose_nodes
* seperate non chunk input
* seperate reorder
* rename
* ad reorder graph
* seperate trace flow
* code style
* code style
* fix typo
* set benchmark
* rename test
* update codegen test
* Fix state_dict key missing issue of the ZeroDDP (#2363)
* Fix state_dict output for ZeroDDP duplicated parameters
* Rewrite state_dict based on get_static_torch_model
* Modify get_static_torch_model to be compatible with the lower version (ZeroDDP)
* update codegen test
* update codegen test
* add chunk search test
* code style
* add available
* [hotfix] fix gpt gemini example (#2404)
* [hotfix] fix gpt gemini example
* [example] add new assertions
* remove autochunk_available
* [workflow] added nightly release to pypi (#2403)
* add comments
* code style
* add doc for search chunk
* [doc] updated readme regarding pypi installation (#2406)
* add doc for search
* [doc] updated kernel-related optimisers' docstring (#2385)
* [doc] updated kernel-related optimisers' docstring
* polish doc
* rename trace_index to trace_indice
* rename function from index to indice
* rename
* rename in doc
* [polish] polish code for get_static_torch_model (#2405)
* [gemini] polish code
* [testing] remove code
* [gemini] make more robust
* rename
* rename
* remove useless function
* [worfklow] added coverage test (#2399)
* [worfklow] added coverage test
* polish code
* polish code
* polish code
* polish code
* polish code
* polish code
* polish code
* polish code
* add doc for trace indice
* [docker] updated Dockerfile and release workflow (#2410)
* add doc
* update doc
* add available
* change imports
* add test in import
* [workflow] refactored the example check workflow (#2411)
* [workflow] refactored the example check workflow
* polish code
* polish code
* polish code
* polish code
* polish code
* polish code
* polish code
* polish code
* polish code
* polish code
* polish code
* Update parallel_context.py (#2408)
* [hotfix] add DISTPAN argument for benchmark (#2412)
* change the benchmark config file
* change config
* revert config file
* rename distpan to distplan
* [workflow] added precommit check for code consistency (#2401)
* [workflow] added precommit check for code consistency
* polish code
* polish code
* polish code
* polish code
* polish code
* polish code
* polish code
* adapt new fx
* [workflow] added translation for non-english comments (#2414)
* [setup] refactored setup.py for dependency graph (#2413)
* change import
* update doc
* [workflow] auto comment if precommit check fails (#2417)
* [hotfix] add norm clearing for the overflow step (#2416)
* [examples] adding tflops to PaLM (#2365)
* [workflow]auto comment with test coverage report (#2419)
* [workflow]auto comment with test coverage report
* polish code
* polish yaml
* [doc] added documentation for CI/CD (#2420)
* [doc] added documentation for CI/CD
* polish markdown
* polish markdown
* polish markdown
* [example] removed duplicated stable diffusion example (#2424)
* [zero] add inference mode and its unit test (#2418)
* [workflow] report test coverage even if below threshold (#2431)
* [example] improved the clarity yof the example readme (#2427)
* [example] improved the clarity yof the example readme
* polish workflow
* polish workflow
* polish workflow
* polish workflow
* polish workflow
* polish workflow
* [ddp] add is_ddp_ignored (#2434)
[ddp] rename to is_ddp_ignored
* [workflow] make test coverage report collapsable (#2436)
* [autoparallel] add shard option (#2423)
* [fx] allow native ckpt trace and codegen. (#2438)
* [cli] provided more details if colossalai run fail (#2442)
* [autoparallel] integrate device mesh initialization into autoparallelize (#2393)
* [autoparallel] integrate device mesh initialization into autoparallelize
* add megatron solution
* update gpt autoparallel examples with latest api
* adapt beta value to fit the current computation cost
* [zero] fix state_dict and load_state_dict for ddp ignored parameters (#2443)
* [ddp] add is_ddp_ignored
[ddp] rename to is_ddp_ignored
* [zero] fix state_dict and load_state_dict
* fix bugs
* [zero] update unit test for ZeroDDP
* [example] updated the hybrid parallel tutorial (#2444)
* [example] updated the hybrid parallel tutorial
* polish code
* [zero] add warning for ignored parameters (#2446)
* [example] updated large-batch optimizer tutorial (#2448)
* [example] updated large-batch optimizer tutorial
* polish code
* polish code
* [example] fixed seed error in train_dreambooth_colossalai.py (#2445)
* [workflow] fixed the on-merge condition check (#2452)
* [workflow] automated the compatiblity test (#2453)
* [workflow] automated the compatiblity test
* polish code
* [autoparallel] update binary elementwise handler (#2451)
* [autoparallel] update binary elementwise handler
* polish
* [workflow] automated bdist wheel build (#2459)
* [workflow] automated bdist wheel build
* polish workflow
* polish readme
* polish readme
* Fix False warning in initialize.py (#2456)
* Update initialize.py
* pre-commit run check
* [examples] update autoparallel tutorial demo (#2449)
* [examples] update autoparallel tutorial demo
* add test_ci.sh
* polish
* add conda yaml
* [cli] fixed hostname mismatch error (#2465)
* [example] integrate autoparallel demo with CI (#2466)
* [example] integrate autoparallel demo with CI
* polish code
* polish code
* polish code
* polish code
* [zero] low level optim supports ProcessGroup (#2464)
* [example] update vit ci script (#2469)
* [example] update vit ci script
* [example] update requirements
* [example] update requirements
* [example] integrate seq-parallel tutorial with CI (#2463)
* [zero] polish low level optimizer (#2473)
* polish pp middleware (#2476)
Co-authored-by: Ziyue Jiang <ziyue.jiang@gmail.com>
* [example] update gpt gemini example ci test (#2477)
* [zero] add unit test for low-level zero init (#2474)
* [workflow] fixed the skip condition of example weekly check workflow (#2481)
* [example] stable diffusion add roadmap
* add dummy test_ci.sh
* [example] stable diffusion add roadmap (#2482)
* [CI] add test_ci.sh for palm, opt and gpt (#2475)
* polish code
* [example] titans for gpt
* polish readme
* remove license
* polish code
* update readme
* [example] titans for gpt (#2484)
* [autoparallel] support origin activation ckpt on autoprallel system (#2468)
* [autochunk] support evoformer tracer (#2485)
support full evoformer tracer, which is a main module of alphafold. previously we just support a simplifed version of it.
1. support some evoformer's op in fx
2. support evoformer test
3. add repos for test code
* [example] fix requirements (#2488)
* [zero] add unit testings for hybrid parallelism (#2486)
* [hotfix] gpt example titans bug #2493
* polish code and fix dataloader bugs
* [hotfix] gpt example titans bug #2493 (#2494)
* [fx] allow control of ckpt_codegen init (#2498)
* [fx] allow control of ckpt_codegen init
Currently in ColoGraphModule, ActivationCheckpointCodeGen will be set automatically in __init__. But other codegen can't be set if so.
So I add an arg to control whether to set ActivationCheckpointCodeGen in __init__.
* code style
* [example] dreambooth example
* add test_ci.sh to dreambooth
* [autochunk] support autochunk on evoformer (#2497)
* Revert "Update parallel_context.py (#2408)"
This reverts commit 7d5640b9db
.
* add avg partition (#2483)
Co-authored-by: Ziyue Jiang <ziyue.jiang@gmail.com>
* [auto-chunk] support extramsa (#3) (#2504)
* [utils] lazy init. (#2148)
* [utils] lazy init.
* [utils] remove description.
* [utils] complete.
* [utils] finalize.
* [utils] fix names.
* [autochunk] support parsing blocks (#2506)
* [zero] add strict ddp mode (#2508)
* [zero] add strict ddp mode
* [polish] add comments for strict ddp mode
* [zero] fix test error
* [doc] update opt and tutorial links (#2509)
* [workflow] fixed changed file detection (#2515)
Co-authored-by: oahzxl <xuanlei.zhao@gmail.com>
Co-authored-by: eric8607242 <e0928021388@gmail.com>
Co-authored-by: HELSON <c2h214748@gmail.com>
Co-authored-by: Frank Lee <somerlee.9@gmail.com>
Co-authored-by: Haofan Wang <haofanwang.ai@gmail.com>
Co-authored-by: Jiarui Fang <fangjiarui123@gmail.com>
Co-authored-by: ZijianYY <119492445+ZijianYY@users.noreply.github.com>
Co-authored-by: YuliangLiu0306 <72588413+YuliangLiu0306@users.noreply.github.com>
Co-authored-by: Super Daniel <78588128+super-dainiu@users.noreply.github.com>
Co-authored-by: ver217 <lhx0217@gmail.com>
Co-authored-by: Ziyue Jiang <ziyue.jiang97@gmail.com>
Co-authored-by: Ziyue Jiang <ziyue.jiang@gmail.com>
Co-authored-by: oahzxl <43881818+oahzxl@users.noreply.github.com>
Co-authored-by: binmakeswell <binmakeswell@gmail.com>
Co-authored-by: Fazzie-Maqianli <55798671+Fazziekey@users.noreply.github.com>
Co-authored-by: アマデウス <kurisusnowdeng@users.noreply.github.com>
472 lines
21 KiB
Python
472 lines
21 KiB
Python
#!/usr/bin/env python
|
|
# -*- encoding: utf-8 -*-
|
|
|
|
import argparse
|
|
import os
|
|
import pprint
|
|
from pathlib import Path
|
|
from typing import Callable, Dict, Iterable, List, Optional, Tuple, Union
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
from torch.nn.modules.loss import _Loss
|
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
|
from torch.optim.lr_scheduler import _LRScheduler
|
|
from torch.optim.optimizer import Optimizer
|
|
from torch.utils.data import DataLoader
|
|
|
|
from colossalai.amp import AMP_TYPE, convert_to_amp
|
|
from colossalai.amp.naive_amp import NaiveAMPModel
|
|
from colossalai.builder.builder import build_gradient_handler
|
|
from colossalai.context import Config, ConfigException, ParallelMode
|
|
from colossalai.context.moe_context import MOE_CONTEXT
|
|
from colossalai.core import global_context as gpc
|
|
from colossalai.engine import Engine
|
|
from colossalai.engine.gradient_accumulation import accumulate_gradient
|
|
from colossalai.engine.schedule import (
|
|
InterleavedPipelineSchedule,
|
|
NonPipelineSchedule,
|
|
PipelineSchedule,
|
|
get_tensor_shape,
|
|
)
|
|
from colossalai.gemini.ophooks import BaseOpHook
|
|
from colossalai.logging import get_dist_logger
|
|
from colossalai.nn.optimizer.colossalai_optimizer import ColossalaiOptimizer
|
|
from colossalai.utils import get_current_device, is_using_ddp, is_using_pp, is_using_sequence, sync_model_param
|
|
from colossalai.utils.moe import sync_moe_model_param
|
|
from colossalai.zero import convert_to_zero_v2
|
|
from colossalai.zero.sharded_optim.sharded_optim_v2 import ShardedOptimizerV2
|
|
|
|
|
|
def get_default_parser():
|
|
"""Reads user command line and uses an argument parser to parse the input arguments.
|
|
Input arguments include configuration, host, port, world size, local rank, backend for torch.distributed.
|
|
|
|
Returns:
|
|
Namespace: Returns the parser with the default arguments, the user may add customized arguments into this parser.
|
|
"""
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument('--config', type=str, help='path to the config file')
|
|
parser.add_argument('--host', type=str, help='the master address for distributed training')
|
|
parser.add_argument('--port', type=int, help='the master port for distributed training')
|
|
parser.add_argument('--world_size', type=int, help='world size for distributed training')
|
|
parser.add_argument('--rank', type=int, help='rank for the default process group')
|
|
parser.add_argument('--local_rank', type=int, help='local rank on the node')
|
|
parser.add_argument('--backend', type=str, default='nccl', help='backend for distributed communication')
|
|
return parser
|
|
|
|
|
|
def launch(config: Union[str, Path, Config, Dict],
|
|
rank: int,
|
|
world_size: int,
|
|
host: str,
|
|
port: int,
|
|
backend: str = 'nccl',
|
|
local_rank: int = None,
|
|
seed: int = 1024,
|
|
verbose: bool = True):
|
|
"""This function first parses the configuration arguments, using :func:`parse_args()` in case one of the input
|
|
arguments are not given. Then initialize and set distributed environment by calling global_context's functions.
|
|
|
|
Args:
|
|
config (Union[str, dict, Config]): Config file or config file path are both acceptable
|
|
rank (int): Rank for the default process group
|
|
world_size (int): World size of the default process group
|
|
host (str): The master address for distributed training
|
|
port (str): The master port for distributed training
|
|
backend (str, optional): Backend for ``torch.distributed``, defaults to ``nccl``
|
|
local_rank (int, optional):
|
|
Rank for the process on the node and is used to set the default CUDA device,
|
|
defaults to None. If local_rank = None, the default device ordinal will be calculated automatically.
|
|
seed (int, optional): Specified random seed for every process. Defaults to 1024.
|
|
verbose (bool, optional): Whether to print logs. Defaults to True.
|
|
|
|
Raises:
|
|
Exception: Raise exception when config type is wrong
|
|
"""
|
|
gpc.verbose = verbose
|
|
|
|
# set config
|
|
assert isinstance(config, (Config, str, Path, dict)), \
|
|
f'expected argument config to be Config, str or Path, but got {type(config)}'
|
|
if not isinstance(config, Config) and isinstance(config, dict):
|
|
config = Config(config)
|
|
if isinstance(config, (str, Path)):
|
|
config = Config.from_file(config)
|
|
gpc.load_config(config)
|
|
|
|
# init default process group
|
|
gpc.init_global_dist(rank, world_size, backend, host, port)
|
|
|
|
# init process groups for different parallel modes from config
|
|
gpc.init_parallel_groups()
|
|
|
|
# set cuda device
|
|
if torch.cuda.is_available():
|
|
# if local rank is not given, calculate automatically
|
|
gpc.set_device(local_rank)
|
|
|
|
# set the number of processes running on the same node
|
|
gpc.detect_num_processes_on_current_node()
|
|
|
|
gpc.set_seed(seed)
|
|
|
|
if verbose:
|
|
logger = get_dist_logger()
|
|
logger.info(
|
|
f'Distributed environment is initialized, '
|
|
f'data parallel size: {gpc.data_parallel_size}, pipeline parallel size: {gpc.pipeline_parallel_size}, '
|
|
f'tensor parallel size: {gpc.tensor_parallel_size}',
|
|
ranks=[0])
|
|
|
|
|
|
def launch_from_slurm(config: Union[str, Path, Config, Dict],
|
|
host: str,
|
|
port: int,
|
|
backend: str = 'nccl',
|
|
seed: int = 1024,
|
|
verbose: bool = True):
|
|
"""A wrapper for colossalai.launch for SLURM launcher by reading rank and world size from the environment variables
|
|
set by SLURM
|
|
|
|
Args:
|
|
config (Union[str, dict, Config]): Config file or config file path are both acceptable
|
|
host (str): The master address for distributed training
|
|
port (str): The master port for distributed training
|
|
backend (str, optional): Backend for ``torch.distributed``, defaults to ``nccl``
|
|
seed (int, optional): Specified random seed for every process. Defaults to 1024.
|
|
verbose (bool, optional): Whether to print logs. Defaults to True.
|
|
"""
|
|
try:
|
|
rank = int(os.environ['SLURM_PROCID'])
|
|
world_size = int(os.environ['SLURM_NPROCS'])
|
|
except KeyError as e:
|
|
raise RuntimeError(
|
|
f"Could not find {e} in the SLURM environment, visit https://www.colossalai.org/ for more information on launching with SLURM"
|
|
)
|
|
|
|
launch(config=config,
|
|
rank=rank,
|
|
world_size=world_size,
|
|
host=host,
|
|
port=port,
|
|
backend=backend,
|
|
seed=seed,
|
|
verbose=verbose)
|
|
|
|
|
|
def launch_from_openmpi(config: Union[str, Path, Config, Dict],
|
|
host: str,
|
|
port: int,
|
|
backend: str = 'nccl',
|
|
seed: int = 1024,
|
|
verbose: bool = True):
|
|
"""A wrapper for colossalai.launch for OpenMPI launcher by reading rank and world size from the environment variables
|
|
set by OpenMPI
|
|
|
|
Args:
|
|
config (Union[str, dict, Config]): Config file or config file path are both acceptable
|
|
host (str): The master address for distributed training
|
|
port (str): The master port for distributed training
|
|
backend (str, optional): Backend for ``torch.distributed``, defaults to ``nccl``
|
|
seed (int, optional): Specified random seed for every process. Defaults to 1024.
|
|
verbose (bool, optional): Whether to print logs. Defaults to True.
|
|
"""
|
|
try:
|
|
rank = int(os.environ['OMPI_COMM_WORLD_RANK'])
|
|
local_rank = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK'])
|
|
world_size = int(os.environ['OMPI_COMM_WORLD_SIZE'])
|
|
except KeyError as e:
|
|
raise RuntimeError(
|
|
f"Could not find {e} in the OpenMPI environment, visit https://www.colossalai.org/ for more information on launching with OpenMPI"
|
|
)
|
|
|
|
launch(config=config,
|
|
local_rank=local_rank,
|
|
rank=rank,
|
|
world_size=world_size,
|
|
host=host,
|
|
port=port,
|
|
backend=backend,
|
|
seed=seed,
|
|
verbose=verbose)
|
|
|
|
|
|
def launch_from_torch(config: Union[str, Path, Config, Dict],
|
|
backend: str = 'nccl',
|
|
seed: int = 1024,
|
|
verbose: bool = True):
|
|
"""A wrapper for colossalai.launch for torchrun or torch.distributed.launch by reading rank and world size
|
|
from the environment variables set by PyTorch
|
|
|
|
Args:
|
|
config (Union[str, dict, Config]): Config file or config file path are both acceptable
|
|
backend (str, optional): Backend for ``torch.distributed``, defaults to ``nccl``
|
|
seed (int, optional): Specified random seed for every process. Defaults to 1024.
|
|
verbose (bool, optional): Whether to print logs. Defaults to True.
|
|
"""
|
|
try:
|
|
rank = int(os.environ['RANK'])
|
|
local_rank = int(os.environ['LOCAL_RANK'])
|
|
world_size = int(os.environ['WORLD_SIZE'])
|
|
host = os.environ['MASTER_ADDR']
|
|
port = int(os.environ['MASTER_PORT'])
|
|
except KeyError as e:
|
|
raise RuntimeError(
|
|
f"Could not find {e} in the torch environment, visit https://www.colossalai.org/ for more information on launching with torch"
|
|
)
|
|
|
|
launch(config=config,
|
|
local_rank=local_rank,
|
|
rank=rank,
|
|
world_size=world_size,
|
|
host=host,
|
|
port=port,
|
|
backend=backend,
|
|
seed=seed,
|
|
verbose=verbose)
|
|
|
|
|
|
def initialize(model: nn.Module,
|
|
optimizer: Optimizer,
|
|
criterion: Optional[_Loss] = None,
|
|
train_dataloader: Optional[Iterable] = None,
|
|
test_dataloader: Optional[Iterable] = None,
|
|
lr_scheduler: Optional[_LRScheduler] = None,
|
|
ophooks: Optional[List[BaseOpHook]] = None,
|
|
verbose: bool = True) -> Tuple[Engine, DataLoader, DataLoader, _LRScheduler]:
|
|
"""Core function to wrap the essential training components with our functionality based on the config which is
|
|
loaded into gpc.config.
|
|
|
|
Args:
|
|
model (:class:`torch.nn.Module` or Callbale): Your model instance or a function to build the model.
|
|
optimizer (:class:`torch.optim.optimizer.Optimizer` or :class:`Type[torch.optim.optimizer]`):
|
|
Your optimizer instance.
|
|
criterion (:class:`torch.nn.modules.loss._Loss`, optional): Your criterion instance.
|
|
train_dataloader (:class:`torch.utils.data.DataLoader`, optional): Dataloader for training.
|
|
test_dataloader (:class:`torch.utils.data.DataLoader`, optional): Dataloader for testing.
|
|
lr_scheduler (:class:`torch.nn.lr_scheduler._LRScheduler`, optional): Your lr scheduler instance, optional.
|
|
verbose (bool, optional): Whether to print logs.
|
|
|
|
Returns:
|
|
Tuple (engine, train_dataloader, test_dataloader, lr_scheduler):
|
|
A tuple of ``(engine, train_dataloader, test_dataloader, lr_scheduler)``
|
|
where only ``engine`` could not be None.
|
|
"""
|
|
# get logger
|
|
logger = get_dist_logger()
|
|
gpc.verbose = verbose
|
|
|
|
# get config from gpc
|
|
config = gpc.config
|
|
|
|
# print config
|
|
if verbose:
|
|
logger.info(
|
|
f"\n========== Your Config ========\n"
|
|
f"{pprint.pformat(gpc.config)}\n"
|
|
f"================================\n",
|
|
ranks=[0])
|
|
|
|
# cudnn
|
|
cudnn_benchmark = config.get('cudnn_benchmark', False)
|
|
cudnn_deterministic = config.get('cudnn_deterministic', False)
|
|
torch.backends.cudnn.benchmark = cudnn_benchmark
|
|
torch.backends.cudnn.deterministic = cudnn_deterministic
|
|
if verbose:
|
|
logger.info(f"cuDNN benchmark = {cudnn_benchmark}, deterministic = {cudnn_deterministic}", ranks=[0])
|
|
|
|
# zero
|
|
use_zero = hasattr(gpc.config, 'zero')
|
|
if use_zero:
|
|
zero_cfg = gpc.config.get('zero', None)
|
|
if zero_cfg is not None:
|
|
cfg_ = zero_cfg.copy()
|
|
else:
|
|
cfg_ = {}
|
|
optimizer_config = zero_cfg.get('optimizer_config', None)
|
|
model_config = zero_cfg.get('model_config', None)
|
|
model, optimizer = convert_to_zero_v2(model,
|
|
optimizer,
|
|
model_config=model_config,
|
|
optimizer_config=optimizer_config)
|
|
|
|
logger.info("Initializing ZeRO model and optimizer finished!", ranks=[0])
|
|
else:
|
|
if isinstance(model, nn.Module):
|
|
# first sync model across dp ranks
|
|
model.to(get_current_device())
|
|
elif isinstance(model, Callable):
|
|
model = model().to(get_current_device())
|
|
|
|
# optimizer maybe a optimizer_cls
|
|
if isinstance(optimizer, Callable):
|
|
optimizer = optimizer(model.parameters())
|
|
logger.warning("Initializing an non ZeRO model with optimizer class")
|
|
|
|
if not use_zero:
|
|
if is_using_sequence():
|
|
sync_model_param(model, ParallelMode.SEQUENCE_DP)
|
|
elif MOE_CONTEXT.is_initialized:
|
|
sync_moe_model_param(model)
|
|
elif is_using_ddp():
|
|
sync_model_param(model, ParallelMode.DATA)
|
|
else:
|
|
logger.warning(
|
|
"The parameters of models is not automatically synchronized.\n"
|
|
"Please make sure that all parameters are the same in data parallel group.",
|
|
ranks=[0])
|
|
|
|
# check amp and zero
|
|
fp16_cfg = gpc.config.get('fp16', None)
|
|
|
|
if fp16_cfg is not None and fp16_cfg.mode is not None and use_zero:
|
|
raise ConfigException(
|
|
"It is not allowed to set fp16 and zero configuration in your config file at the same time")
|
|
|
|
# clip grad norm
|
|
clip_grad_norm = gpc.config.get('clip_grad_norm', 0.0)
|
|
|
|
# initialize amp
|
|
amp_mode = None
|
|
if fp16_cfg is not None and fp16_cfg.mode is not None:
|
|
cfg_ = fp16_cfg.copy()
|
|
amp_mode = cfg_.pop('mode')
|
|
if is_using_pp():
|
|
assert amp_mode == AMP_TYPE.NAIVE, 'Pipeline only support NaiveAMP currently'
|
|
if amp_mode == AMP_TYPE.NAIVE:
|
|
cfg_['clip_grad_norm'] = clip_grad_norm
|
|
model, optimizer, criterion = convert_to_amp(model=model,
|
|
optimizer=optimizer,
|
|
criterion=criterion,
|
|
mode=amp_mode,
|
|
amp_config=cfg_)
|
|
|
|
# get torch ddp config
|
|
torch_ddp_cfg = gpc.config.get('torch_ddp', dict())
|
|
|
|
# gradient handler
|
|
gradient_handler_cfg = gpc.config.get('gradient_handler', None)
|
|
if gradient_handler_cfg is None:
|
|
# if gradient handler is not specified in the configuration file,
|
|
# check in the following order
|
|
# 1. if optimizer is ZERO, then use zero grad handler
|
|
# 2. if dp size is larger than 1 and pipeline is not used, use pytorch ddp
|
|
# 3. if using pipeline and dp size larger than 1, use data parallel grad handler
|
|
if isinstance(optimizer, ShardedOptimizerV2):
|
|
gradient_handler_cfg = [dict(type='ZeROGradientHandler')]
|
|
if verbose:
|
|
logger.info(
|
|
"Training with zero is detected, ZeROGradientHandler is automatically "
|
|
"added even though not specified in the configuration",
|
|
ranks=[0])
|
|
elif is_using_ddp() and MOE_CONTEXT.is_initialized:
|
|
gradient_handler_cfg = [dict(type='MoeGradientHandler')]
|
|
if verbose:
|
|
logger.info(
|
|
"Data parallel training is detected with moe parallel, MoeGradientHandler is automatically "
|
|
"added even though not specified in the configuration",
|
|
ranks=[0])
|
|
elif is_using_sequence():
|
|
model = DDP(model,
|
|
process_group=gpc.get_group(ParallelMode.SEQUENCE_DP),
|
|
device_ids=[torch.cuda.current_device()],
|
|
**torch_ddp_cfg)
|
|
if verbose:
|
|
logger.info('Model is using torch.nn.parallel.DistributedDataParallel for Sequence Parallelism',
|
|
ranks=[0])
|
|
elif is_using_ddp() and not is_using_pp() and amp_mode != AMP_TYPE.NAIVE:
|
|
model = DDP(model,
|
|
process_group=gpc.get_group(ParallelMode.DATA),
|
|
device_ids=[torch.cuda.current_device()],
|
|
**torch_ddp_cfg)
|
|
if verbose:
|
|
logger.info('Model is using torch.nn.parallel.DistributedDataParallel for Data Parallelism', ranks=[0])
|
|
elif is_using_ddp():
|
|
gradient_handler_cfg = [dict(type='DataParallelGradientHandler')]
|
|
if verbose:
|
|
logger.info(
|
|
"Data parallel training is detected when using pipeline parallel, "
|
|
"DataParallelGradientHandler is automatically "
|
|
"added even though not specified in the configuration",
|
|
ranks=[0])
|
|
# add pipeline parallel gradient handler, if pipeline shared module is detected
|
|
for param in model.parameters():
|
|
if getattr(param, 'pipeline_shared_module_pg', None) is not None:
|
|
if gradient_handler_cfg is None:
|
|
gradient_handler_cfg = [dict(type='PipelineSharedModuleGradientHandler')]
|
|
else:
|
|
gradient_handler_cfg.append(dict(type='PipelineSharedModuleGradientHandler'))
|
|
if verbose:
|
|
logger.info(
|
|
"pipeline_shared_module is detected, PipelineSharedModuleGradientHandler is automatically "
|
|
"added even though not specified in the configuration",
|
|
ranks=[0])
|
|
break
|
|
else:
|
|
if not isinstance(gradient_handler_cfg, list):
|
|
raise ConfigException(
|
|
f"expected gradient_handler in the configuration file to be a list but got {type(gradient_handler_cfg)}"
|
|
)
|
|
|
|
# turn off sync buffer for NaiveAMPModel if using torch DDP and NaiveAMPModel at the same time
|
|
# to avoid duplicated buffer synchronization
|
|
if isinstance(model, DDP) and isinstance(model.module, NaiveAMPModel):
|
|
model.module.sync_buffer = False
|
|
|
|
# initialize schedule for engine
|
|
if is_using_pp():
|
|
tensor_shape = get_tensor_shape()
|
|
use_interleaved = hasattr(gpc.config, 'model') and hasattr(gpc.config.model, 'num_chunks')
|
|
if gpc.is_initialized(ParallelMode.PARALLEL_1D):
|
|
scatter_gather = True
|
|
else:
|
|
scatter_gather = False
|
|
if use_interleaved:
|
|
if isinstance(model, nn.Sequential):
|
|
model = nn.ModuleList([model])
|
|
schedule = InterleavedPipelineSchedule(gpc.config.NUM_MICRO_BATCHES,
|
|
gpc.config.model.num_chunks,
|
|
tensor_shape=tensor_shape,
|
|
scatter_gather_tensors=scatter_gather)
|
|
else:
|
|
schedule = PipelineSchedule(gpc.config.NUM_MICRO_BATCHES,
|
|
tensor_shape=tensor_shape,
|
|
scatter_gather_tensors=scatter_gather)
|
|
else:
|
|
schedule = NonPipelineSchedule()
|
|
|
|
if gradient_handler_cfg is None:
|
|
gradient_handlers = None
|
|
if verbose and not isinstance(model, DDP):
|
|
logger.warning(
|
|
"No PyTorch DDP or gradient handler is set up, please make sure you do not need "
|
|
"to all-reduce the gradients after a training step.",
|
|
ranks=[0])
|
|
else:
|
|
gradient_handlers = [build_gradient_handler(cfg, model, optimizer) for cfg in gradient_handler_cfg]
|
|
|
|
# check if optimizer is ColossalaiOptimizer
|
|
if not isinstance(optimizer, (ColossalaiOptimizer, ShardedOptimizerV2)):
|
|
optimizer = ColossalaiOptimizer(optim=optimizer)
|
|
|
|
# gradient accumulation
|
|
grad_accum_size = gpc.config.get('gradient_accumulation', None)
|
|
if grad_accum_size is not None:
|
|
optimizer, train_dataloader, gradient_handlers, lr_scheduler = accumulate_gradient(
|
|
model=model,
|
|
optimizer=optimizer,
|
|
dataloader=train_dataloader,
|
|
accumulate_size=grad_accum_size,
|
|
gradient_handlers=gradient_handlers,
|
|
lr_scheduler=lr_scheduler)
|
|
engine = Engine(model=model,
|
|
optimizer=optimizer,
|
|
criterion=criterion,
|
|
gradient_handlers=gradient_handlers,
|
|
clip_grad_norm=clip_grad_norm,
|
|
ophook_list=ophooks,
|
|
schedule=schedule)
|
|
|
|
return engine, train_dataloader, test_dataloader, lr_scheduler
|