added buffer sync to naive amp model wrapper (#291)

This commit is contained in:
Frank Lee
2022-03-02 16:47:17 +08:00
parent 8d653af408
commit e17e54e32a
4 changed files with 191 additions and 46 deletions

View File

@@ -16,6 +16,7 @@ from torch.optim.optimizer import Optimizer
from torch.utils.data import DataLoader
from colossalai.amp import AMP_TYPE, convert_to_amp
from colossalai.amp.naive_amp import NaiveAMPModel
from colossalai.builder.builder import build_gradient_handler
from colossalai.context import Config, ConfigException, ParallelMode
from colossalai.core import global_context as gpc
@@ -23,8 +24,7 @@ from colossalai.engine import Engine
from colossalai.global_variables import moe_env
from colossalai.logging import get_dist_logger
from colossalai.nn.optimizer.colossalai_optimizer import ColossalaiOptimizer
from colossalai.utils import (accumulate_gradient, get_current_device,
is_using_ddp, is_using_pp, is_using_sequence,
from colossalai.utils import (accumulate_gradient, get_current_device, is_using_ddp, is_using_pp, is_using_sequence,
sync_model_param)
from colossalai.zero import convert_to_zero, ShardedOptimizer
from colossalai.engine.ophooks import register_ophooks_recursively, BaseOpHook
@@ -39,21 +39,12 @@ def get_default_parser():
"""
parser = argparse.ArgumentParser()
parser.add_argument('--config', type=str, help='path to the config file')
parser.add_argument('--host',
type=str,
help='the master address for distributed training')
parser.add_argument('--port',
type=int,
help='the master port for distributed training')
parser.add_argument('--host', type=str, help='the master address for distributed training')
parser.add_argument('--port', type=int, help='the master port for distributed training')
parser.add_argument('--world_size', type=int, help='world size for distributed training')
parser.add_argument('--rank', type=int, help='rank for the default process group')
parser.add_argument('--local_rank',
type=int,
help='local rank on the node')
parser.add_argument('--backend',
type=str,
default='nccl',
help='backend for distributed communication')
parser.add_argument('--local_rank', type=int, help='local rank on the node')
parser.add_argument('--backend', type=str, default='nccl', help='backend for distributed communication')
return parser
@@ -116,9 +107,11 @@ def launch(config: Union[str, Path, Config, Dict],
if verbose:
logger = get_dist_logger()
logger.info(f'Distributed environment is initialized, '
f'data parallel size: {gpc.data_parallel_size}, pipeline parallel size: {gpc.pipeline_parallel_size}, '
f'tensor parallel size: {gpc.tensor_parallel_size}', ranks=[0])
logger.info(
f'Distributed environment is initialized, '
f'data parallel size: {gpc.data_parallel_size}, pipeline parallel size: {gpc.pipeline_parallel_size}, '
f'tensor parallel size: {gpc.tensor_parallel_size}',
ranks=[0])
def launch_from_slurm(config: Union[str, Path, Config, Dict],
@@ -261,9 +254,11 @@ def initialize(model: Union[nn.Module, List[nn.Module]],
# print config
if verbose:
logger.info(f"\n========== Your Config ========\n"
f"{pprint.pformat(gpc.config)}\n"
f"================================\n", ranks=[0])
logger.info(
f"\n========== Your Config ========\n"
f"{pprint.pformat(gpc.config)}\n"
f"================================\n",
ranks=[0])
# cudnn
cudnn_benchmark = config.get('cudnn_benchmark', True)
@@ -271,8 +266,7 @@ def initialize(model: Union[nn.Module, List[nn.Module]],
torch.backends.cudnn.benchmark = cudnn_benchmark
torch.backends.cudnn.deterministic = cudnn_deterministic
if verbose:
logger.info(
f"cuDNN benchmark = {cudnn_benchmark}, deterministic = {cudnn_deterministic}", ranks=[0])
logger.info(f"cuDNN benchmark = {cudnn_benchmark}, deterministic = {cudnn_deterministic}", ranks=[0])
# first sync model across dp ranks
model.to(get_current_device())
@@ -321,11 +315,7 @@ def initialize(model: Union[nn.Module, List[nn.Module]],
if zero_cfg is not None:
cfg_ = zero_cfg.copy()
level = cfg_.pop('level')
model, optimizer = convert_to_zero(model=model,
optimizer=optimizer,
level=level,
zero_config=cfg_
)
model, optimizer = convert_to_zero(model=model, optimizer=optimizer, level=level, zero_config=cfg_)
# gradient handler
gradient_handler_cfg = gpc.config.get('gradient_handler', None)
@@ -350,21 +340,22 @@ def initialize(model: Union[nn.Module, List[nn.Module]],
"added even though not specified in the configuration",
ranks=[0])
elif is_using_sequence():
model = DDP(model, process_group=gpc.get_group(ParallelMode.SEQUENCE_DP),
model = DDP(model,
process_group=gpc.get_group(ParallelMode.SEQUENCE_DP),
device_ids=[torch.cuda.current_device()])
if verbose:
logger.info(
'Model is using torch.nn.parallel.DistributedDataParallel for Sequence Parallelism', ranks=[0])
logger.info('Model is using torch.nn.parallel.DistributedDataParallel for Sequence Parallelism',
ranks=[0])
elif is_using_ddp() and not is_using_pp() and amp_mode != AMP_TYPE.NAIVE:
model = DDP(model, process_group=gpc.get_group(ParallelMode.DATA), device_ids=[torch.cuda.current_device()])
if verbose:
logger.info(
'Model is using torch.nn.parallel.DistributedDataParallel for Data Parallelism', ranks=[0])
logger.info('Model is using torch.nn.parallel.DistributedDataParallel for Data Parallelism', ranks=[0])
elif is_using_ddp():
gradient_handler_cfg = [dict(type='DataParallelGradientHandler')]
if verbose:
logger.info(
"Data parallel training is detected when using pipeline parallel, DataParallelGradientHandler is automatically "
"Data parallel training is detected when using pipeline parallel, "
"DataParallelGradientHandler is automatically "
"added even though not specified in the configuration",
ranks=[0])
# add pipeline parallel gradient handler, if pipeline shared module is detected
@@ -383,7 +374,13 @@ def initialize(model: Union[nn.Module, List[nn.Module]],
else:
if not isinstance(gradient_handler_cfg, list):
raise ConfigException(
f"expected gradient_handler in the configuration file to be a list but got {type(gradient_handler_cfg)}")
f"expected gradient_handler in the configuration file to be a list but got {type(gradient_handler_cfg)}"
)
# turn off sync buffer for NaiveAMPModel if using torch DDP and NaiveAMPModel at the same time
# to avoid duplicated buffer synchronization
if isinstance(model, DDP) and isinstance(model.module, NaiveAMPModel):
model.module.sync_buffer = False
if gradient_handler_cfg is None:
gradient_handlers = None