mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-04 18:40:28 +00:00
added buffer sync to naive amp model wrapper (#291)
This commit is contained in:
@@ -16,6 +16,7 @@ from torch.optim.optimizer import Optimizer
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from colossalai.amp import AMP_TYPE, convert_to_amp
|
||||
from colossalai.amp.naive_amp import NaiveAMPModel
|
||||
from colossalai.builder.builder import build_gradient_handler
|
||||
from colossalai.context import Config, ConfigException, ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
@@ -23,8 +24,7 @@ from colossalai.engine import Engine
|
||||
from colossalai.global_variables import moe_env
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.nn.optimizer.colossalai_optimizer import ColossalaiOptimizer
|
||||
from colossalai.utils import (accumulate_gradient, get_current_device,
|
||||
is_using_ddp, is_using_pp, is_using_sequence,
|
||||
from colossalai.utils import (accumulate_gradient, get_current_device, is_using_ddp, is_using_pp, is_using_sequence,
|
||||
sync_model_param)
|
||||
from colossalai.zero import convert_to_zero, ShardedOptimizer
|
||||
from colossalai.engine.ophooks import register_ophooks_recursively, BaseOpHook
|
||||
@@ -39,21 +39,12 @@ def get_default_parser():
|
||||
"""
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--config', type=str, help='path to the config file')
|
||||
parser.add_argument('--host',
|
||||
type=str,
|
||||
help='the master address for distributed training')
|
||||
parser.add_argument('--port',
|
||||
type=int,
|
||||
help='the master port for distributed training')
|
||||
parser.add_argument('--host', type=str, help='the master address for distributed training')
|
||||
parser.add_argument('--port', type=int, help='the master port for distributed training')
|
||||
parser.add_argument('--world_size', type=int, help='world size for distributed training')
|
||||
parser.add_argument('--rank', type=int, help='rank for the default process group')
|
||||
parser.add_argument('--local_rank',
|
||||
type=int,
|
||||
help='local rank on the node')
|
||||
parser.add_argument('--backend',
|
||||
type=str,
|
||||
default='nccl',
|
||||
help='backend for distributed communication')
|
||||
parser.add_argument('--local_rank', type=int, help='local rank on the node')
|
||||
parser.add_argument('--backend', type=str, default='nccl', help='backend for distributed communication')
|
||||
return parser
|
||||
|
||||
|
||||
@@ -116,9 +107,11 @@ def launch(config: Union[str, Path, Config, Dict],
|
||||
|
||||
if verbose:
|
||||
logger = get_dist_logger()
|
||||
logger.info(f'Distributed environment is initialized, '
|
||||
f'data parallel size: {gpc.data_parallel_size}, pipeline parallel size: {gpc.pipeline_parallel_size}, '
|
||||
f'tensor parallel size: {gpc.tensor_parallel_size}', ranks=[0])
|
||||
logger.info(
|
||||
f'Distributed environment is initialized, '
|
||||
f'data parallel size: {gpc.data_parallel_size}, pipeline parallel size: {gpc.pipeline_parallel_size}, '
|
||||
f'tensor parallel size: {gpc.tensor_parallel_size}',
|
||||
ranks=[0])
|
||||
|
||||
|
||||
def launch_from_slurm(config: Union[str, Path, Config, Dict],
|
||||
@@ -261,9 +254,11 @@ def initialize(model: Union[nn.Module, List[nn.Module]],
|
||||
|
||||
# print config
|
||||
if verbose:
|
||||
logger.info(f"\n========== Your Config ========\n"
|
||||
f"{pprint.pformat(gpc.config)}\n"
|
||||
f"================================\n", ranks=[0])
|
||||
logger.info(
|
||||
f"\n========== Your Config ========\n"
|
||||
f"{pprint.pformat(gpc.config)}\n"
|
||||
f"================================\n",
|
||||
ranks=[0])
|
||||
|
||||
# cudnn
|
||||
cudnn_benchmark = config.get('cudnn_benchmark', True)
|
||||
@@ -271,8 +266,7 @@ def initialize(model: Union[nn.Module, List[nn.Module]],
|
||||
torch.backends.cudnn.benchmark = cudnn_benchmark
|
||||
torch.backends.cudnn.deterministic = cudnn_deterministic
|
||||
if verbose:
|
||||
logger.info(
|
||||
f"cuDNN benchmark = {cudnn_benchmark}, deterministic = {cudnn_deterministic}", ranks=[0])
|
||||
logger.info(f"cuDNN benchmark = {cudnn_benchmark}, deterministic = {cudnn_deterministic}", ranks=[0])
|
||||
|
||||
# first sync model across dp ranks
|
||||
model.to(get_current_device())
|
||||
@@ -321,11 +315,7 @@ def initialize(model: Union[nn.Module, List[nn.Module]],
|
||||
if zero_cfg is not None:
|
||||
cfg_ = zero_cfg.copy()
|
||||
level = cfg_.pop('level')
|
||||
model, optimizer = convert_to_zero(model=model,
|
||||
optimizer=optimizer,
|
||||
level=level,
|
||||
zero_config=cfg_
|
||||
)
|
||||
model, optimizer = convert_to_zero(model=model, optimizer=optimizer, level=level, zero_config=cfg_)
|
||||
|
||||
# gradient handler
|
||||
gradient_handler_cfg = gpc.config.get('gradient_handler', None)
|
||||
@@ -350,21 +340,22 @@ def initialize(model: Union[nn.Module, List[nn.Module]],
|
||||
"added even though not specified in the configuration",
|
||||
ranks=[0])
|
||||
elif is_using_sequence():
|
||||
model = DDP(model, process_group=gpc.get_group(ParallelMode.SEQUENCE_DP),
|
||||
model = DDP(model,
|
||||
process_group=gpc.get_group(ParallelMode.SEQUENCE_DP),
|
||||
device_ids=[torch.cuda.current_device()])
|
||||
if verbose:
|
||||
logger.info(
|
||||
'Model is using torch.nn.parallel.DistributedDataParallel for Sequence Parallelism', ranks=[0])
|
||||
logger.info('Model is using torch.nn.parallel.DistributedDataParallel for Sequence Parallelism',
|
||||
ranks=[0])
|
||||
elif is_using_ddp() and not is_using_pp() and amp_mode != AMP_TYPE.NAIVE:
|
||||
model = DDP(model, process_group=gpc.get_group(ParallelMode.DATA), device_ids=[torch.cuda.current_device()])
|
||||
if verbose:
|
||||
logger.info(
|
||||
'Model is using torch.nn.parallel.DistributedDataParallel for Data Parallelism', ranks=[0])
|
||||
logger.info('Model is using torch.nn.parallel.DistributedDataParallel for Data Parallelism', ranks=[0])
|
||||
elif is_using_ddp():
|
||||
gradient_handler_cfg = [dict(type='DataParallelGradientHandler')]
|
||||
if verbose:
|
||||
logger.info(
|
||||
"Data parallel training is detected when using pipeline parallel, DataParallelGradientHandler is automatically "
|
||||
"Data parallel training is detected when using pipeline parallel, "
|
||||
"DataParallelGradientHandler is automatically "
|
||||
"added even though not specified in the configuration",
|
||||
ranks=[0])
|
||||
# add pipeline parallel gradient handler, if pipeline shared module is detected
|
||||
@@ -383,7 +374,13 @@ def initialize(model: Union[nn.Module, List[nn.Module]],
|
||||
else:
|
||||
if not isinstance(gradient_handler_cfg, list):
|
||||
raise ConfigException(
|
||||
f"expected gradient_handler in the configuration file to be a list but got {type(gradient_handler_cfg)}")
|
||||
f"expected gradient_handler in the configuration file to be a list but got {type(gradient_handler_cfg)}"
|
||||
)
|
||||
|
||||
# turn off sync buffer for NaiveAMPModel if using torch DDP and NaiveAMPModel at the same time
|
||||
# to avoid duplicated buffer synchronization
|
||||
if isinstance(model, DDP) and isinstance(model.module, NaiveAMPModel):
|
||||
model.module.sync_buffer = False
|
||||
|
||||
if gradient_handler_cfg is None:
|
||||
gradient_handlers = None
|
||||
|
Reference in New Issue
Block a user