mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-04 10:34:41 +00:00
adapted for sequence parallel (#163)
This commit is contained in:
@@ -17,7 +17,7 @@ from colossalai.core import global_context as gpc
|
||||
from colossalai.engine import Engine
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.utils import (accumulate_gradient, get_current_device,
|
||||
sync_model_param_in_dp, is_using_ddp, is_using_pp)
|
||||
sync_model_param, is_using_ddp, is_using_pp, is_using_sequence)
|
||||
from colossalai.zero import convert_to_zero, ZeroRedundancyOptimizer_Level_2, ZeroRedundancyOptimizer_Level_3
|
||||
from colossalai.builder.builder import build_gradient_handler
|
||||
from torch.optim.optimizer import Optimizer
|
||||
@@ -187,7 +187,7 @@ def launch_from_torch(config: Union[str, Path, Config, Dict],
|
||||
backend: str = 'nccl',
|
||||
seed: int = 1024,
|
||||
verbose: bool = True):
|
||||
'''A wrapper for colossalai.launch for torchrun or torch.distributed.launch by reading rank and world size
|
||||
'''A wrapper for colossalai.launch for torchrun or torch.distributed.launch by reading rank and world size
|
||||
from the environment variables set by PyTorch
|
||||
|
||||
:param config: config file or config file path are both acceptable
|
||||
@@ -270,12 +270,15 @@ def initialize(model: Union[nn.Module, List[nn.Module]],
|
||||
model.to(get_current_device())
|
||||
use_zero3 = hasattr(gpc.config, 'zero') and gpc.config.zero.level == 3
|
||||
if not moe_env.is_initialized() and not use_zero3:
|
||||
sync_model_param_in_dp(model)
|
||||
if is_using_sequence():
|
||||
sync_model_param(model, ParallelMode.SEQUENCE_DP)
|
||||
elif is_using_ddp():
|
||||
sync_model_param(model, ParallelMode.DATA)
|
||||
else:
|
||||
print(
|
||||
"Warning: The parameters of models is not automatically synchronized.\n"
|
||||
logger.warning(
|
||||
"The parameters of models is not automatically synchronized.\n"
|
||||
"Please make sure that all parameters are the same in data parallel group.",
|
||||
flush=True)
|
||||
ranks=[0])
|
||||
|
||||
# check amp and zero
|
||||
fp16_cfg = gpc.config.get('fp16', None)
|
||||
@@ -339,11 +342,16 @@ def initialize(model: Union[nn.Module, List[nn.Module]],
|
||||
"Data parallel training is detected with moe parallel, MoeGradientHandler is automatically "
|
||||
"added even though not specified in the configuration",
|
||||
ranks=[0])
|
||||
elif is_using_sequence():
|
||||
model = DDP(model, process_group=gpc.get_group(ParallelMode.SEQUENCE_DP))
|
||||
if verbose:
|
||||
logger.info(
|
||||
'Model is using torch.nn.parallel.DistributedDataParallel for Sequence Parallelism', ranks=[0])
|
||||
elif is_using_ddp() and not is_using_pp() and amp_mode != AMP_TYPE.NAIVE:
|
||||
model = DDP(model, process_group=gpc.get_group(ParallelMode.DATA))
|
||||
if verbose:
|
||||
logger.info(
|
||||
'Model is using torch.nn.parallel.DistributedDataParallel', ranks=[0])
|
||||
'Model is using torch.nn.parallel.DistributedDataParallel for Data Parallelism', ranks=[0])
|
||||
elif is_using_ddp():
|
||||
gradient_handler_cfg = [dict(type='DataParallelGradientHandler')]
|
||||
if verbose:
|
||||
|
Reference in New Issue
Block a user