[Gemini] add GeminiAdamOptimizer (#1960)

This commit is contained in:
Jiarui Fang 2022-11-16 14:44:28 +08:00 committed by GitHub
parent 7066dfbf82
commit f7e276fa71
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 66 additions and 44 deletions

View File

@ -0,0 +1,15 @@
from typing import Any
import torch
from colossalai.nn.optimizer import HybridAdam
from colossalai.nn.optimizer.zero_optimizer import ZeroOptimizer
__all__ = ['GeminiAdamOptimizer']
class GeminiAdamOptimizer(ZeroOptimizer):
def __init__(self, model: torch.nn.Module, **defaults: Any) -> None:
optimizer = HybridAdam(model.parameters(), **defaults)
super().__init__(optimizer, model, **defaults)

View File

@ -1,8 +1,10 @@
from typing import Any, Optional
import torch import torch
from colossalai.utils import multi_tensor_applier
from colossalai.registry import OPTIMIZERS from colossalai.registry import OPTIMIZERS
from typing import Optional from colossalai.utils import multi_tensor_applier
from .nvme_optimizer import NVMeOptimizer from .nvme_optimizer import NVMeOptimizer
@ -68,14 +70,15 @@ class HybridAdam(NVMeOptimizer):
weight_decay=0, weight_decay=0,
adamw_mode=True, adamw_mode=True,
nvme_offload_fraction: float = 0.0, nvme_offload_fraction: float = 0.0,
nvme_offload_dir: Optional[str] = None): nvme_offload_dir: Optional[str] = None,
**defaults: Any):
default_args = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, bias_correction=bias_correction) default_args = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, bias_correction=bias_correction)
super(HybridAdam, self).__init__(model_params, default_args, nvme_offload_fraction, nvme_offload_dir) super(HybridAdam, self).__init__(model_params, default_args, nvme_offload_fraction, nvme_offload_dir)
self.adamw_mode = adamw_mode self.adamw_mode = adamw_mode
try: try:
import cpu_adam
import colossal_C import colossal_C
import cpu_adam
except ImportError: except ImportError:
raise ImportError('Please install colossalai from source code to use HybridAdam') raise ImportError('Please install colossalai from source code to use HybridAdam')

View File

@ -1,5 +1,5 @@
from enum import Enum from enum import Enum
from typing import Dict, Set, Tuple from typing import Any, Dict, Set, Tuple
import torch import torch
import torch.distributed as dist import torch.distributed as dist
@ -55,7 +55,8 @@ class ZeroOptimizer(ColossalaiOptimizer):
backoff_factor: float = 0.5, backoff_factor: float = 0.5,
growth_interval: int = 1000, growth_interval: int = 1000,
hysteresis: int = 2, hysteresis: int = 2,
max_scale: float = 2**32): max_scale: float = 2**32,
**defaults: Any):
super().__init__(optim) super().__init__(optim)
assert isinstance(module, ZeroDDP) assert isinstance(module, ZeroDDP)
self.module = module self.module = module

View File

@ -16,8 +16,9 @@ class GeminiDDP(ZeroDDP):
force_outputs_fp32: bool = False, force_outputs_fp32: bool = False,
search_range_mb: int = 32) -> None: search_range_mb: int = 32) -> None:
""" """
A torch.Module warpper using ZeRODPP and Genimi. A torch.Module warpper using ZeRO-DP and Genimi.
ZeRO is for parallel. Gemini is for memory management. ZeRO is for parallel. Gemini is for memory management.
WARNING: The class will modify the module inline!
Example: Example:
model is initialized under the context of ColoInitContext model is initialized under the context of ColoInitContext

View File

