mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-28 16:28:10 +00:00
added buffer sync to naive amp model wrapper (#291)
This commit is contained in:
parent
6f22fb1906
commit
193af3a8b7
@ -3,12 +3,15 @@
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
import torch.distributed as dist
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from typing import Union, List, Any, Dict
|
from typing import Any
|
||||||
from torch.optim import Optimizer
|
from torch.optim import Optimizer
|
||||||
import torch.cuda.amp as torch_amp
|
from torch.distributed import ReduceOp
|
||||||
|
from colossalai.core import global_context as gpc
|
||||||
|
from colossalai.context import ParallelMode
|
||||||
from colossalai.nn.optimizer import ColossalaiOptimizer
|
from colossalai.nn.optimizer import ColossalaiOptimizer
|
||||||
|
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
|
||||||
from ._fp16_optimizer import FP16Optimizer
|
from ._fp16_optimizer import FP16Optimizer
|
||||||
|
|
||||||
|
|
||||||
@ -49,10 +52,30 @@ class NaiveAMPModel(nn.Module):
|
|||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
output_to_fp32: bool = True):
|
output_to_fp32: bool = True,
|
||||||
|
parallel_mode: ParallelMode = ParallelMode.DATA,
|
||||||
|
sync_buffer: bool = True):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.model = model.half()
|
self.model = model.half()
|
||||||
self._output_to_fp32 = output_to_fp32
|
self._output_to_fp32 = output_to_fp32
|
||||||
|
self._sync_buf = sync_buffer
|
||||||
|
|
||||||
|
if gpc.is_initialized(parallel_mode) and gpc.get_world_size(parallel_mode) > 1:
|
||||||
|
self._process_group = gpc.get_group(parallel_mode)
|
||||||
|
self._world_size = gpc.get_world_size(parallel_mode)
|
||||||
|
else:
|
||||||
|
self._process_group = None
|
||||||
|
self._world_size = 1
|
||||||
|
self._sync_buf = False
|
||||||
|
self._first_eval_run = False
|
||||||
|
|
||||||
|
@property
|
||||||
|
def sync_buffer(self):
|
||||||
|
return self._sync_buf
|
||||||
|
|
||||||
|
@sync_buffer.setter
|
||||||
|
def sync_buffer(self, state: bool):
|
||||||
|
self._sync_buf = state
|
||||||
|
|
||||||
def _convert_to_fp16(self, input_: Any):
|
def _convert_to_fp16(self, input_: Any):
|
||||||
if isinstance(input_, Tensor) and input_.dtype == torch.float32:
|
if isinstance(input_, Tensor) and input_.dtype == torch.float32:
|
||||||
@ -64,7 +87,46 @@ class NaiveAMPModel(nn.Module):
|
|||||||
input_ = input_.float()
|
input_ = input_.float()
|
||||||
return input_
|
return input_
|
||||||
|
|
||||||
|
def _reduce_module_buffer(self):
|
||||||
|
"""
|
||||||
|
All-reduce the buffers (e.g. running stats of batch normalization) across
|
||||||
|
data parallel ranks so that all the ranks will produce consistent results
|
||||||
|
when given the same input
|
||||||
|
"""
|
||||||
|
buf_list = []
|
||||||
|
|
||||||
|
# find valid buffers
|
||||||
|
for buf in self.model.buffers():
|
||||||
|
if buf is not None:
|
||||||
|
buf_list.append(buf)
|
||||||
|
|
||||||
|
# reduce buffers across data parallel ranks
|
||||||
|
if buf_list:
|
||||||
|
coalesced_buf = _flatten_dense_tensors(buf_list)
|
||||||
|
coalesced_buf.div_(self._world_size)
|
||||||
|
dist.all_reduce(coalesced_buf, op=ReduceOp.SUM, group=self._process_group)
|
||||||
|
unflattened_buf_list = _unflatten_dense_tensors(coalesced_buf, buf_list)
|
||||||
|
for old, new in zip(buf_list, unflattened_buf_list):
|
||||||
|
old.copy_(new)
|
||||||
|
|
||||||
|
def eval(self):
|
||||||
|
self.model.eval()
|
||||||
|
|
||||||
|
# we only sync buffer in the first eval iteration
|
||||||
|
# so that future eval iterations can be done without communication
|
||||||
|
self._first_eval_run = True
|
||||||
|
|
||||||
def forward(self, *args, **kwargs):
|
def forward(self, *args, **kwargs):
|
||||||
|
# reduce buffers after forward will lead to error
|
||||||
|
# as we cannot change the variables needed for gradient computation after forward
|
||||||
|
# so we sync buffer before forward
|
||||||
|
if (self.training or self._first_eval_run) and self._sync_buf:
|
||||||
|
with torch.no_grad():
|
||||||
|
self._reduce_module_buffer()
|
||||||
|
|
||||||
|
if self._first_eval_run:
|
||||||
|
self._first_eval_run = False
|
||||||
|
|
||||||
if args:
|
if args:
|
||||||
args = [self._convert_to_fp16(arg) for arg in args]
|
args = [self._convert_to_fp16(arg) for arg in args]
|
||||||
if kwargs:
|
if kwargs:
|
||||||
|
@ -16,6 +16,7 @@ from torch.optim.optimizer import Optimizer
|
|||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
from colossalai.amp import AMP_TYPE, convert_to_amp
|
from colossalai.amp import AMP_TYPE, convert_to_amp
|
||||||
|
from colossalai.amp.naive_amp import NaiveAMPModel
|
||||||
from colossalai.builder.builder import build_gradient_handler
|
from colossalai.builder.builder import build_gradient_handler
|
||||||
from colossalai.context import Config, ConfigException, ParallelMode
|
from colossalai.context import Config, ConfigException, ParallelMode
|
||||||
from colossalai.core import global_context as gpc
|
from colossalai.core import global_context as gpc
|
||||||
@ -23,8 +24,7 @@ from colossalai.engine import Engine
|
|||||||
from colossalai.global_variables import moe_env
|
from colossalai.global_variables import moe_env
|
||||||
from colossalai.logging import get_dist_logger
|
from colossalai.logging import get_dist_logger
|
||||||
from colossalai.nn.optimizer.colossalai_optimizer import ColossalaiOptimizer
|
from colossalai.nn.optimizer.colossalai_optimizer import ColossalaiOptimizer
|
||||||
from colossalai.utils import (accumulate_gradient, get_current_device,
|
from colossalai.utils import (accumulate_gradient, get_current_device, is_using_ddp, is_using_pp, is_using_sequence,
|
||||||
is_using_ddp, is_using_pp, is_using_sequence,
|
|
||||||
sync_model_param)
|
sync_model_param)
|
||||||
from colossalai.zero import convert_to_zero, ShardedOptimizer
|
from colossalai.zero import convert_to_zero, ShardedOptimizer
|
||||||
from colossalai.engine.ophooks import register_ophooks_recursively, BaseOpHook
|
from colossalai.engine.ophooks import register_ophooks_recursively, BaseOpHook
|
||||||
@ -39,21 +39,12 @@ def get_default_parser():
|
|||||||
"""
|
"""
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('--config', type=str, help='path to the config file')
|
parser.add_argument('--config', type=str, help='path to the config file')
|
||||||
parser.add_argument('--host',
|
parser.add_argument('--host', type=str, help='the master address for distributed training')
|
||||||
type=str,
|
parser.add_argument('--port', type=int, help='the master port for distributed training')
|
||||||
help='the master address for distributed training')
|
|
||||||
parser.add_argument('--port',
|
|
||||||
type=int,
|
|
||||||
help='the master port for distributed training')
|
|
||||||
parser.add_argument('--world_size', type=int, help='world size for distributed training')
|
parser.add_argument('--world_size', type=int, help='world size for distributed training')
|
||||||
parser.add_argument('--rank', type=int, help='rank for the default process group')
|
parser.add_argument('--rank', type=int, help='rank for the default process group')
|
||||||
parser.add_argument('--local_rank',
|
parser.add_argument('--local_rank', type=int, help='local rank on the node')
|
||||||
type=int,
|
parser.add_argument('--backend', type=str, default='nccl', help='backend for distributed communication')
|
||||||
help='local rank on the node')
|
|
||||||
parser.add_argument('--backend',
|
|
||||||
type=str,
|
|
||||||
default='nccl',
|
|
||||||
help='backend for distributed communication')
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
@ -116,9 +107,11 @@ def launch(config: Union[str, Path, Config, Dict],
|
|||||||
|
|
||||||
if verbose:
|
if verbose:
|
||||||
logger = get_dist_logger()
|
logger = get_dist_logger()
|
||||||
logger.info(f'Distributed environment is initialized, '
|
logger.info(
|
||||||
|
f'Distributed environment is initialized, '
|
||||||
f'data parallel size: {gpc.data_parallel_size}, pipeline parallel size: {gpc.pipeline_parallel_size}, '
|
f'data parallel size: {gpc.data_parallel_size}, pipeline parallel size: {gpc.pipeline_parallel_size}, '
|
||||||
f'tensor parallel size: {gpc.tensor_parallel_size}', ranks=[0])
|
f'tensor parallel size: {gpc.tensor_parallel_size}',
|
||||||
|
ranks=[0])
|
||||||
|
|
||||||
|
|
||||||
def launch_from_slurm(config: Union[str, Path, Config, Dict],
|
def launch_from_slurm(config: Union[str, Path, Config, Dict],
|
||||||
@ -261,9 +254,11 @@ def initialize(model: Union[nn.Module, List[nn.Module]],
|
|||||||
|
|
||||||
# print config
|
# print config
|
||||||
if verbose:
|
if verbose:
|
||||||
logger.info(f"\n========== Your Config ========\n"
|
logger.info(
|
||||||
|
f"\n========== Your Config ========\n"
|
||||||
f"{pprint.pformat(gpc.config)}\n"
|
f"{pprint.pformat(gpc.config)}\n"
|
||||||
f"================================\n", ranks=[0])
|
f"================================\n",
|
||||||
|
ranks=[0])
|
||||||
|
|
||||||
# cudnn
|
# cudnn
|
||||||
cudnn_benchmark = config.get('cudnn_benchmark', True)
|
cudnn_benchmark = config.get('cudnn_benchmark', True)
|
||||||
@ -271,8 +266,7 @@ def initialize(model: Union[nn.Module, List[nn.Module]],
|
|||||||
torch.backends.cudnn.benchmark = cudnn_benchmark
|
torch.backends.cudnn.benchmark = cudnn_benchmark
|
||||||
torch.backends.cudnn.deterministic = cudnn_deterministic
|
torch.backends.cudnn.deterministic = cudnn_deterministic
|
||||||
if verbose:
|
if verbose:
|
||||||
logger.info(
|
logger.info(f"cuDNN benchmark = {cudnn_benchmark}, deterministic = {cudnn_deterministic}", ranks=[0])
|
||||||
f"cuDNN benchmark = {cudnn_benchmark}, deterministic = {cudnn_deterministic}", ranks=[0])
|
|
||||||
|
|
||||||
# first sync model across dp ranks
|
# first sync model across dp ranks
|
||||||
model.to(get_current_device())
|
model.to(get_current_device())
|
||||||
@ -321,11 +315,7 @@ def initialize(model: Union[nn.Module, List[nn.Module]],
|
|||||||
if zero_cfg is not None:
|
if zero_cfg is not None:
|
||||||
cfg_ = zero_cfg.copy()
|
cfg_ = zero_cfg.copy()
|
||||||
level = cfg_.pop('level')
|
level = cfg_.pop('level')
|
||||||
model, optimizer = convert_to_zero(model=model,
|
model, optimizer = convert_to_zero(model=model, optimizer=optimizer, level=level, zero_config=cfg_)
|
||||||
optimizer=optimizer,
|
|
||||||
level=level,
|
|
||||||
zero_config=cfg_
|
|
||||||
)
|
|
||||||
|
|
||||||
# gradient handler
|
# gradient handler
|
||||||
gradient_handler_cfg = gpc.config.get('gradient_handler', None)
|
gradient_handler_cfg = gpc.config.get('gradient_handler', None)
|
||||||
@ -350,21 +340,22 @@ def initialize(model: Union[nn.Module, List[nn.Module]],
|
|||||||
"added even though not specified in the configuration",
|
"added even though not specified in the configuration",
|
||||||
ranks=[0])
|
ranks=[0])
|
||||||
elif is_using_sequence():
|
elif is_using_sequence():
|
||||||
model = DDP(model, process_group=gpc.get_group(ParallelMode.SEQUENCE_DP),
|
model = DDP(model,
|
||||||
|
process_group=gpc.get_group(ParallelMode.SEQUENCE_DP),
|
||||||
device_ids=[torch.cuda.current_device()])
|
device_ids=[torch.cuda.current_device()])
|
||||||
if verbose:
|
if verbose:
|
||||||
logger.info(
|
logger.info('Model is using torch.nn.parallel.DistributedDataParallel for Sequence Parallelism',
|
||||||
'Model is using torch.nn.parallel.DistributedDataParallel for Sequence Parallelism', ranks=[0])
|
ranks=[0])
|
||||||
elif is_using_ddp() and not is_using_pp() and amp_mode != AMP_TYPE.NAIVE:
|
elif is_using_ddp() and not is_using_pp() and amp_mode != AMP_TYPE.NAIVE:
|
||||||
model = DDP(model, process_group=gpc.get_group(ParallelMode.DATA), device_ids=[torch.cuda.current_device()])
|
model = DDP(model, process_group=gpc.get_group(ParallelMode.DATA), device_ids=[torch.cuda.current_device()])
|
||||||
if verbose:
|
if verbose:
|
||||||
logger.info(
|
logger.info('Model is using torch.nn.parallel.DistributedDataParallel for Data Parallelism', ranks=[0])
|
||||||
'Model is using torch.nn.parallel.DistributedDataParallel for Data Parallelism', ranks=[0])
|
|
||||||
elif is_using_ddp():
|
elif is_using_ddp():
|
||||||
gradient_handler_cfg = [dict(type='DataParallelGradientHandler')]
|
gradient_handler_cfg = [dict(type='DataParallelGradientHandler')]
|
||||||
if verbose:
|
if verbose:
|
||||||
logger.info(
|
logger.info(
|
||||||
"Data parallel training is detected when using pipeline parallel, DataParallelGradientHandler is automatically "
|
"Data parallel training is detected when using pipeline parallel, "
|
||||||
|
"DataParallelGradientHandler is automatically "
|
||||||
"added even though not specified in the configuration",
|
"added even though not specified in the configuration",
|
||||||
ranks=[0])
|
ranks=[0])
|
||||||
# add pipeline parallel gradient handler, if pipeline shared module is detected
|
# add pipeline parallel gradient handler, if pipeline shared module is detected
|
||||||
@ -383,7 +374,13 @@ def initialize(model: Union[nn.Module, List[nn.Module]],
|
|||||||
else:
|
else:
|
||||||
if not isinstance(gradient_handler_cfg, list):
|
if not isinstance(gradient_handler_cfg, list):
|
||||||
raise ConfigException(
|
raise ConfigException(
|
||||||
f"expected gradient_handler in the configuration file to be a list but got {type(gradient_handler_cfg)}")
|
f"expected gradient_handler in the configuration file to be a list but got {type(gradient_handler_cfg)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# turn off sync buffer for NaiveAMPModel if using torch DDP and NaiveAMPModel at the same time
|
||||||
|
# to avoid duplicated buffer synchronization
|
||||||
|
if isinstance(model, DDP) and isinstance(model.module, NaiveAMPModel):
|
||||||
|
model.module.sync_buffer = False
|
||||||
|
|
||||||
if gradient_handler_cfg is None:
|
if gradient_handler_cfg is None:
|
||||||
gradient_handlers = None
|
gradient_handlers = None
|
||||||
|
@ -9,10 +9,7 @@ from .sharded_model import ShardedModel
|
|||||||
from .sharded_optim import ShardedOptimizer
|
from .sharded_optim import ShardedOptimizer
|
||||||
|
|
||||||
|
|
||||||
def convert_to_zero(model: nn.Module,
|
def convert_to_zero(model: nn.Module, optimizer: Optimizer, level: int, zero_config: dict):
|
||||||
optimizer: Optimizer,
|
|
||||||
level: int,
|
|
||||||
zero_config: dict):
|
|
||||||
"""
|
"""
|
||||||
A helper function to integrate the model and optimizer with ZeRO optimizer and off-loading
|
A helper function to integrate the model and optimizer with ZeRO optimizer and off-loading
|
||||||
|
|
||||||
@ -31,11 +28,16 @@ def convert_to_zero(model: nn.Module,
|
|||||||
assert 1 <= level <= 3, 'Only ZERO Optimizer Level 1-3 are provided'
|
assert 1 <= level <= 3, 'Only ZERO Optimizer Level 1-3 are provided'
|
||||||
if level in [1, 2]:
|
if level in [1, 2]:
|
||||||
if level == 2:
|
if level == 2:
|
||||||
assert config['partition_grad'], 'ZeRO Optimizer requires partition_grad to be True'
|
if 'partition_grad' in zero_config:
|
||||||
|
assert zero_config['partition_grad'], \
|
||||||
|
'Sharded Optimizer requires partition_grad to be True'
|
||||||
|
else:
|
||||||
|
zero_config['partiton_grad'] = True
|
||||||
model = NaiveAMPModel(model, output_to_fp32=True)
|
model = NaiveAMPModel(model, output_to_fp32=True)
|
||||||
optimizer = ShardedOptimizer(model.parameters(), *zero_config)
|
optimizer = ShardedOptimizer(optimizer, **zero_config)
|
||||||
else:
|
else:
|
||||||
model = ShardedModel(module=model, **zero_config)
|
model = ShardedModel(module=model, **zero_config)
|
||||||
return model, optimizer
|
return model, optimizer
|
||||||
|
|
||||||
|
|
||||||
__all__ = ['convert_to_zero', 'ShardedModel', 'ShardedOptimizer']
|
__all__ = ['convert_to_zero', 'ShardedModel', 'ShardedOptimizer']
|
||||||
|
@ -0,0 +1,84 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
# -*- encoding: utf-8 -*-
|
||||||
|
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
import colossalai
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
import torch.multiprocessing as mp
|
||||||
|
from colossalai.utils import free_port
|
||||||
|
from colossalai.core import global_context as gpc
|
||||||
|
from colossalai.context.parallel_mode import ParallelMode
|
||||||
|
from torchvision.models import resnet50
|
||||||
|
import torch.distributed as dist
|
||||||
|
|
||||||
|
|
||||||
|
def run_dist(rank, world_size, port):
|
||||||
|
# need to configure cudnn deterministic so that
|
||||||
|
# randomness of convolution layers will be disabled
|
||||||
|
colossalai.launch(config=dict(zero=dict(level=2, partition_grad=True),
|
||||||
|
cudnn_determinstic=True,
|
||||||
|
cudnn_benchmark=False),
|
||||||
|
rank=rank,
|
||||||
|
world_size=world_size,
|
||||||
|
host='localhost',
|
||||||
|
port=port,
|
||||||
|
backend='nccl')
|
||||||
|
|
||||||
|
model = resnet50()
|
||||||
|
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
|
||||||
|
criterion = torch.nn.CrossEntropyLoss()
|
||||||
|
|
||||||
|
engine, *args = colossalai.initialize(model, optimizer, criterion)
|
||||||
|
|
||||||
|
# train for dummy iterations
|
||||||
|
engine.train()
|
||||||
|
for _ in range(2):
|
||||||
|
data = torch.rand(4, 3, 128, 128).cuda().half()
|
||||||
|
label = torch.randint(0, 10, size=(4,)).cuda()
|
||||||
|
engine.zero_grad()
|
||||||
|
out = engine(data)
|
||||||
|
loss = engine.criterion(out, label)
|
||||||
|
engine.backward(loss)
|
||||||
|
engine.step()
|
||||||
|
|
||||||
|
# test
|
||||||
|
# need to make sure the batch norm stats are synchronized
|
||||||
|
# so that given the same input, the model will produce the same
|
||||||
|
# output on different ranks
|
||||||
|
engine.eval()
|
||||||
|
data = torch.rand(4, 3, 128, 128).cuda().half()
|
||||||
|
dist.broadcast(data, src=0, group=gpc.get_group(ParallelMode.DATA))
|
||||||
|
|
||||||
|
# predict
|
||||||
|
out = engine(data)
|
||||||
|
|
||||||
|
# test if results are equal
|
||||||
|
tensor_list = [torch.empty_like(out) for _ in range(world_size - 1)]
|
||||||
|
tensor_list.insert(rank, out)
|
||||||
|
dist.all_gather(tensor_list=tensor_list, tensor=out, group=gpc.get_group(ParallelMode.DATA))
|
||||||
|
|
||||||
|
assert torch.all(tensor_list[0] == tensor_list[1]), \
|
||||||
|
'expected the output from different ranks to be the same, but got different values'
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.dist
|
||||||
|
def test_sharded_optim_with_sync_bn():
|
||||||
|
"""
|
||||||
|
This test is to make sure that buffers are synchronized between ranks
|
||||||
|
when using ZeRO. An example of module buffer is the running stats of
|
||||||
|
BatchNormalization layer, i.e. mean and var.
|
||||||
|
|
||||||
|
If the buffers are not synchronized, the model will produce different
|
||||||
|
output even though the input and parameters are the same. This is not
|
||||||
|
wanted if we are doing predictions.
|
||||||
|
|
||||||
|
"""
|
||||||
|
world_size = 2
|
||||||
|
run_func = partial(run_dist, world_size=world_size, port=free_port())
|
||||||
|
mp.spawn(run_func, nprocs=world_size)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
test_sharded_optim_with_sync_bn()
|
Loading…
Reference in New Issue
Block a user