adapted for sequence parallel (#163)

This commit is contained in:
Frank Lee
2022-01-20 13:44:51 +08:00
committed by GitHub
parent a2e649da39
commit e2089c5c15
17 changed files with 432 additions and 119 deletions

View File

@@ -17,7 +17,7 @@ from colossalai.core import global_context as gpc
from colossalai.engine import Engine
from colossalai.logging import get_dist_logger
from colossalai.utils import (accumulate_gradient, get_current_device,
sync_model_param_in_dp, is_using_ddp, is_using_pp)
sync_model_param, is_using_ddp, is_using_pp, is_using_sequence)
from colossalai.zero import convert_to_zero, ZeroRedundancyOptimizer_Level_2, ZeroRedundancyOptimizer_Level_3
from colossalai.builder.builder import build_gradient_handler
from torch.optim.optimizer import Optimizer
@@ -187,7 +187,7 @@ def launch_from_torch(config: Union[str, Path, Config, Dict],
backend: str = 'nccl',
seed: int = 1024,
verbose: bool = True):
'''A wrapper for colossalai.launch for torchrun or torch.distributed.launch by reading rank and world size
'''A wrapper for colossalai.launch for torchrun or torch.distributed.launch by reading rank and world size
from the environment variables set by PyTorch
:param config: config file or config file path are both acceptable
@@ -270,12 +270,15 @@ def initialize(model: Union[nn.Module, List[nn.Module]],
model.to(get_current_device())
use_zero3 = hasattr(gpc.config, 'zero') and gpc.config.zero.level == 3
if not moe_env.is_initialized() and not use_zero3:
sync_model_param_in_dp(model)
if is_using_sequence():
sync_model_param(model, ParallelMode.SEQUENCE_DP)
elif is_using_ddp():
sync_model_param(model, ParallelMode.DATA)
else:
print(
"Warning: The parameters of models is not automatically synchronized.\n"
logger.warning(
"The parameters of models is not automatically synchronized.\n"
"Please make sure that all parameters are the same in data parallel group.",
flush=True)
ranks=[0])
# check amp and zero
fp16_cfg = gpc.config.get('fp16', None)
@@ -339,11 +342,16 @@ def initialize(model: Union[nn.Module, List[nn.Module]],
"Data parallel training is detected with moe parallel, MoeGradientHandler is automatically "
"added even though not specified in the configuration",
ranks=[0])
elif is_using_sequence():
model = DDP(model, process_group=gpc.get_group(ParallelMode.SEQUENCE_DP))
if verbose:
logger.info(
'Model is using torch.nn.parallel.DistributedDataParallel for Sequence Parallelism', ranks=[0])
elif is_using_ddp() and not is_using_pp() and amp_mode != AMP_TYPE.NAIVE:
model = DDP(model, process_group=gpc.get_group(ParallelMode.DATA))
if verbose:
logger.info(
'Model is using torch.nn.parallel.DistributedDataParallel', ranks=[0])
'Model is using torch.nn.parallel.DistributedDataParallel for Data Parallelism', ranks=[0])
elif is_using_ddp():
gradient_handler_cfg = [dict(type='DataParallelGradientHandler')]
if verbose: