[refactory] refactory the initialize method for new zero design (#431)

This commit is contained in:
Jiarui Fang
2022-03-16 19:29:37 +08:00
committed by GitHub
parent 4f85b687cf
commit 640a6cd304
5 changed files with 184 additions and 24 deletions

View File

@@ -5,7 +5,7 @@ import argparse
import os
import pprint
from pathlib import Path
from typing import Dict, Iterable, List, Optional, Tuple, Union
from typing import Callable, Dict, Iterable, List, Optional, Tuple, Union, Type
import torch
import torch.nn as nn
@@ -26,8 +26,9 @@ from colossalai.logging import get_dist_logger
from colossalai.nn.optimizer.colossalai_optimizer import ColossalaiOptimizer
from colossalai.utils import (accumulate_gradient, get_current_device, is_using_ddp, is_using_pp, is_using_sequence,
sync_model_param)
from colossalai.zero import convert_to_zero, ShardedOptimizer
from colossalai.zero import convert_to_zero_v2
from colossalai.engine.ophooks import BaseOpHook
from colossalai.zero.sharded_optim.sharded_optim_v2 import ShardedOptimizerV2
def get_default_parser():
@@ -216,8 +217,8 @@ def launch_from_torch(config: Union[str, Path, Config, Dict],
verbose=verbose)
def initialize(model: nn.Module,
optimizer: Optimizer,
def initialize(model: Union[Callable, nn.Module],
optimizer: Union[Type[Optimizer], Optimizer],
criterion: Optional[_Loss] = None,
train_dataloader: Optional[Iterable] = None,
test_dataloader: Optional[Iterable] = None,
@@ -227,10 +228,10 @@ def initialize(model: nn.Module,
"""Core function to wrap the essential training components with our functionality based on the config which is
loaded into gpc.config.
:param model: Your model instance
:type model: :class:`torch.nn.Module`
:param model: Your model instance or a function to build the model
:type model: :class:`torch.nn.Module` or Callbale
:param optimizer: Your optimizer instance
:type optimizer: :class:`torch.optim.optimizer.Optimizer`
:type optimizer: :class:`torch.optim.optimizer.Optimizer` or :class:`Type[torch.optim.optimizer]`
:param criterion: Your criterion instance
:type criterion: :class:`torch.nn.modules.loss._Loss`, optional
:param train_dataloader: Dataloader for training
@@ -267,10 +268,28 @@ def initialize(model: nn.Module,
if verbose:
logger.info(f"cuDNN benchmark = {cudnn_benchmark}, deterministic = {cudnn_deterministic}", ranks=[0])
# first sync model across dp ranks
model.to(get_current_device())
use_zero3 = hasattr(gpc.config, 'zero') and gpc.config.zero.level == 3
if not moe_env.is_initialized() and not use_zero3:
use_zero = hasattr(gpc.config, 'zero')
if use_zero:
zero_cfg = gpc.config.get('zero', None)
if zero_cfg is not None:
cfg_ = zero_cfg.copy()
else:
cfg_ = {}
optimizer_config = zero_cfg.get('optimzer', None)
model, optimizer = convert_to_zero_v2(model_builder=model, optimizer_config=optimizer_config)
logger.info("Initializing ZeRO model and optimzer finished!", ranks=[0])
#FIXME() throw a warning if using zero with MP
if gpc.get_world_size(ParallelMode.MODEL) > 1:
logger.warning("ZeRO currently has not been tested with model parallelism.", ranks=[0])
else:
if isinstance(model, nn.Module):
# first sync model across dp ranks
model.to(get_current_device())
elif isinstance(model, Callable):
model = model().to(get_current_device())
if not moe_env.is_initialized() and not use_zero:
if is_using_sequence():
sync_model_param(model, ParallelMode.SEQUENCE_DP)
elif is_using_ddp():
@@ -283,16 +302,15 @@ def initialize(model: nn.Module,
# check amp and zero
fp16_cfg = gpc.config.get('fp16', None)
zero_cfg = gpc.config.get('zero', None)
if fp16_cfg is not None and fp16_cfg.mode is not None and zero_cfg is not None:
if fp16_cfg is not None and fp16_cfg.mode is not None and use_zero:
raise ConfigException(
"It is not allowed to set fp16 and zero configuration in your config file at the same time")
# clip grad norm
clip_grad_norm = gpc.config.get('clip_grad_norm', 0.0)
if clip_grad_norm > 0:
if zero_cfg is not None:
if use_zero and zero_cfg is not None:
raise ConfigException(
"clip_grad_norm should be specified with zero, you should specify clip_grad in zero configuration")
@@ -311,11 +329,6 @@ def initialize(model: nn.Module,
mode=amp_mode,
amp_config=cfg_)
if zero_cfg is not None:
cfg_ = zero_cfg.copy()
level = cfg_.pop('level')
model, optimizer = convert_to_zero(model=model, optimizer=optimizer, level=level, zero_config=cfg_)
# gradient handler
gradient_handler_cfg = gpc.config.get('gradient_handler', None)
if gradient_handler_cfg is None:
@@ -324,7 +337,7 @@ def initialize(model: nn.Module,
# 1. if optimizer is ZERO, then use zero grad handler
# 2. if dp size is larger than 1 and pipeline is not used, use pytorch ddp
# 3. if using pipeline and dp size larger than 1, use data parallel grad handler
if isinstance(optimizer, ShardedOptimizer):
if isinstance(optimizer, ShardedOptimizerV2):
gradient_handler_cfg = [dict(type='ZeROGradientHandler')]
if verbose:
logger.info(
@@ -392,7 +405,7 @@ def initialize(model: nn.Module,
gradient_handlers = [build_gradient_handler(cfg, model, optimizer) for cfg in gradient_handler_cfg]
# check if optimizer is ColossalaiOptimizer
if not isinstance(optimizer, (ColossalaiOptimizer, ShardedOptimizer)):
if not isinstance(optimizer, (ColossalaiOptimizer, ShardedOptimizerV2)):
optimizer = ColossalaiOptimizer(optim=optimizer)
# gradient accumulation