mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 11:02:05 +00:00
[refactory] refactory the initialize method for new zero design (#431)
This commit is contained in:
@@ -5,7 +5,7 @@ import argparse
|
||||
import os
|
||||
import pprint
|
||||
from pathlib import Path
|
||||
from typing import Dict, Iterable, List, Optional, Tuple, Union
|
||||
from typing import Callable, Dict, Iterable, List, Optional, Tuple, Union, Type
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -26,8 +26,9 @@ from colossalai.logging import get_dist_logger
|
||||
from colossalai.nn.optimizer.colossalai_optimizer import ColossalaiOptimizer
|
||||
from colossalai.utils import (accumulate_gradient, get_current_device, is_using_ddp, is_using_pp, is_using_sequence,
|
||||
sync_model_param)
|
||||
from colossalai.zero import convert_to_zero, ShardedOptimizer
|
||||
from colossalai.zero import convert_to_zero_v2
|
||||
from colossalai.engine.ophooks import BaseOpHook
|
||||
from colossalai.zero.sharded_optim.sharded_optim_v2 import ShardedOptimizerV2
|
||||
|
||||
|
||||
def get_default_parser():
|
||||
@@ -216,8 +217,8 @@ def launch_from_torch(config: Union[str, Path, Config, Dict],
|
||||
verbose=verbose)
|
||||
|
||||
|
||||
def initialize(model: nn.Module,
|
||||
optimizer: Optimizer,
|
||||
def initialize(model: Union[Callable, nn.Module],
|
||||
optimizer: Union[Type[Optimizer], Optimizer],
|
||||
criterion: Optional[_Loss] = None,
|
||||
train_dataloader: Optional[Iterable] = None,
|
||||
test_dataloader: Optional[Iterable] = None,
|
||||
@@ -227,10 +228,10 @@ def initialize(model: nn.Module,
|
||||
"""Core function to wrap the essential training components with our functionality based on the config which is
|
||||
loaded into gpc.config.
|
||||
|
||||
:param model: Your model instance
|
||||
:type model: :class:`torch.nn.Module`
|
||||
:param model: Your model instance or a function to build the model
|
||||
:type model: :class:`torch.nn.Module` or Callbale
|
||||
:param optimizer: Your optimizer instance
|
||||
:type optimizer: :class:`torch.optim.optimizer.Optimizer`
|
||||
:type optimizer: :class:`torch.optim.optimizer.Optimizer` or :class:`Type[torch.optim.optimizer]`
|
||||
:param criterion: Your criterion instance
|
||||
:type criterion: :class:`torch.nn.modules.loss._Loss`, optional
|
||||
:param train_dataloader: Dataloader for training
|
||||
@@ -267,10 +268,28 @@ def initialize(model: nn.Module,
|
||||
if verbose:
|
||||
logger.info(f"cuDNN benchmark = {cudnn_benchmark}, deterministic = {cudnn_deterministic}", ranks=[0])
|
||||
|
||||
# first sync model across dp ranks
|
||||
model.to(get_current_device())
|
||||
use_zero3 = hasattr(gpc.config, 'zero') and gpc.config.zero.level == 3
|
||||
if not moe_env.is_initialized() and not use_zero3:
|
||||
use_zero = hasattr(gpc.config, 'zero')
|
||||
if use_zero:
|
||||
zero_cfg = gpc.config.get('zero', None)
|
||||
if zero_cfg is not None:
|
||||
cfg_ = zero_cfg.copy()
|
||||
else:
|
||||
cfg_ = {}
|
||||
optimizer_config = zero_cfg.get('optimzer', None)
|
||||
model, optimizer = convert_to_zero_v2(model_builder=model, optimizer_config=optimizer_config)
|
||||
|
||||
logger.info("Initializing ZeRO model and optimzer finished!", ranks=[0])
|
||||
#FIXME() throw a warning if using zero with MP
|
||||
if gpc.get_world_size(ParallelMode.MODEL) > 1:
|
||||
logger.warning("ZeRO currently has not been tested with model parallelism.", ranks=[0])
|
||||
else:
|
||||
if isinstance(model, nn.Module):
|
||||
# first sync model across dp ranks
|
||||
model.to(get_current_device())
|
||||
elif isinstance(model, Callable):
|
||||
model = model().to(get_current_device())
|
||||
|
||||
if not moe_env.is_initialized() and not use_zero:
|
||||
if is_using_sequence():
|
||||
sync_model_param(model, ParallelMode.SEQUENCE_DP)
|
||||
elif is_using_ddp():
|
||||
@@ -283,16 +302,15 @@ def initialize(model: nn.Module,
|
||||
|
||||
# check amp and zero
|
||||
fp16_cfg = gpc.config.get('fp16', None)
|
||||
zero_cfg = gpc.config.get('zero', None)
|
||||
|
||||
if fp16_cfg is not None and fp16_cfg.mode is not None and zero_cfg is not None:
|
||||
if fp16_cfg is not None and fp16_cfg.mode is not None and use_zero:
|
||||
raise ConfigException(
|
||||
"It is not allowed to set fp16 and zero configuration in your config file at the same time")
|
||||
|
||||
# clip grad norm
|
||||
clip_grad_norm = gpc.config.get('clip_grad_norm', 0.0)
|
||||
if clip_grad_norm > 0:
|
||||
if zero_cfg is not None:
|
||||
if use_zero and zero_cfg is not None:
|
||||
raise ConfigException(
|
||||
"clip_grad_norm should be specified with zero, you should specify clip_grad in zero configuration")
|
||||
|
||||
@@ -311,11 +329,6 @@ def initialize(model: nn.Module,
|
||||
mode=amp_mode,
|
||||
amp_config=cfg_)
|
||||
|
||||
if zero_cfg is not None:
|
||||
cfg_ = zero_cfg.copy()
|
||||
level = cfg_.pop('level')
|
||||
model, optimizer = convert_to_zero(model=model, optimizer=optimizer, level=level, zero_config=cfg_)
|
||||
|
||||
# gradient handler
|
||||
gradient_handler_cfg = gpc.config.get('gradient_handler', None)
|
||||
if gradient_handler_cfg is None:
|
||||
@@ -324,7 +337,7 @@ def initialize(model: nn.Module,
|
||||
# 1. if optimizer is ZERO, then use zero grad handler
|
||||
# 2. if dp size is larger than 1 and pipeline is not used, use pytorch ddp
|
||||
# 3. if using pipeline and dp size larger than 1, use data parallel grad handler
|
||||
if isinstance(optimizer, ShardedOptimizer):
|
||||
if isinstance(optimizer, ShardedOptimizerV2):
|
||||
gradient_handler_cfg = [dict(type='ZeROGradientHandler')]
|
||||
if verbose:
|
||||
logger.info(
|
||||
@@ -392,7 +405,7 @@ def initialize(model: nn.Module,
|
||||
gradient_handlers = [build_gradient_handler(cfg, model, optimizer) for cfg in gradient_handler_cfg]
|
||||
|
||||
# check if optimizer is ColossalaiOptimizer
|
||||
if not isinstance(optimizer, (ColossalaiOptimizer, ShardedOptimizer)):
|
||||
if not isinstance(optimizer, (ColossalaiOptimizer, ShardedOptimizerV2)):
|
||||
optimizer = ColossalaiOptimizer(optim=optimizer)
|
||||
|
||||
# gradient accumulation
|
||||
|
Reference in New Issue
Block a user