From 055fbf5be680dfde20be1c51302f3c8b154a93e4 Mon Sep 17 00:00:00 2001 From: HELSON Date: Fri, 1 Apr 2022 20:10:47 +0800 Subject: [PATCH] [zero] adapt zero for unsharded paramters (Optimizer part) (#601) --- colossalai/utils/checkpointing.py | 5 +- colossalai/zero/init_ctx/init_context.py | 25 +++- .../zero/sharded_optim/sharded_optim_v2.py | 52 ++++--- tests/components_to_test/no_leaf_module.py | 3 +- tests/test_moe/test_moe_zero_init.py | 7 +- tests/test_moe/test_moe_zero_model.py | 2 +- tests/test_moe/test_moe_zero_optim.py | 134 ++++++++++++++++++ tests/test_zero_data_parallel/common.py | 24 ++-- 8 files changed, 208 insertions(+), 44 deletions(-) create mode 100644 tests/test_moe/test_moe_zero_optim.py diff --git a/colossalai/utils/checkpointing.py b/colossalai/utils/checkpointing.py index bc49656bc..34eaa2ea0 100644 --- a/colossalai/utils/checkpointing.py +++ b/colossalai/utils/checkpointing.py @@ -6,7 +6,10 @@ import torch.distributed as dist from colossalai.communication.collective import scatter_object_list from colossalai.context.parallel_mode import ParallelMode from colossalai.core import global_context as gpc -from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX +try: + from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX +except ImportError: + _EXTRA_STATE_KEY_SUFFIX = '_extra_state' from .common import is_using_pp diff --git a/colossalai/zero/init_ctx/init_context.py b/colossalai/zero/init_ctx/init_context.py index 4be143c15..52a166d89 100644 --- a/colossalai/zero/init_ctx/init_context.py +++ b/colossalai/zero/init_ctx/init_context.py @@ -11,6 +11,7 @@ from colossalai.zero.shard_utils import BaseShardStrategy from colossalai.zero.sharded_model._utils import cast_tensor_to_fp16 from colossalai.zero.sharded_param import ShardedParamV2 from torch.distributed import ProcessGroup +from contextlib import AbstractContextManager def _substitute_init_recursively(cls, func): @@ -88,6 +89,7 @@ class ZeroContextConfig(object): """The configuration used to control zero context initialization. Args: + target_device (torch.device): The device where param data are after exiting the context. replicated (bool, optional): Whether the param is replicated across data parallel group. Some parameters are not replicated, e.g. parameters in MOE experts. shard_param (bool, optional): Is param sharded after exiting the context. Defaults to False. @@ -99,8 +101,13 @@ class ZeroContextConfig(object): See torchvision resnet18. Defaults to False. """ - def __init__(self, replicated: bool = True, shard_param: bool = False, rm_torch_payload_on_the_fly: bool = False): + def __init__(self, + target_device: torch.device, + replicated: bool = True, + shard_param: bool = False, + rm_torch_payload_on_the_fly: bool = False): super().__init__() + self.target_device = target_device self.is_replicated: bool = replicated self.shard_param: bool = shard_param self.rm_torch_payload_on_the_fly: bool = rm_torch_payload_on_the_fly @@ -114,7 +121,7 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses): 3. Shard the param and grad according to flags. Args: - target_device (torch.device): The device where param data after exiting the context. + target_device (torch.device): The device where param data are after exiting the context. shard_strategy (BaseShardStrategy): Shard strategy instance. shard_param (bool, optional): Is param sharded after exiting the context. Defaults to False. rm_torch_payload_on_the_fly (bool, optional): If set to `True`, remove tensor payload on `param.data` after module init finished. @@ -136,17 +143,22 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses): dp_process_group: Optional[ProcessGroup] = None): super().__init__() - self.target_device = target_device self.shard_strategy = shard_strategy self.initialized_param_list = [] self.model_numel_tensor = model_numel_tensor self.dp_process_group = dp_process_group or gpc.get_group(ParallelMode.DATA) - self.config = ZeroContextConfig(replicated=True, + self.config = ZeroContextConfig(target_device=target_device, + replicated=True, shard_param=shard_param, rm_torch_payload_on_the_fly=rm_torch_payload_on_the_fly) + ZeroContextMgr().current_context = self + @property + def target_device(self): + return self.config.target_device + @property def is_replicated(self): return self.config.is_replicated @@ -235,8 +247,9 @@ class ZeroContextMgr(metaclass=SingletonMeta): self.current_context.config = old_config -def no_shard_zero_context(is_replicated: bool = True): - return ZeroContextMgr().hijack_context_config(replicated=is_replicated, +def no_shard_zero_context(is_replicated: bool = True) -> AbstractContextManager: + return ZeroContextMgr().hijack_context_config(target_device=torch.device('cuda', torch.cuda.current_device()), + replicated=is_replicated, shard_param=False, rm_torch_payload_on_the_fly=False) diff --git a/colossalai/zero/sharded_optim/sharded_optim_v2.py b/colossalai/zero/sharded_optim/sharded_optim_v2.py index a45677b7b..dfb0f00d8 100644 --- a/colossalai/zero/sharded_optim/sharded_optim_v2.py +++ b/colossalai/zero/sharded_optim/sharded_optim_v2.py @@ -12,13 +12,12 @@ from colossalai.logging import get_dist_logger from colossalai.nn.optimizer import ColossalaiOptimizer from colossalai.utils.memory_tracer.model_data_memtracer import \ GLOBAL_MODEL_DATA_TRACER -from colossalai.zero.shard_utils.tensor_utils import (colo_model_tensor_clone, colo_tensor_mem_usage) +from colossalai.zero.shard_utils.tensor_utils import (colo_model_data_tensor_move_inline, colo_model_tensor_clone, + colo_tensor_mem_usage) from colossalai.zero.sharded_model import ShardedModelV2 from colossalai.zero.sharded_model._utils import cast_tensor_to_fp32 from colossalai.zero.sharded_optim._utils import has_inf_or_nan from colossalai.zero.sharded_param.tensorful_state import (StatefulTensor, TensorState) -from colossalai.zero.shard_utils.tensor_utils import colo_model_data_tensor_move_inline - from torch import Tensor from torch.distributed import ProcessGroup from torch.nn.parameter import Parameter @@ -69,6 +68,9 @@ class ShardedOptimizerV2(ColossalaiOptimizer): backoff_factor (float, optional): backoff_factor used by DynamicGradScaler. Defaults to 0.5. growth_interval (float, optional): growth_interval used by DynamicGradScaler. Defaults to 1000. hysteresis (float, optional): hysteresis used by DynamicGradScaler. Defaults to 2. + keep_unsharded (bool, optional): if True, optimizer won't shard unsharded parameters. + In Zero-2, set keep_unsharded to False. + In Zero-3, set keep_unsharded to True. max_scale (int, optional): max_scale used by DynamicGradScaler. Defaults to 2**32. dp_process_group (Optional[ProcessGroup], optional): data paralle process group. Defaults to None. mp_process_group (Optional[ProcessGroup], optional): model paralle process group. Defaults to None. @@ -89,6 +91,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer): growth_interval: float = 1000, hysteresis: float = 2, max_scale: int = 2**32, + keep_unsharded: bool = False, dp_process_group: Optional[ProcessGroup] = None, mp_process_group: Optional[ProcessGroup] = None) -> None: assert isinstance(sharded_model, ShardedModelV2), 'model must be wrapped with ShardedModel' @@ -122,24 +125,12 @@ class ShardedOptimizerV2(ColossalaiOptimizer): self._found_overflow: Tensor = torch.FloatTensor([0]).to(torch.cuda.current_device()) self._logger = get_dist_logger("ShardedOptimizerV2") - # Store fp32 param shards - self.master_params: Dict[Parameter, StatefulTensor] = {} + assert not (keep_unsharded and self._should_move_fp32_shards_h2d), \ + "Keeping unsharded parameters can't be used with hybrid OS placement right now." + self.keep_unshard = keep_unsharded - for group in self.optim.param_groups: - for p in group['params']: - assert hasattr(p, 'colo_attr'), 'The parameter must be wrapped with ShardedParam' - is_param_sharded = p.colo_attr.sharded_data_tensor.is_sharded - if not is_param_sharded: - # TODO (ver217): we may not use shard / gather here - # Param is no sharded, which means we use ZeRO-2 here - # As we only store param shard, we shard it here - self.shard_strategy.shard([p.colo_attr.sharded_data_tensor], self.dp_process_group) - self.master_params[p] = StatefulTensor( - cast_tensor_to_fp32(p.colo_attr.sharded_data_tensor.payload).to(self.device)) - if not is_param_sharded: - # In this branch, there's no need to shard param - # So we gather here - self.shard_strategy.gather([p.colo_attr.sharded_data_tensor], self.dp_process_group) + # Store fp32 param shards + self._register_master_weight() self._logger.debug(f"After init ShardedOptimizerV2 consumes {self.get_memory_usage()[0]/1e6} MB CUDA Memory!", ranks=[0]) @@ -283,6 +274,23 @@ class ShardedOptimizerV2(ColossalaiOptimizer): def sync_grad(self): pass + def _register_master_weight(self): + self.master_params: Dict[Parameter, StatefulTensor] = {} + for group in self.optim.param_groups: + for p in group['params']: + assert hasattr(p, 'colo_attr'), 'The parameter must be wrapped with ShardedParam' + is_param_sharded = p.colo_attr.sharded_data_tensor.is_sharded + if not is_param_sharded and not self.keep_unshard: + # Please use keep_unsharded to control whether shard unsharded paramters + # As we only store param shard, we shard it here + self.shard_strategy.shard([p.colo_attr.sharded_data_tensor], self.dp_process_group) + self.master_params[p] = StatefulTensor( + cast_tensor_to_fp32(p.colo_attr.sharded_data_tensor.payload).to(self.device)) + if not is_param_sharded and not self.keep_unshard: + # In this branch, there's no need to shard param + # So we gather here + self.shard_strategy.gather([p.colo_attr.sharded_data_tensor], self.dp_process_group) + def _maybe_move_fp32_shards(self): if self._should_move_fp32_shards_h2d: self._should_move_fp32_shards_h2d = False @@ -328,7 +336,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer): for group in self.optim.param_groups: for p in group['params']: is_param_sharded = p.colo_attr.sharded_data_tensor.is_sharded - if not is_param_sharded: + if not is_param_sharded and not self.keep_unshard: # We use ZeRO-2 here # The `p.colo_attr.sharded_data_tensor` saves full fp16 param # But we only have updated fp32 param shard here @@ -342,7 +350,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer): p.colo_attr.sharded_data_tensor.reset_payload( colo_model_tensor_clone(p.half(), torch.cuda.current_device())) - if not is_param_sharded: + if not is_param_sharded and not self.keep_unshard: # We gather full fp16 param here self.shard_strategy.gather([p.colo_attr.sharded_data_tensor], self.dp_process_group) p.data = p.colo_attr.sharded_data_tensor.payload diff --git a/tests/components_to_test/no_leaf_module.py b/tests/components_to_test/no_leaf_module.py index c944ff48f..28a212f96 100644 --- a/tests/components_to_test/no_leaf_module.py +++ b/tests/components_to_test/no_leaf_module.py @@ -42,4 +42,5 @@ def get_training_components(): testloader = DummyDataLoader() criterion = torch.nn.CrossEntropyLoss() - return model_builder, trainloader, testloader, torch.optim.Adam, criterion + from colossalai.nn.optimizer import HybridAdam + return model_builder, trainloader, testloader, HybridAdam, criterion diff --git a/tests/test_moe/test_moe_zero_init.py b/tests/test_moe/test_moe_zero_init.py index 0fa5067a5..b1baab8eb 100644 --- a/tests/test_moe/test_moe_zero_init.py +++ b/tests/test_moe/test_moe_zero_init.py @@ -76,8 +76,11 @@ def run_moe_zero_init(init_device_type, shard_strategy_class): else: assert param.is_replicated - assert param.colo_attr.sharded_data_tensor.payload.device.type == init_device.type, \ - f'{param.colo_attr.sharded_data_tensor.payload.device.type} vs. {init_device.type}' + if param.colo_attr.param_is_sharded: + assert param.colo_attr.sharded_data_tensor.payload.device.type == init_device.type, \ + f'{param.colo_attr.sharded_data_tensor.payload.device.type} vs. {init_device.type}' + else: + assert param.colo_attr.sharded_data_tensor.payload.device.type == 'cuda' def _run_dist(rank, world_size, port): diff --git a/tests/test_moe/test_moe_zero_model.py b/tests/test_moe/test_moe_zero_model.py index 5b5a73d05..87a72a8e1 100644 --- a/tests/test_moe/test_moe_zero_model.py +++ b/tests/test_moe/test_moe_zero_model.py @@ -67,7 +67,7 @@ def run_dist(rank, world_size, port): @pytest.mark.dist -@pytest.mark.parametrize("world_size", [1, 2]) +@pytest.mark.parametrize("world_size", [2]) @rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*") def test_moe_zero_model(world_size): run_func = partial(run_dist, world_size=world_size, port=free_port()) diff --git a/tests/test_moe/test_moe_zero_optim.py b/tests/test_moe/test_moe_zero_optim.py new file mode 100644 index 000000000..91004545f --- /dev/null +++ b/tests/test_moe/test_moe_zero_optim.py @@ -0,0 +1,134 @@ +from functools import partial + +import colossalai +from colossalai.utils.cuda import get_current_device +import pytest +import torch +import torch.multiprocessing as mp +from colossalai.amp import convert_to_apex_amp +from colossalai.nn.optimizer import CPUAdam +from colossalai.testing import parameterize, rerun_on_exception +from colossalai.utils import free_port +from colossalai.zero.init_ctx import ZeroInitContext +from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy) +from colossalai.zero.sharded_model import ShardedModelV2 +from colossalai.zero.sharded_model.utils import col_model_deepcopy +from colossalai.zero.sharded_optim import ShardedOptimizerV2 +from colossalai.zero.sharded_optim._utils import has_inf_or_nan +from colossalai.utils import get_current_device +from tests.components_to_test.registry import non_distributed_component_funcs +from colossalai.engine.gradient_handler import MoeGradientHandler +from colossalai.context import MOE_CONTEXT +from colossalai.testing import assert_equal_in_group + +from tests.test_zero_data_parallel.common import CONFIG, check_sharded_model_params +from tests.test_moe.test_moe_zero_init import MoeModel + + +def _run_step(model, optimizer, data, label, criterion, grad_handler): + model.train() + optimizer.zero_grad() + + if criterion: + y = model(data) + loss = criterion(y, label) + else: + loss = model(data, label) + + loss = loss.float() + if isinstance(model, ShardedModelV2): + optimizer.backward(loss) + else: + loss.backward() + + if grad_handler is not None: + grad_handler.handle_gradient() + + optimizer.step() + + +@parameterize("cpu_offload", [True, False]) +@parameterize("use_cpuadam", [True, False]) +@parameterize("shard_strategy_class", [TensorShardStrategy, BucketTensorShardStrategy]) +def _run_test_sharded_optim_v2(cpu_offload, shard_strategy_class, use_cpuadam, gpu_margin_mem_ratio=0.0): + MOE_CONTEXT.reset_loss() + shard_strategy = shard_strategy_class() + if use_cpuadam and cpu_offload is False: + return + + get_components_func = non_distributed_component_funcs.get_callable('no_leaf_module') + _, train_dataloader, _, optimizer_class, criterion = get_components_func() + + with ZeroInitContext( + target_device=torch.device('cpu') if cpu_offload else torch.device(f'cuda:{get_current_device()}'), + shard_strategy=shard_strategy, + shard_param=True, + rm_torch_payload_on_the_fly=False): + zero_model = MoeModel() + + zero_model = ShardedModelV2( + zero_model, + shard_strategy, + offload_config=dict(device='cpu') if cpu_offload else None, + use_memory_tracer=gpu_margin_mem_ratio > 0.0, + reuse_fp16_shard=use_cpuadam, + ) + + # check whether parameters are identical in ddp + for name, p in zero_model.named_parameters(): + if not p.colo_attr.param_is_sharded and p.is_replicated: + assert_equal_in_group(p.data.to(get_current_device())) + + model = MoeModel().half() + col_model_deepcopy(zero_model, model) + model = model.cuda().float() + + if use_cpuadam: + optimizer_class = CPUAdam + optim = optimizer_class(model.parameters(), lr=1e-3) + sharded_optim = optimizer_class(zero_model.parameters(), lr=1e-3) + sharded_optim = ShardedOptimizerV2(zero_model, + sharded_optim, + cpu_offload=cpu_offload, + initial_scale=2**5, + gpu_margin_mem_ratio=gpu_margin_mem_ratio, + keep_unsharded=True) + + amp_config = dict(opt_level='O2', keep_batchnorm_fp32=False) + apex_model, apex_optimizer = convert_to_apex_amp(model, optim, amp_config) + apex_grad_handler = MoeGradientHandler(model) + + # Since MOE is not compatible with apex_amp now, we need to convert gate weight to fp32 + for (n, p), zp in zip(apex_model.named_parameters(), zero_model.parameters()): + if 'gate' in n: + p.data = p.float() + p.data.copy_(zp.data) + + for i, (data, label) in enumerate(train_dataloader): + if i > 5: + break + data, label = data.cuda(), label.cuda() + _run_step(apex_model, apex_optimizer, data, label, criterion, apex_grad_handler) + _run_step(zero_model, sharded_optim, data, label, criterion, None) + check_sharded_model_params(model, zero_model, loose=True, reuse_fp16_shard=use_cpuadam) + for param in model.parameters(): + assert not has_inf_or_nan(param) + + +def _run_dist(rank, world_size, port): + colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + MOE_CONTEXT.setup(seed=42) + _run_test_sharded_optim_v2() + + +# use_cpuadam = True can be used with cpu_offload = False +@pytest.mark.dist +@pytest.mark.parametrize("world_size", [2]) +@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*") +def test_moe_zero_optim(world_size): + run_func = partial(_run_dist, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_moe_zero_optim(world_size=2) diff --git a/tests/test_zero_data_parallel/common.py b/tests/test_zero_data_parallel/common.py index c7948c0fe..0143c0e3c 100644 --- a/tests/test_zero_data_parallel/common.py +++ b/tests/test_zero_data_parallel/common.py @@ -124,16 +124,18 @@ def check_params_padding(model, zero_model, loose=False): def check_sharded_model_params(model, zero_model, loose=False, reuse_fp16_shard=False): rank = dist.get_rank() - for p, zero_p in zip(model.parameters(), zero_model.parameters()): - if reuse_fp16_shard: - zero_p = zero_p.data.to(p.device).float() - else: - zero_p = zero_p.colo_attr.sharded_data_tensor.payload.to(p.device).float() - chunks = torch.flatten(p).chunk(dist.get_world_size()) - if rank >= len(chunks): - continue - p = chunks[rank].float() - if zero_p.size(0) > p.size(0): - zero_p = zero_p[:p.size(0)] + for (name, p), (zero_name, zero_p) in zip(model.named_parameters(), zero_model.named_parameters()): + if zero_p.colo_attr.param_is_sharded: + if reuse_fp16_shard: + zero_p = zero_p.data.to(p.device).float() + else: + zero_p = zero_p.colo_attr.sharded_data_tensor.payload.to(p.device).float() + chunks = torch.flatten(p).chunk(dist.get_world_size()) + if rank >= len(chunks): + continue + p = chunks[rank].float() + if zero_p.size(0) > p.size(0): + zero_p = zero_p[:p.size(0)] + assert p.dtype == zero_p.dtype assert allclose(p, zero_p, loose=loose), f'{p} vs {zero_p}'