mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-04 02:26:51 +00:00
Added MoE parallel (#127)
This commit is contained in:
@@ -5,7 +5,6 @@ import argparse
|
||||
import pprint
|
||||
import os
|
||||
from colossalai.nn.optimizer.colossalai_optimizer import ColossalaiOptimizer
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
@@ -26,6 +25,7 @@ from torch.optim.lr_scheduler import _LRScheduler
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.nn.modules.loss import _Loss
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from colossalai.global_variables import moe_env
|
||||
|
||||
|
||||
def get_default_parser():
|
||||
@@ -224,7 +224,7 @@ def initialize(model: Union[nn.Module, List[nn.Module]],
|
||||
test_dataloader: Optional[Union[Iterable, List[Iterable]]] = None,
|
||||
lr_scheduler: _LRScheduler = None,
|
||||
verbose: bool = True
|
||||
) -> Tuple[Engine, DataLoader, DataLoader]:
|
||||
) -> Tuple[Engine, DataLoader, DataLoader, _LRScheduler]:
|
||||
''' Core function to wrap the essential training components with our functionality based on the config which is loaded into gpc.config.
|
||||
|
||||
:param model: your model instance
|
||||
@@ -269,8 +269,13 @@ def initialize(model: Union[nn.Module, List[nn.Module]],
|
||||
# first sync model across dp ranks
|
||||
model.to(get_current_device())
|
||||
use_zero3 = hasattr(gpc.config, 'zero') and gpc.config.zero.level == 3
|
||||
if not use_zero3:
|
||||
if not moe_env.is_initialized() and not use_zero3:
|
||||
sync_model_param_in_dp(model)
|
||||
else:
|
||||
print(
|
||||
"Warning: The parameters of models is not automatically synchronized.\n"
|
||||
"Please make sure that all parameters are the same in data parallel group.",
|
||||
flush=True)
|
||||
|
||||
# check amp and zero
|
||||
fp16_cfg = gpc.config.get('fp16', None)
|
||||
@@ -327,6 +332,13 @@ def initialize(model: Union[nn.Module, List[nn.Module]],
|
||||
"Training with zero is detected, ZeROGradientHandler is automatically "
|
||||
"added even though not specified in the configuration",
|
||||
ranks=[0])
|
||||
elif is_using_ddp() and moe_env.is_initialized():
|
||||
gradient_handler_cfg = [dict(type='MoeGradientHandler')]
|
||||
if verbose:
|
||||
logger.info(
|
||||
"Data parallel training is detected with moe parallel, MoeGradientHandler is automatically "
|
||||
"added even though not specified in the configuration",
|
||||
ranks=[0])
|
||||
elif is_using_ddp() and not is_using_pp() and amp_mode != AMP_TYPE.NAIVE:
|
||||
model = DDP(model, process_group=gpc.get_group(ParallelMode.DATA))
|
||||
if verbose:
|
||||
|
Reference in New Issue
Block a user