@ -7,7 +7,7 @@ from colossalai.logging import get_dist_logger
from colossalai.zero.sharded_model.sharded_model_v2 import ShardedModelV2 from colossalai.zero.sharded_model.sharded_model_v2 import ShardedModelV2
from colossalai.zero.sharded_optim import LowLevelZeroOptimizer, ShardedOptimizerV2 from colossalai.zero.sharded_optim import LowLevelZeroOptimizer, ShardedOptimizerV2
from .zero_optimizer import ZeroOptimizer from ..nn.optimizer.zero_optimizer import ZeroOptimizer
def convert_to_zero_v2(model: nn.Module, optimizer: torch.optim.Optimizer, model_config, def convert_to_zero_v2(model: nn.Module, optimizer: torch.optim.Optimizer, model_config,

View File

@ -3,10 +3,8 @@ This example shows how to use Colossal-AI to run huggingface GPT training in dis
## GPT ## GPT
We use the GPT2 model from huggingface transformers. The input data is randonly generated. We use the GPT2 model from huggingface transformers. The input data is randonly generated.
The `train_gpt_demo.py` provides three distributed plans, i.e. ColossalAI, PyTorch DDP and ZeRO.
## Our Modifications The ColossalAI leverages Tensor Parallel and Gemini.
The `train_gpt_demo.py` provides three distributed plans, i.e. Colossal-AI, PyTorch DDP and ZeRO.
The Colossal-AI leverages Tensor Parallel and Gemini.
## Quick Start ## Quick Start
You can launch training by using the following bash script. You can launch training by using the following bash script.

View File

@ -10,11 +10,12 @@ from torch.nn.parallel import DistributedDataParallel as DDP
import colossalai import colossalai
from colossalai.logging import disable_existing_loggers, get_dist_logger from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.nn.optimizer import HybridAdam from colossalai.nn.optimizer import HybridAdam
from colossalai.nn.optimizer.gemini_optimizer import GeminiAdamOptimizer
from colossalai.nn.optimizer.zero_optimizer import ZeroOptimizer
from colossalai.nn.parallel import ZeroDDP from colossalai.nn.parallel import ZeroDDP
from colossalai.tensor import ColoParameter, ComputePattern, ComputeSpec, ProcessGroup, ReplicaSpec, ShardSpec from colossalai.tensor import ColoParameter, ComputePattern, ComputeSpec, ProcessGroup, ReplicaSpec, ShardSpec
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
from colossalai.utils.model.colo_init_context import ColoInitContext from colossalai.utils.model.colo_init_context import ColoInitContext
from colossalai.zero import ZeroOptimizer
from transformers import GPT2Config, GPT2LMHeadModel from transformers import GPT2Config, GPT2LMHeadModel
@ -222,7 +223,7 @@ def main():
default_dist_spec = ShardSpec([-1], [args.tp_degree]) if args.shardinit else None default_dist_spec = ShardSpec([-1], [args.tp_degree]) if args.shardinit else None
# build GPT model # build GPT model
with ColoInitContext(device='cuda', default_dist_spec=default_dist_spec, default_pg=default_pg): with ColoInitContext(device='cpu', default_dist_spec=default_dist_spec, default_pg=default_pg):
model = gpt2_medium(checkpoint=True) model = gpt2_medium(checkpoint=True)
pg = default_pg pg = default_pg
@ -232,8 +233,9 @@ def main():
model = gemini_zero_dpp(model, pg, args.placement) model = gemini_zero_dpp(model, pg, args.placement)
# build optimizer # build optimizer
optimizer = HybridAdam(model.parameters(), lr=1e-3) optimizer = GeminiAdamOptimizer(model, lr=1e-3, initial_scale=2**5)
optimizer = ZeroOptimizer(optimizer, model, initial_scale=2**5) # optimizer = HybridAdam(model.parameters(), lr=1e-3)
# optimizer = ZeroOptimizer(optimizer, model, initial_scale=2**5)
logger.info(get_mem_info(prefix='After init optim, '), ranks=[0]) logger.info(get_mem_info(prefix='After init optim, '), ranks=[0])
elif args.distplan == "ddp": elif args.distplan == "ddp":

View File

@ -43,11 +43,11 @@ from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.logging import disable_existing_loggers, get_dist_logger from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.nn.optimizer import HybridAdam from colossalai.nn.optimizer import HybridAdam
from colossalai.nn.optimizer.zero_optimizer import ZeroOptimizer
from colossalai.nn.parallel import ZeroDDP from colossalai.nn.parallel import ZeroDDP
from colossalai.tensor import ProcessGroup from colossalai.tensor import ProcessGroup
from colossalai.utils import get_current_device, get_dataloader from colossalai.utils import get_current_device, get_dataloader
from colossalai.utils.model.colo_init_context import ColoInitContext from colossalai.utils.model.colo_init_context import ColoInitContext
from colossalai.zero import ZeroOptimizer
from transformers import ( from transformers import (
CONFIG_MAPPING, CONFIG_MAPPING,
MODEL_MAPPING, MODEL_MAPPING,

View File

@ -30,13 +30,24 @@ from itertools import chain
import datasets import datasets
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import transformers
from accelerate.utils import set_seed from accelerate.utils import set_seed
from context import barrier_context from context import barrier_context
from datasets import load_dataset from datasets import load_dataset
from packaging import version from packaging import version
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from tqdm.auto import tqdm from tqdm.auto import tqdm
import colossalai
import transformers
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.nn.optimizer import HybridAdam
from colossalai.nn.optimizer.zero_optimizer import ZeroOptimizer
from colossalai.nn.parallel import ZeroDDP
from colossalai.tensor import ProcessGroup
from colossalai.utils import get_current_device, get_dataloader
from colossalai.utils.model.colo_init_context import ColoInitContext
from transformers import ( from transformers import (
CONFIG_MAPPING, CONFIG_MAPPING,
MODEL_MAPPING, MODEL_MAPPING,
@ -50,17 +61,6 @@ from transformers import (
) )
from transformers.utils.versions import require_version from transformers.utils.versions import require_version
import colossalai
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.nn.optimizer import HybridAdam
from colossalai.nn.parallel import ZeroDDP
from colossalai.tensor import ProcessGroup
from colossalai.utils import get_current_device, get_dataloader
from colossalai.utils.model.colo_init_context import ColoInitContext
from colossalai.zero import ZeroOptimizer
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt") require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")
MODEL_CONFIG_CLASSES = list(MODEL_MAPPING.keys()) MODEL_CONFIG_CLASSES = list(MODEL_MAPPING.keys())

View File

@ -12,12 +12,12 @@ from colossalai.amp import convert_to_apex_amp
from colossalai.gemini.chunk import ChunkManager, init_chunk_manager, search_chunk_configuration from colossalai.gemini.chunk import ChunkManager, init_chunk_manager, search_chunk_configuration
from colossalai.gemini.gemini_mgr import GeminiManager from colossalai.gemini.gemini_mgr import GeminiManager
from colossalai.nn.optimizer import HybridAdam from colossalai.nn.optimizer import HybridAdam
from colossalai.nn.optimizer.zero_optimizer import ZeroOptimizer
from colossalai.nn.parallel import ZeroDDP from colossalai.nn.parallel import ZeroDDP
from colossalai.testing import parameterize, rerun_if_address_is_in_use from colossalai.testing import parameterize, rerun_if_address_is_in_use
from colossalai.utils import free_port from colossalai.utils import free_port
from colossalai.utils.cuda import get_current_device from colossalai.utils.cuda import get_current_device
from colossalai.utils.model.colo_init_context import ColoInitContext from colossalai.utils.model.colo_init_context import ColoInitContext
from colossalai.zero import ZeroOptimizer
from tests.components_to_test.registry import non_distributed_component_funcs from tests.components_to_test.registry import non_distributed_component_funcs
from tests.test_tensor.common_utils import debug_print, set_seed, tensor_equal, tensor_shard_equal from tests.test_tensor.common_utils import debug_print, set_seed, tensor_equal, tensor_shard_equal

View File

@ -9,12 +9,12 @@ import colossalai
from colossalai.gemini.chunk import ChunkManager, search_chunk_configuration from colossalai.gemini.chunk import ChunkManager, search_chunk_configuration
from colossalai.gemini.gemini_mgr import GeminiManager from colossalai.gemini.gemini_mgr import GeminiManager
from colossalai.nn.optimizer import HybridAdam from colossalai.nn.optimizer import HybridAdam
from colossalai.nn.optimizer.zero_optimizer import ZeroOptimizer
from colossalai.nn.parallel import ZeroDDP from colossalai.nn.parallel import ZeroDDP
from colossalai.testing import parameterize, rerun_if_address_is_in_use from colossalai.testing import parameterize, rerun_if_address_is_in_use
from colossalai.utils import free_port from colossalai.utils import free_port
from colossalai.utils.cuda import get_current_device from colossalai.utils.cuda import get_current_device
from colossalai.utils.model.colo_init_context import ColoInitContext from colossalai.utils.model.colo_init_context import ColoInitContext
from colossalai.zero import ZeroOptimizer
from tests.components_to_test.registry import non_distributed_component_funcs from tests.components_to_test.registry import non_distributed_component_funcs
from tests.test_tensor.common_utils import debug_print, set_seed from tests.test_tensor.common_utils import debug_print, set_seed

View File

@ -7,16 +7,14 @@ from torch.nn.parallel import DistributedDataParallel as DDP
import colossalai import colossalai
from colossalai.amp import convert_to_apex_amp from colossalai.amp import convert_to_apex_amp
from colossalai.gemini.chunk import ChunkManager, search_chunk_configuration from colossalai.gemini.chunk import search_chunk_configuration
from colossalai.gemini.gemini_mgr import GeminiManager from colossalai.nn.optimizer.gemini_optimizer import GeminiAdamOptimizer
from colossalai.nn.optimizer import HybridAdam from colossalai.nn.parallel import GeminiDDP, ZeroDDP
from colossalai.nn.parallel import ZeroDDP
from colossalai.tensor import ColoTensor, ColoTensorSpec, ComputePattern, ComputeSpec, ProcessGroup, ShardSpec from colossalai.tensor import ColoTensor, ColoTensorSpec, ComputePattern, ComputeSpec, ProcessGroup, ShardSpec
from colossalai.testing import parameterize, rerun_if_address_is_in_use from colossalai.testing import parameterize, rerun_if_address_is_in_use
from colossalai.utils import free_port from colossalai.utils import free_port
from colossalai.utils.cuda import get_current_device from colossalai.utils.cuda import get_current_device
from colossalai.utils.model.colo_init_context import ColoInitContext from colossalai.utils.model.colo_init_context import ColoInitContext
from colossalai.zero import ZeroOptimizer
from tests.components_to_test.registry import non_distributed_component_funcs from tests.components_to_test.registry import non_distributed_component_funcs
from tests.test_tensor.common_utils import set_seed, tensor_shard_equal from tests.test_tensor.common_utils import set_seed, tensor_shard_equal
from tests.test_tensor.model.test_gpt2 import init_megatron_spec from tests.test_tensor.model.test_gpt2 import init_megatron_spec
@ -96,19 +94,23 @@ def run_gpt(placement_policy, tp_init_spec_func=None):
init_device = torch.device('cpu') init_device = torch.device('cpu')
else: else:
init_device = None init_device = None
chunk_manager = ChunkManager(config_dict, init_device=init_device)
gemini_manager = GeminiManager(placement_policy, chunk_manager)
model = ZeroDDP(model, gemini_manager, pin_memory=True)
optimizer = HybridAdam(model.parameters(), lr=1e-3) model = GeminiDDP(model, init_device, placement_policy, True, False, 32)
zero_optim = ZeroOptimizer(optimizer, model, initial_scale=1) # The same as the following 3 lines
# chunk_manager = ChunkManager(config_dict, init_device=init_device)
# gemini_manager = GeminiManager(placement_policy, chunk_manager)
# model = ZeroDDP(model, gemini_manager, pin_memory=True)
zero_optim = GeminiAdamOptimizer(model, lr=1e-3, initial_scale=1)
# The same as the following 2 lines
# optimizer = HybridAdam(model.parameters(), lr=1e-3)
# zero_optim = ZeroOptimizer(optimizer, model, initial_scale=1)
amp_config = dict(opt_level='O2', keep_batchnorm_fp32=False, loss_scale=1) amp_config = dict(opt_level='O2', keep_batchnorm_fp32=False, loss_scale=1)
torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1e-3) torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1e-3)
torch_model, torch_optim = convert_to_apex_amp(torch_model, torch_optim, amp_config) torch_model, torch_optim = convert_to_apex_amp(torch_model, torch_optim, amp_config)
torch_model = DDP(torch_model, device_ids=[pg.rank()], process_group=pg.dp_process_group()) torch_model = DDP(torch_model, device_ids=[pg.rank()], process_group=pg.dp_process_group())
print(chunk_manager)
check_param(model, torch_model, pg) check_param(model, torch_model, pg)
model.eval() model.eval()