From f7e276fa717f6a414f56246a54b87ba5a4c36fb3 Mon Sep 17 00:00:00 2001 From: Jiarui Fang Date: Wed, 16 Nov 2022 14:44:28 +0800 Subject: [PATCH] [Gemini] add GeminiAdamOptimizer (#1960) --- colossalai/nn/optimizer/gemini_optimizer.py | 15 ++++++++++++ colossalai/nn/optimizer/hybrid_adam.py | 15 +++++++----- .../{zero => nn/optimizer}/zero_optimizer.py | 5 ++-- colossalai/nn/parallel/gemini_parallel.py | 3 ++- colossalai/zero/__init__.py | 2 +- examples/language/gpt/README.md | 6 ++--- examples/language/gpt/train_gpt_demo.py | 10 ++++---- examples/language/opt/run_clm.py | 2 +- examples/tutorial/opt/opt/run_clm.py | 24 +++++++++---------- tests/test_gemini/update/test_optim.py | 2 +- .../update/test_zerooptim_state_dict.py | 2 +- tests/test_tensor/test_tp_with_zero.py | 24 ++++++++++--------- 12 files changed, 66 insertions(+), 44 deletions(-) create mode 100644 colossalai/nn/optimizer/gemini_optimizer.py rename colossalai/{zero => nn/optimizer}/zero_optimizer.py (98%) diff --git a/colossalai/nn/optimizer/gemini_optimizer.py b/colossalai/nn/optimizer/gemini_optimizer.py new file mode 100644 index 000000000..31d161612 --- /dev/null +++ b/colossalai/nn/optimizer/gemini_optimizer.py @@ -0,0 +1,15 @@ +from typing import Any + +import torch + +from colossalai.nn.optimizer import HybridAdam +from colossalai.nn.optimizer.zero_optimizer import ZeroOptimizer + +__all__ = ['GeminiAdamOptimizer'] + + +class GeminiAdamOptimizer(ZeroOptimizer): + + def __init__(self, model: torch.nn.Module, **defaults: Any) -> None: + optimizer = HybridAdam(model.parameters(), **defaults) + super().__init__(optimizer, model, **defaults) diff --git a/colossalai/nn/optimizer/hybrid_adam.py b/colossalai/nn/optimizer/hybrid_adam.py index 761843aab..069b52af5 100644 --- a/colossalai/nn/optimizer/hybrid_adam.py +++ b/colossalai/nn/optimizer/hybrid_adam.py @@ -1,8 +1,10 @@ +from typing import Any, Optional + import torch -from colossalai.utils import multi_tensor_applier from colossalai.registry import OPTIMIZERS -from typing import Optional +from colossalai.utils import multi_tensor_applier + from .nvme_optimizer import NVMeOptimizer @@ -11,7 +13,7 @@ class HybridAdam(NVMeOptimizer): """Implements Adam algorithm. Supports parameters updating on both GPU and CPU, depanding on the device of paramters. - But the parameters and gradients should on the same device: + But the parameters and gradients should on the same device: * Parameters on CPU and gradients on CPU is allowed. * Parameters on GPU and gradients on GPU is allowed. * Parameters on GPU and gradients on CPU is **not** allowed. @@ -43,7 +45,7 @@ class HybridAdam(NVMeOptimizer): (default: False) NOT SUPPORTED yet in CPUAdam! adamw_mode (boolean, optional): Apply L2 regularization or weight decay True for decoupled weight decay(also known as AdamW) (default: True) - simd_log (boolean, optional): whether to show if you are using SIMD to + simd_log (boolean, optional): whether to show if you are using SIMD to accelerate. (default: False) nvme_offload_fraction (float, optional): Fraction of optimizer states to be offloaded to NVMe. Defaults to 0.0. nvme_offload_dir (Optional[str], optional): Directory to save NVMe offload files. @@ -68,14 +70,15 @@ class HybridAdam(NVMeOptimizer): weight_decay=0, adamw_mode=True, nvme_offload_fraction: float = 0.0, - nvme_offload_dir: Optional[str] = None): + nvme_offload_dir: Optional[str] = None, + **defaults: Any): default_args = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, bias_correction=bias_correction) super(HybridAdam, self).__init__(model_params, default_args, nvme_offload_fraction, nvme_offload_dir) self.adamw_mode = adamw_mode try: - import cpu_adam import colossal_C + import cpu_adam except ImportError: raise ImportError('Please install colossalai from source code to use HybridAdam') diff --git a/colossalai/zero/zero_optimizer.py b/colossalai/nn/optimizer/zero_optimizer.py similarity index 98% rename from colossalai/zero/zero_optimizer.py rename to colossalai/nn/optimizer/zero_optimizer.py index 9a3101e38..09ecbb2c7 100644 --- a/colossalai/zero/zero_optimizer.py +++ b/colossalai/nn/optimizer/zero_optimizer.py @@ -1,5 +1,5 @@ from enum import Enum -from typing import Dict, Set, Tuple +from typing import Any, Dict, Set, Tuple import torch import torch.distributed as dist @@ -55,7 +55,8 @@ class ZeroOptimizer(ColossalaiOptimizer): backoff_factor: float = 0.5, growth_interval: int = 1000, hysteresis: int = 2, - max_scale: float = 2**32): + max_scale: float = 2**32, + **defaults: Any): super().__init__(optim) assert isinstance(module, ZeroDDP) self.module = module diff --git a/colossalai/nn/parallel/gemini_parallel.py b/colossalai/nn/parallel/gemini_parallel.py index c1223c27f..6cc188b4b 100644 --- a/colossalai/nn/parallel/gemini_parallel.py +++ b/colossalai/nn/parallel/gemini_parallel.py @@ -16,8 +16,9 @@ class GeminiDDP(ZeroDDP): force_outputs_fp32: bool = False, search_range_mb: int = 32) -> None: """ - A torch.Module warpper using ZeRODPP and Genimi. + A torch.Module warpper using ZeRO-DP and Genimi. ZeRO is for parallel. Gemini is for memory management. + WARNING: The class will modify the module inline! Example: model is initialized under the context of ColoInitContext diff --git a/colossalai/zero/__init__.py b/colossalai/zero/__init__.py index 3a896322f..098ccbb45 100644 --- a/colossalai/zero/__init__.py +++ b/colossalai/zero/__init__.py @@ -7,7 +7,7 @@ from colossalai.logging import get_dist_logger from colossalai.zero.sharded_model.sharded_model_v2 import ShardedModelV2 from colossalai.zero.sharded_optim import LowLevelZeroOptimizer, ShardedOptimizerV2 -from .zero_optimizer import ZeroOptimizer +from ..nn.optimizer.zero_optimizer import ZeroOptimizer def convert_to_zero_v2(model: nn.Module, optimizer: torch.optim.Optimizer, model_config, diff --git a/examples/language/gpt/README.md b/examples/language/gpt/README.md index e0e1dc5c1..2fc401004 100644 --- a/examples/language/gpt/README.md +++ b/examples/language/gpt/README.md @@ -3,10 +3,8 @@ This example shows how to use Colossal-AI to run huggingface GPT training in dis ## GPT We use the GPT2 model from huggingface transformers. The input data is randonly generated. - -## Our Modifications -The `train_gpt_demo.py` provides three distributed plans, i.e. Colossal-AI, PyTorch DDP and ZeRO. -The Colossal-AI leverages Tensor Parallel and Gemini. +The `train_gpt_demo.py` provides three distributed plans, i.e. ColossalAI, PyTorch DDP and ZeRO. +The ColossalAI leverages Tensor Parallel and Gemini. ## Quick Start You can launch training by using the following bash script. diff --git a/examples/language/gpt/train_gpt_demo.py b/examples/language/gpt/train_gpt_demo.py index 99de40e5f..92123e6a7 100644 --- a/examples/language/gpt/train_gpt_demo.py +++ b/examples/language/gpt/train_gpt_demo.py @@ -10,11 +10,12 @@ from torch.nn.parallel import DistributedDataParallel as DDP import colossalai from colossalai.logging import disable_existing_loggers, get_dist_logger from colossalai.nn.optimizer import HybridAdam +from colossalai.nn.optimizer.gemini_optimizer import GeminiAdamOptimizer +from colossalai.nn.optimizer.zero_optimizer import ZeroOptimizer from colossalai.nn.parallel import ZeroDDP from colossalai.tensor import ColoParameter, ComputePattern, ComputeSpec, ProcessGroup, ReplicaSpec, ShardSpec from colossalai.utils import get_current_device from colossalai.utils.model.colo_init_context import ColoInitContext -from colossalai.zero import ZeroOptimizer from transformers import GPT2Config, GPT2LMHeadModel @@ -222,7 +223,7 @@ def main(): default_dist_spec = ShardSpec([-1], [args.tp_degree]) if args.shardinit else None # build GPT model - with ColoInitContext(device='cuda', default_dist_spec=default_dist_spec, default_pg=default_pg): + with ColoInitContext(device='cpu', default_dist_spec=default_dist_spec, default_pg=default_pg): model = gpt2_medium(checkpoint=True) pg = default_pg @@ -232,8 +233,9 @@ def main(): model = gemini_zero_dpp(model, pg, args.placement) # build optimizer - optimizer = HybridAdam(model.parameters(), lr=1e-3) - optimizer = ZeroOptimizer(optimizer, model, initial_scale=2**5) + optimizer = GeminiAdamOptimizer(model, lr=1e-3, initial_scale=2**5) + # optimizer = HybridAdam(model.parameters(), lr=1e-3) + # optimizer = ZeroOptimizer(optimizer, model, initial_scale=2**5) logger.info(get_mem_info(prefix='After init optim, '), ranks=[0]) elif args.distplan == "ddp": diff --git a/examples/language/opt/run_clm.py b/examples/language/opt/run_clm.py index 00e05459a..c6590323e 100755 --- a/examples/language/opt/run_clm.py +++ b/examples/language/opt/run_clm.py @@ -43,11 +43,11 @@ from colossalai.context import ParallelMode from colossalai.core import global_context as gpc from colossalai.logging import disable_existing_loggers, get_dist_logger from colossalai.nn.optimizer import HybridAdam +from colossalai.nn.optimizer.zero_optimizer import ZeroOptimizer from colossalai.nn.parallel import ZeroDDP from colossalai.tensor import ProcessGroup from colossalai.utils import get_current_device, get_dataloader from colossalai.utils.model.colo_init_context import ColoInitContext -from colossalai.zero import ZeroOptimizer from transformers import ( CONFIG_MAPPING, MODEL_MAPPING, diff --git a/examples/tutorial/opt/opt/run_clm.py b/examples/tutorial/opt/opt/run_clm.py index 00a2da101..c4f576cb1 100755 --- a/examples/tutorial/opt/opt/run_clm.py +++ b/examples/tutorial/opt/opt/run_clm.py @@ -30,13 +30,24 @@ from itertools import chain import datasets import torch import torch.distributed as dist -import transformers from accelerate.utils import set_seed from context import barrier_context from datasets import load_dataset from packaging import version from torch.utils.data import DataLoader from tqdm.auto import tqdm + +import colossalai +import transformers +from colossalai.context import ParallelMode +from colossalai.core import global_context as gpc +from colossalai.logging import disable_existing_loggers, get_dist_logger +from colossalai.nn.optimizer import HybridAdam +from colossalai.nn.optimizer.zero_optimizer import ZeroOptimizer +from colossalai.nn.parallel import ZeroDDP +from colossalai.tensor import ProcessGroup +from colossalai.utils import get_current_device, get_dataloader +from colossalai.utils.model.colo_init_context import ColoInitContext from transformers import ( CONFIG_MAPPING, MODEL_MAPPING, @@ -50,17 +61,6 @@ from transformers import ( ) from transformers.utils.versions import require_version -import colossalai -from colossalai.context import ParallelMode -from colossalai.core import global_context as gpc -from colossalai.logging import disable_existing_loggers, get_dist_logger -from colossalai.nn.optimizer import HybridAdam -from colossalai.nn.parallel import ZeroDDP -from colossalai.tensor import ProcessGroup -from colossalai.utils import get_current_device, get_dataloader -from colossalai.utils.model.colo_init_context import ColoInitContext -from colossalai.zero import ZeroOptimizer - require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt") MODEL_CONFIG_CLASSES = list(MODEL_MAPPING.keys()) diff --git a/tests/test_gemini/update/test_optim.py b/tests/test_gemini/update/test_optim.py index a7c2fc2b2..008813698 100644 --- a/tests/test_gemini/update/test_optim.py +++ b/tests/test_gemini/update/test_optim.py @@ -12,12 +12,12 @@ from colossalai.amp import convert_to_apex_amp from colossalai.gemini.chunk import ChunkManager, init_chunk_manager, search_chunk_configuration from colossalai.gemini.gemini_mgr import GeminiManager from colossalai.nn.optimizer import HybridAdam +from colossalai.nn.optimizer.zero_optimizer import ZeroOptimizer from colossalai.nn.parallel import ZeroDDP from colossalai.testing import parameterize, rerun_if_address_is_in_use from colossalai.utils import free_port from colossalai.utils.cuda import get_current_device from colossalai.utils.model.colo_init_context import ColoInitContext -from colossalai.zero import ZeroOptimizer from tests.components_to_test.registry import non_distributed_component_funcs from tests.test_tensor.common_utils import debug_print, set_seed, tensor_equal, tensor_shard_equal diff --git a/tests/test_gemini/update/test_zerooptim_state_dict.py b/tests/test_gemini/update/test_zerooptim_state_dict.py index 74761668a..68885e543 100644 --- a/tests/test_gemini/update/test_zerooptim_state_dict.py +++ b/tests/test_gemini/update/test_zerooptim_state_dict.py @@ -9,12 +9,12 @@ import colossalai from colossalai.gemini.chunk import ChunkManager, search_chunk_configuration from colossalai.gemini.gemini_mgr import GeminiManager from colossalai.nn.optimizer import HybridAdam +from colossalai.nn.optimizer.zero_optimizer import ZeroOptimizer from colossalai.nn.parallel import ZeroDDP from colossalai.testing import parameterize, rerun_if_address_is_in_use from colossalai.utils import free_port from colossalai.utils.cuda import get_current_device from colossalai.utils.model.colo_init_context import ColoInitContext -from colossalai.zero import ZeroOptimizer from tests.components_to_test.registry import non_distributed_component_funcs from tests.test_tensor.common_utils import debug_print, set_seed diff --git a/tests/test_tensor/test_tp_with_zero.py b/tests/test_tensor/test_tp_with_zero.py index 9ea274fd1..b87802191 100644 --- a/tests/test_tensor/test_tp_with_zero.py +++ b/tests/test_tensor/test_tp_with_zero.py @@ -7,16 +7,14 @@ from torch.nn.parallel import DistributedDataParallel as DDP import colossalai from colossalai.amp import convert_to_apex_amp -from colossalai.gemini.chunk import ChunkManager, search_chunk_configuration -from colossalai.gemini.gemini_mgr import GeminiManager -from colossalai.nn.optimizer import HybridAdam -from colossalai.nn.parallel import ZeroDDP +from colossalai.gemini.chunk import search_chunk_configuration +from colossalai.nn.optimizer.gemini_optimizer import GeminiAdamOptimizer +from colossalai.nn.parallel import GeminiDDP, ZeroDDP from colossalai.tensor import ColoTensor, ColoTensorSpec, ComputePattern, ComputeSpec, ProcessGroup, ShardSpec from colossalai.testing import parameterize, rerun_if_address_is_in_use from colossalai.utils import free_port from colossalai.utils.cuda import get_current_device from colossalai.utils.model.colo_init_context import ColoInitContext -from colossalai.zero import ZeroOptimizer from tests.components_to_test.registry import non_distributed_component_funcs from tests.test_tensor.common_utils import set_seed, tensor_shard_equal from tests.test_tensor.model.test_gpt2 import init_megatron_spec @@ -96,19 +94,23 @@ def run_gpt(placement_policy, tp_init_spec_func=None): init_device = torch.device('cpu') else: init_device = None - chunk_manager = ChunkManager(config_dict, init_device=init_device) - gemini_manager = GeminiManager(placement_policy, chunk_manager) - model = ZeroDDP(model, gemini_manager, pin_memory=True) - optimizer = HybridAdam(model.parameters(), lr=1e-3) - zero_optim = ZeroOptimizer(optimizer, model, initial_scale=1) + model = GeminiDDP(model, init_device, placement_policy, True, False, 32) + # The same as the following 3 lines + # chunk_manager = ChunkManager(config_dict, init_device=init_device) + # gemini_manager = GeminiManager(placement_policy, chunk_manager) + # model = ZeroDDP(model, gemini_manager, pin_memory=True) + + zero_optim = GeminiAdamOptimizer(model, lr=1e-3, initial_scale=1) + # The same as the following 2 lines + # optimizer = HybridAdam(model.parameters(), lr=1e-3) + # zero_optim = ZeroOptimizer(optimizer, model, initial_scale=1) amp_config = dict(opt_level='O2', keep_batchnorm_fp32=False, loss_scale=1) torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1e-3) torch_model, torch_optim = convert_to_apex_amp(torch_model, torch_optim, amp_config) torch_model = DDP(torch_model, device_ids=[pg.rank()], process_group=pg.dp_process_group()) - print(chunk_manager) check_param(model, torch_model, pg) model.eval()