From 193af3a8b71ac375f65f0cd68a97e4be373c99e8 Mon Sep 17 00:00:00 2001 From: Frank Lee Date: Wed, 2 Mar 2022 16:47:17 +0800 Subject: [PATCH] added buffer sync to naive amp model wrapper (#291) --- colossalai/amp/naive_amp/naive_amp.py | 72 ++++++++++++++-- colossalai/initialize.py | 67 +++++++-------- colossalai/zero/__init__.py | 14 ++-- .../test_sharded_optim_with_sync_bn.py | 84 +++++++++++++++++++ 4 files changed, 191 insertions(+), 46 deletions(-) create mode 100644 tests/test_zero_data_parallel/test_sharded_optim_with_sync_bn.py diff --git a/colossalai/amp/naive_amp/naive_amp.py b/colossalai/amp/naive_amp/naive_amp.py index 62a6b9ff2..c4e950f68 100644 --- a/colossalai/amp/naive_amp/naive_amp.py +++ b/colossalai/amp/naive_amp/naive_amp.py @@ -3,12 +3,15 @@ import torch import torch.nn as nn +import torch.distributed as dist from torch import Tensor -from typing import Union, List, Any, Dict +from typing import Any from torch.optim import Optimizer -import torch.cuda.amp as torch_amp - +from torch.distributed import ReduceOp +from colossalai.core import global_context as gpc +from colossalai.context import ParallelMode from colossalai.nn.optimizer import ColossalaiOptimizer +from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors from ._fp16_optimizer import FP16Optimizer @@ -43,16 +46,36 @@ class NaiveAMPOptimizer(ColossalaiOptimizer): class NaiveAMPModel(nn.Module): - """A wrapper class for model to cast the model into fp16 and + """A wrapper class for model to cast the model into fp16 and automatically cast the input and output """ def __init__(self, model: nn.Module, - output_to_fp32: bool = True): + output_to_fp32: bool = True, + parallel_mode: ParallelMode = ParallelMode.DATA, + sync_buffer: bool = True): super().__init__() self.model = model.half() self._output_to_fp32 = output_to_fp32 + self._sync_buf = sync_buffer + + if gpc.is_initialized(parallel_mode) and gpc.get_world_size(parallel_mode) > 1: + self._process_group = gpc.get_group(parallel_mode) + self._world_size = gpc.get_world_size(parallel_mode) + else: + self._process_group = None + self._world_size = 1 + self._sync_buf = False + self._first_eval_run = False + + @property + def sync_buffer(self): + return self._sync_buf + + @sync_buffer.setter + def sync_buffer(self, state: bool): + self._sync_buf = state def _convert_to_fp16(self, input_: Any): if isinstance(input_, Tensor) and input_.dtype == torch.float32: @@ -64,7 +87,46 @@ class NaiveAMPModel(nn.Module): input_ = input_.float() return input_ + def _reduce_module_buffer(self): + """ + All-reduce the buffers (e.g. running stats of batch normalization) across + data parallel ranks so that all the ranks will produce consistent results + when given the same input + """ + buf_list = [] + + # find valid buffers + for buf in self.model.buffers(): + if buf is not None: + buf_list.append(buf) + + # reduce buffers across data parallel ranks + if buf_list: + coalesced_buf = _flatten_dense_tensors(buf_list) + coalesced_buf.div_(self._world_size) + dist.all_reduce(coalesced_buf, op=ReduceOp.SUM, group=self._process_group) + unflattened_buf_list = _unflatten_dense_tensors(coalesced_buf, buf_list) + for old, new in zip(buf_list, unflattened_buf_list): + old.copy_(new) + + def eval(self): + self.model.eval() + + # we only sync buffer in the first eval iteration + # so that future eval iterations can be done without communication + self._first_eval_run = True + def forward(self, *args, **kwargs): + # reduce buffers after forward will lead to error + # as we cannot change the variables needed for gradient computation after forward + # so we sync buffer before forward + if (self.training or self._first_eval_run) and self._sync_buf: + with torch.no_grad(): + self._reduce_module_buffer() + + if self._first_eval_run: + self._first_eval_run = False + if args: args = [self._convert_to_fp16(arg) for arg in args] if kwargs: diff --git a/colossalai/initialize.py b/colossalai/initialize.py index d2620c466..010cee736 100644 --- a/colossalai/initialize.py +++ b/colossalai/initialize.py @@ -16,6 +16,7 @@ from torch.optim.optimizer import Optimizer from torch.utils.data import DataLoader from colossalai.amp import AMP_TYPE, convert_to_amp +from colossalai.amp.naive_amp import NaiveAMPModel from colossalai.builder.builder import build_gradient_handler from colossalai.context import Config, ConfigException, ParallelMode from colossalai.core import global_context as gpc @@ -23,8 +24,7 @@ from colossalai.engine import Engine from colossalai.global_variables import moe_env from colossalai.logging import get_dist_logger from colossalai.nn.optimizer.colossalai_optimizer import ColossalaiOptimizer -from colossalai.utils import (accumulate_gradient, get_current_device, - is_using_ddp, is_using_pp, is_using_sequence, +from colossalai.utils import (accumulate_gradient, get_current_device, is_using_ddp, is_using_pp, is_using_sequence, sync_model_param) from colossalai.zero import convert_to_zero, ShardedOptimizer from colossalai.engine.ophooks import register_ophooks_recursively, BaseOpHook @@ -39,21 +39,12 @@ def get_default_parser(): """ parser = argparse.ArgumentParser() parser.add_argument('--config', type=str, help='path to the config file') - parser.add_argument('--host', - type=str, - help='the master address for distributed training') - parser.add_argument('--port', - type=int, - help='the master port for distributed training') + parser.add_argument('--host', type=str, help='the master address for distributed training') + parser.add_argument('--port', type=int, help='the master port for distributed training') parser.add_argument('--world_size', type=int, help='world size for distributed training') parser.add_argument('--rank', type=int, help='rank for the default process group') - parser.add_argument('--local_rank', - type=int, - help='local rank on the node') - parser.add_argument('--backend', - type=str, - default='nccl', - help='backend for distributed communication') + parser.add_argument('--local_rank', type=int, help='local rank on the node') + parser.add_argument('--backend', type=str, default='nccl', help='backend for distributed communication') return parser @@ -116,9 +107,11 @@ def launch(config: Union[str, Path, Config, Dict], if verbose: logger = get_dist_logger() - logger.info(f'Distributed environment is initialized, ' - f'data parallel size: {gpc.data_parallel_size}, pipeline parallel size: {gpc.pipeline_parallel_size}, ' - f'tensor parallel size: {gpc.tensor_parallel_size}', ranks=[0]) + logger.info( + f'Distributed environment is initialized, ' + f'data parallel size: {gpc.data_parallel_size}, pipeline parallel size: {gpc.pipeline_parallel_size}, ' + f'tensor parallel size: {gpc.tensor_parallel_size}', + ranks=[0]) def launch_from_slurm(config: Union[str, Path, Config, Dict], @@ -261,9 +254,11 @@ def initialize(model: Union[nn.Module, List[nn.Module]], # print config if verbose: - logger.info(f"\n========== Your Config ========\n" - f"{pprint.pformat(gpc.config)}\n" - f"================================\n", ranks=[0]) + logger.info( + f"\n========== Your Config ========\n" + f"{pprint.pformat(gpc.config)}\n" + f"================================\n", + ranks=[0]) # cudnn cudnn_benchmark = config.get('cudnn_benchmark', True) @@ -271,8 +266,7 @@ def initialize(model: Union[nn.Module, List[nn.Module]], torch.backends.cudnn.benchmark = cudnn_benchmark torch.backends.cudnn.deterministic = cudnn_deterministic if verbose: - logger.info( - f"cuDNN benchmark = {cudnn_benchmark}, deterministic = {cudnn_deterministic}", ranks=[0]) + logger.info(f"cuDNN benchmark = {cudnn_benchmark}, deterministic = {cudnn_deterministic}", ranks=[0]) # first sync model across dp ranks model.to(get_current_device()) @@ -321,11 +315,7 @@ def initialize(model: Union[nn.Module, List[nn.Module]], if zero_cfg is not None: cfg_ = zero_cfg.copy() level = cfg_.pop('level') - model, optimizer = convert_to_zero(model=model, - optimizer=optimizer, - level=level, - zero_config=cfg_ - ) + model, optimizer = convert_to_zero(model=model, optimizer=optimizer, level=level, zero_config=cfg_) # gradient handler gradient_handler_cfg = gpc.config.get('gradient_handler', None) @@ -350,21 +340,22 @@ def initialize(model: Union[nn.Module, List[nn.Module]], "added even though not specified in the configuration", ranks=[0]) elif is_using_sequence(): - model = DDP(model, process_group=gpc.get_group(ParallelMode.SEQUENCE_DP), + model = DDP(model, + process_group=gpc.get_group(ParallelMode.SEQUENCE_DP), device_ids=[torch.cuda.current_device()]) if verbose: - logger.info( - 'Model is using torch.nn.parallel.DistributedDataParallel for Sequence Parallelism', ranks=[0]) + logger.info('Model is using torch.nn.parallel.DistributedDataParallel for Sequence Parallelism', + ranks=[0]) elif is_using_ddp() and not is_using_pp() and amp_mode != AMP_TYPE.NAIVE: model = DDP(model, process_group=gpc.get_group(ParallelMode.DATA), device_ids=[torch.cuda.current_device()]) if verbose: - logger.info( - 'Model is using torch.nn.parallel.DistributedDataParallel for Data Parallelism', ranks=[0]) + logger.info('Model is using torch.nn.parallel.DistributedDataParallel for Data Parallelism', ranks=[0]) elif is_using_ddp(): gradient_handler_cfg = [dict(type='DataParallelGradientHandler')] if verbose: logger.info( - "Data parallel training is detected when using pipeline parallel, DataParallelGradientHandler is automatically " + "Data parallel training is detected when using pipeline parallel, " + "DataParallelGradientHandler is automatically " "added even though not specified in the configuration", ranks=[0]) # add pipeline parallel gradient handler, if pipeline shared module is detected @@ -383,7 +374,13 @@ def initialize(model: Union[nn.Module, List[nn.Module]], else: if not isinstance(gradient_handler_cfg, list): raise ConfigException( - f"expected gradient_handler in the configuration file to be a list but got {type(gradient_handler_cfg)}") + f"expected gradient_handler in the configuration file to be a list but got {type(gradient_handler_cfg)}" + ) + + # turn off sync buffer for NaiveAMPModel if using torch DDP and NaiveAMPModel at the same time + # to avoid duplicated buffer synchronization + if isinstance(model, DDP) and isinstance(model.module, NaiveAMPModel): + model.module.sync_buffer = False if gradient_handler_cfg is None: gradient_handlers = None diff --git a/colossalai/zero/__init__.py b/colossalai/zero/__init__.py index 708650cd8..95186233a 100644 --- a/colossalai/zero/__init__.py +++ b/colossalai/zero/__init__.py @@ -9,10 +9,7 @@ from .sharded_model import ShardedModel from .sharded_optim import ShardedOptimizer -def convert_to_zero(model: nn.Module, - optimizer: Optimizer, - level: int, - zero_config: dict): +def convert_to_zero(model: nn.Module, optimizer: Optimizer, level: int, zero_config: dict): """ A helper function to integrate the model and optimizer with ZeRO optimizer and off-loading @@ -31,11 +28,16 @@ def convert_to_zero(model: nn.Module, assert 1 <= level <= 3, 'Only ZERO Optimizer Level 1-3 are provided' if level in [1, 2]: if level == 2: - assert config['partition_grad'], 'ZeRO Optimizer requires partition_grad to be True' + if 'partition_grad' in zero_config: + assert zero_config['partition_grad'], \ + 'Sharded Optimizer requires partition_grad to be True' + else: + zero_config['partiton_grad'] = True model = NaiveAMPModel(model, output_to_fp32=True) - optimizer = ShardedOptimizer(model.parameters(), *zero_config) + optimizer = ShardedOptimizer(optimizer, **zero_config) else: model = ShardedModel(module=model, **zero_config) return model, optimizer + __all__ = ['convert_to_zero', 'ShardedModel', 'ShardedOptimizer'] diff --git a/tests/test_zero_data_parallel/test_sharded_optim_with_sync_bn.py b/tests/test_zero_data_parallel/test_sharded_optim_with_sync_bn.py new file mode 100644 index 000000000..d9b6524c8 --- /dev/null +++ b/tests/test_zero_data_parallel/test_sharded_optim_with_sync_bn.py @@ -0,0 +1,84 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +from functools import partial + +import colossalai +import pytest +import torch +import torch.multiprocessing as mp +from colossalai.utils import free_port +from colossalai.core import global_context as gpc +from colossalai.context.parallel_mode import ParallelMode +from torchvision.models import resnet50 +import torch.distributed as dist + + +def run_dist(rank, world_size, port): + # need to configure cudnn deterministic so that + # randomness of convolution layers will be disabled + colossalai.launch(config=dict(zero=dict(level=2, partition_grad=True), + cudnn_determinstic=True, + cudnn_benchmark=False), + rank=rank, + world_size=world_size, + host='localhost', + port=port, + backend='nccl') + + model = resnet50() + optimizer = torch.optim.Adam(model.parameters(), lr=0.001) + criterion = torch.nn.CrossEntropyLoss() + + engine, *args = colossalai.initialize(model, optimizer, criterion) + + # train for dummy iterations + engine.train() + for _ in range(2): + data = torch.rand(4, 3, 128, 128).cuda().half() + label = torch.randint(0, 10, size=(4,)).cuda() + engine.zero_grad() + out = engine(data) + loss = engine.criterion(out, label) + engine.backward(loss) + engine.step() + + # test + # need to make sure the batch norm stats are synchronized + # so that given the same input, the model will produce the same + # output on different ranks + engine.eval() + data = torch.rand(4, 3, 128, 128).cuda().half() + dist.broadcast(data, src=0, group=gpc.get_group(ParallelMode.DATA)) + + # predict + out = engine(data) + + # test if results are equal + tensor_list = [torch.empty_like(out) for _ in range(world_size - 1)] + tensor_list.insert(rank, out) + dist.all_gather(tensor_list=tensor_list, tensor=out, group=gpc.get_group(ParallelMode.DATA)) + + assert torch.all(tensor_list[0] == tensor_list[1]), \ + 'expected the output from different ranks to be the same, but got different values' + + +@pytest.mark.dist +def test_sharded_optim_with_sync_bn(): + """ + This test is to make sure that buffers are synchronized between ranks + when using ZeRO. An example of module buffer is the running stats of + BatchNormalization layer, i.e. mean and var. + + If the buffers are not synchronized, the model will produce different + output even though the input and parameters are the same. This is not + wanted if we are doing predictions. + + """ + world_size = 2 + run_func = partial(run_dist, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_sharded_optim_with_sync_bn()