Added MoE parallel (#127)

This commit is contained in:
HELSON
2022-01-07 15:08:36 +08:00
committed by GitHub
parent 42741dd4a3
commit dceae85195
26 changed files with 858 additions and 18 deletions

View File

@@ -5,7 +5,6 @@ import argparse
import pprint
import os
from colossalai.nn.optimizer.colossalai_optimizer import ColossalaiOptimizer
import numpy as np
import torch
import torch.nn as nn
@@ -26,6 +25,7 @@ from torch.optim.lr_scheduler import _LRScheduler
from torch.utils.data import DataLoader
from torch.nn.modules.loss import _Loss
from torch.nn.parallel import DistributedDataParallel as DDP
from colossalai.global_variables import moe_env
def get_default_parser():
@@ -224,7 +224,7 @@ def initialize(model: Union[nn.Module, List[nn.Module]],
test_dataloader: Optional[Union[Iterable, List[Iterable]]] = None,
lr_scheduler: _LRScheduler = None,
verbose: bool = True
) -> Tuple[Engine, DataLoader, DataLoader]:
) -> Tuple[Engine, DataLoader, DataLoader, _LRScheduler]:
''' Core function to wrap the essential training components with our functionality based on the config which is loaded into gpc.config.
:param model: your model instance
@@ -269,8 +269,13 @@ def initialize(model: Union[nn.Module, List[nn.Module]],
# first sync model across dp ranks
model.to(get_current_device())
use_zero3 = hasattr(gpc.config, 'zero') and gpc.config.zero.level == 3
if not use_zero3:
if not moe_env.is_initialized() and not use_zero3:
sync_model_param_in_dp(model)
else:
print(
"Warning: The parameters of models is not automatically synchronized.\n"
"Please make sure that all parameters are the same in data parallel group.",
flush=True)
# check amp and zero
fp16_cfg = gpc.config.get('fp16', None)
@@ -327,6 +332,13 @@ def initialize(model: Union[nn.Module, List[nn.Module]],
"Training with zero is detected, ZeROGradientHandler is automatically "
"added even though not specified in the configuration",
ranks=[0])
elif is_using_ddp() and moe_env.is_initialized():
gradient_handler_cfg = [dict(type='MoeGradientHandler')]
if verbose:
logger.info(
"Data parallel training is detected with moe parallel, MoeGradientHandler is automatically "
"added even though not specified in the configuration",
ranks=[0])
elif is_using_ddp() and not is_using_pp() and amp_mode != AMP_TYPE.NAIVE:
model = DDP(model, process_group=gpc.get_group(ParallelMode.DATA))
if verbose: