mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-26 15:32:22 +00:00
[Gemini] add GeminiAdamOptimizer (#1960)
This commit is contained in:
parent
7066dfbf82
commit
f7e276fa71
15
colossalai/nn/optimizer/gemini_optimizer.py
Normal file
15
colossalai/nn/optimizer/gemini_optimizer.py
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from colossalai.nn.optimizer import HybridAdam
|
||||||
|
from colossalai.nn.optimizer.zero_optimizer import ZeroOptimizer
|
||||||
|
|
||||||
|
__all__ = ['GeminiAdamOptimizer']
|
||||||
|
|
||||||
|
|
||||||
|
class GeminiAdamOptimizer(ZeroOptimizer):
|
||||||
|
|
||||||
|
def __init__(self, model: torch.nn.Module, **defaults: Any) -> None:
|
||||||
|
optimizer = HybridAdam(model.parameters(), **defaults)
|
||||||
|
super().__init__(optimizer, model, **defaults)
|
@ -1,8 +1,10 @@
|
|||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from colossalai.utils import multi_tensor_applier
|
|
||||||
from colossalai.registry import OPTIMIZERS
|
from colossalai.registry import OPTIMIZERS
|
||||||
from typing import Optional
|
from colossalai.utils import multi_tensor_applier
|
||||||
|
|
||||||
from .nvme_optimizer import NVMeOptimizer
|
from .nvme_optimizer import NVMeOptimizer
|
||||||
|
|
||||||
|
|
||||||
@ -68,14 +70,15 @@ class HybridAdam(NVMeOptimizer):
|
|||||||
weight_decay=0,
|
weight_decay=0,
|
||||||
adamw_mode=True,
|
adamw_mode=True,
|
||||||
nvme_offload_fraction: float = 0.0,
|
nvme_offload_fraction: float = 0.0,
|
||||||
nvme_offload_dir: Optional[str] = None):
|
nvme_offload_dir: Optional[str] = None,
|
||||||
|
**defaults: Any):
|
||||||
|
|
||||||
default_args = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, bias_correction=bias_correction)
|
default_args = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, bias_correction=bias_correction)
|
||||||
super(HybridAdam, self).__init__(model_params, default_args, nvme_offload_fraction, nvme_offload_dir)
|
super(HybridAdam, self).__init__(model_params, default_args, nvme_offload_fraction, nvme_offload_dir)
|
||||||
self.adamw_mode = adamw_mode
|
self.adamw_mode = adamw_mode
|
||||||
try:
|
try:
|
||||||
import cpu_adam
|
|
||||||
import colossal_C
|
import colossal_C
|
||||||
|
import cpu_adam
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise ImportError('Please install colossalai from source code to use HybridAdam')
|
raise ImportError('Please install colossalai from source code to use HybridAdam')
|
||||||
|
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Dict, Set, Tuple
|
from typing import Any, Dict, Set, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
@ -55,7 +55,8 @@ class ZeroOptimizer(ColossalaiOptimizer):
|
|||||||
backoff_factor: float = 0.5,
|
backoff_factor: float = 0.5,
|
||||||
growth_interval: int = 1000,
|
growth_interval: int = 1000,
|
||||||
hysteresis: int = 2,
|
hysteresis: int = 2,
|
||||||
max_scale: float = 2**32):
|
max_scale: float = 2**32,
|
||||||
|
**defaults: Any):
|
||||||
super().__init__(optim)
|
super().__init__(optim)
|
||||||
assert isinstance(module, ZeroDDP)
|
assert isinstance(module, ZeroDDP)
|
||||||
self.module = module
|
self.module = module
|
@ -16,8 +16,9 @@ class GeminiDDP(ZeroDDP):
|
|||||||
force_outputs_fp32: bool = False,
|
force_outputs_fp32: bool = False,
|
||||||
search_range_mb: int = 32) -> None:
|
search_range_mb: int = 32) -> None:
|
||||||
"""
|
"""
|
||||||
A torch.Module warpper using ZeRODPP and Genimi.
|
A torch.Module warpper using ZeRO-DP and Genimi.
|
||||||
ZeRO is for parallel. Gemini is for memory management.
|
ZeRO is for parallel. Gemini is for memory management.
|
||||||
|
WARNING: The class will modify the module inline!
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
model is initialized under the context of ColoInitContext
|
model is initialized under the context of ColoInitContext
|
||||||
|
@ -7,7 +7,7 @@ from colossalai.logging import get_dist_logger
|
|||||||
from colossalai.zero.sharded_model.sharded_model_v2 import ShardedModelV2
|
from colossalai.zero.sharded_model.sharded_model_v2 import ShardedModelV2
|
||||||
from colossalai.zero.sharded_optim import LowLevelZeroOptimizer, ShardedOptimizerV2
|
from colossalai.zero.sharded_optim import LowLevelZeroOptimizer, ShardedOptimizerV2
|
||||||
|
|
||||||
from .zero_optimizer import ZeroOptimizer
|
from ..nn.optimizer.zero_optimizer import ZeroOptimizer
|
||||||
|
|
||||||
|
|
||||||
def convert_to_zero_v2(model: nn.Module, optimizer: torch.optim.Optimizer, model_config,
|
def convert_to_zero_v2(model: nn.Module, optimizer: torch.optim.Optimizer, model_config,
|
||||||
|
@ -3,10 +3,8 @@ This example shows how to use Colossal-AI to run huggingface GPT training in dis
|
|||||||
|
|
||||||
## GPT
|
## GPT
|
||||||
We use the GPT2 model from huggingface transformers. The input data is randonly generated.
|
We use the GPT2 model from huggingface transformers. The input data is randonly generated.
|
||||||
|
The `train_gpt_demo.py` provides three distributed plans, i.e. ColossalAI, PyTorch DDP and ZeRO.
|
||||||
## Our Modifications
|
The ColossalAI leverages Tensor Parallel and Gemini.
|
||||||
The `train_gpt_demo.py` provides three distributed plans, i.e. Colossal-AI, PyTorch DDP and ZeRO.
|
|
||||||
The Colossal-AI leverages Tensor Parallel and Gemini.
|
|
||||||
|
|
||||||
## Quick Start
|
## Quick Start
|
||||||
You can launch training by using the following bash script.
|
You can launch training by using the following bash script.
|
||||||
|
@ -10,11 +10,12 @@ from torch.nn.parallel import DistributedDataParallel as DDP
|
|||||||
import colossalai
|
import colossalai
|
||||||
from colossalai.logging import disable_existing_loggers, get_dist_logger
|
from colossalai.logging import disable_existing_loggers, get_dist_logger
|
||||||
from colossalai.nn.optimizer import HybridAdam
|
from colossalai.nn.optimizer import HybridAdam
|
||||||
|
from colossalai.nn.optimizer.gemini_optimizer import GeminiAdamOptimizer
|
||||||
|
from colossalai.nn.optimizer.zero_optimizer import ZeroOptimizer
|
||||||
from colossalai.nn.parallel import ZeroDDP
|
from colossalai.nn.parallel import ZeroDDP
|
||||||
from colossalai.tensor import ColoParameter, ComputePattern, ComputeSpec, ProcessGroup, ReplicaSpec, ShardSpec
|
from colossalai.tensor import ColoParameter, ComputePattern, ComputeSpec, ProcessGroup, ReplicaSpec, ShardSpec
|
||||||
from colossalai.utils import get_current_device
|
from colossalai.utils import get_current_device
|
||||||
from colossalai.utils.model.colo_init_context import ColoInitContext
|
from colossalai.utils.model.colo_init_context import ColoInitContext
|
||||||
from colossalai.zero import ZeroOptimizer
|
|
||||||
from transformers import GPT2Config, GPT2LMHeadModel
|
from transformers import GPT2Config, GPT2LMHeadModel
|
||||||
|
|
||||||
|
|
||||||
@ -222,7 +223,7 @@ def main():
|
|||||||
default_dist_spec = ShardSpec([-1], [args.tp_degree]) if args.shardinit else None
|
default_dist_spec = ShardSpec([-1], [args.tp_degree]) if args.shardinit else None
|
||||||
|
|
||||||
# build GPT model
|
# build GPT model
|
||||||
with ColoInitContext(device='cuda', default_dist_spec=default_dist_spec, default_pg=default_pg):
|
with ColoInitContext(device='cpu', default_dist_spec=default_dist_spec, default_pg=default_pg):
|
||||||
model = gpt2_medium(checkpoint=True)
|
model = gpt2_medium(checkpoint=True)
|
||||||
|
|
||||||
pg = default_pg
|
pg = default_pg
|
||||||
@ -232,8 +233,9 @@ def main():
|
|||||||
model = gemini_zero_dpp(model, pg, args.placement)
|
model = gemini_zero_dpp(model, pg, args.placement)
|
||||||
|
|
||||||
# build optimizer
|
# build optimizer
|
||||||
optimizer = HybridAdam(model.parameters(), lr=1e-3)
|
optimizer = GeminiAdamOptimizer(model, lr=1e-3, initial_scale=2**5)
|
||||||
optimizer = ZeroOptimizer(optimizer, model, initial_scale=2**5)
|
# optimizer = HybridAdam(model.parameters(), lr=1e-3)
|
||||||
|
# optimizer = ZeroOptimizer(optimizer, model, initial_scale=2**5)
|
||||||
logger.info(get_mem_info(prefix='After init optim, '), ranks=[0])
|
logger.info(get_mem_info(prefix='After init optim, '), ranks=[0])
|
||||||
|
|
||||||
elif args.distplan == "ddp":
|
elif args.distplan == "ddp":
|
||||||
|
@ -43,11 +43,11 @@ from colossalai.context import ParallelMode
|
|||||||
from colossalai.core import global_context as gpc
|
from colossalai.core import global_context as gpc
|
||||||
from colossalai.logging import disable_existing_loggers, get_dist_logger
|
from colossalai.logging import disable_existing_loggers, get_dist_logger
|
||||||
from colossalai.nn.optimizer import HybridAdam
|
from colossalai.nn.optimizer import HybridAdam
|
||||||
|
from colossalai.nn.optimizer.zero_optimizer import ZeroOptimizer
|
||||||
from colossalai.nn.parallel import ZeroDDP
|
from colossalai.nn.parallel import ZeroDDP
|
||||||
from colossalai.tensor import ProcessGroup
|
from colossalai.tensor import ProcessGroup
|
||||||
from colossalai.utils import get_current_device, get_dataloader
|
from colossalai.utils import get_current_device, get_dataloader
|
||||||
from colossalai.utils.model.colo_init_context import ColoInitContext
|
from colossalai.utils.model.colo_init_context import ColoInitContext
|
||||||
from colossalai.zero import ZeroOptimizer
|
|
||||||
from transformers import (
|
from transformers import (
|
||||||
CONFIG_MAPPING,
|
CONFIG_MAPPING,
|
||||||
MODEL_MAPPING,
|
MODEL_MAPPING,
|
||||||
|
@ -30,13 +30,24 @@ from itertools import chain
|
|||||||
import datasets
|
import datasets
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
import transformers
|
|
||||||
from accelerate.utils import set_seed
|
from accelerate.utils import set_seed
|
||||||
from context import barrier_context
|
from context import barrier_context
|
||||||
from datasets import load_dataset
|
from datasets import load_dataset
|
||||||
from packaging import version
|
from packaging import version
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
from tqdm.auto import tqdm
|
from tqdm.auto import tqdm
|
||||||
|
|
||||||
|
import colossalai
|
||||||
|
import transformers
|
||||||
|
from colossalai.context import ParallelMode
|
||||||
|
from colossalai.core import global_context as gpc
|
||||||
|
from colossalai.logging import disable_existing_loggers, get_dist_logger
|
||||||
|
from colossalai.nn.optimizer import HybridAdam
|
||||||
|
from colossalai.nn.optimizer.zero_optimizer import ZeroOptimizer
|
||||||
|
from colossalai.nn.parallel import ZeroDDP
|
||||||
|
from colossalai.tensor import ProcessGroup
|
||||||
|
from colossalai.utils import get_current_device, get_dataloader
|
||||||
|
from colossalai.utils.model.colo_init_context import ColoInitContext
|
||||||
from transformers import (
|
from transformers import (
|
||||||
CONFIG_MAPPING,
|
CONFIG_MAPPING,
|
||||||
MODEL_MAPPING,
|
MODEL_MAPPING,
|
||||||
@ -50,17 +61,6 @@ from transformers import (
|
|||||||
)
|
)
|
||||||
from transformers.utils.versions import require_version
|
from transformers.utils.versions import require_version
|
||||||
|
|
||||||
import colossalai
|
|
||||||
from colossalai.context import ParallelMode
|
|
||||||
from colossalai.core import global_context as gpc
|
|
||||||
from colossalai.logging import disable_existing_loggers, get_dist_logger
|
|
||||||
from colossalai.nn.optimizer import HybridAdam
|
|
||||||
from colossalai.nn.parallel import ZeroDDP
|
|
||||||
from colossalai.tensor import ProcessGroup
|
|
||||||
from colossalai.utils import get_current_device, get_dataloader
|
|
||||||
from colossalai.utils.model.colo_init_context import ColoInitContext
|
|
||||||
from colossalai.zero import ZeroOptimizer
|
|
||||||
|
|
||||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")
|
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")
|
||||||
|
|
||||||
MODEL_CONFIG_CLASSES = list(MODEL_MAPPING.keys())
|
MODEL_CONFIG_CLASSES = list(MODEL_MAPPING.keys())
|
||||||
|
@ -12,12 +12,12 @@ from colossalai.amp import convert_to_apex_amp
|
|||||||
from colossalai.gemini.chunk import ChunkManager, init_chunk_manager, search_chunk_configuration
|
from colossalai.gemini.chunk import ChunkManager, init_chunk_manager, search_chunk_configuration
|
||||||
from colossalai.gemini.gemini_mgr import GeminiManager
|
from colossalai.gemini.gemini_mgr import GeminiManager
|
||||||
from colossalai.nn.optimizer import HybridAdam
|
from colossalai.nn.optimizer import HybridAdam
|
||||||
|
from colossalai.nn.optimizer.zero_optimizer import ZeroOptimizer
|
||||||
from colossalai.nn.parallel import ZeroDDP
|
from colossalai.nn.parallel import ZeroDDP
|
||||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use
|
from colossalai.testing import parameterize, rerun_if_address_is_in_use
|
||||||
from colossalai.utils import free_port
|
from colossalai.utils import free_port
|
||||||
from colossalai.utils.cuda import get_current_device
|
from colossalai.utils.cuda import get_current_device
|
||||||
from colossalai.utils.model.colo_init_context import ColoInitContext
|
from colossalai.utils.model.colo_init_context import ColoInitContext
|
||||||
from colossalai.zero import ZeroOptimizer
|
|
||||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||||
from tests.test_tensor.common_utils import debug_print, set_seed, tensor_equal, tensor_shard_equal
|
from tests.test_tensor.common_utils import debug_print, set_seed, tensor_equal, tensor_shard_equal
|
||||||
|
|
||||||
|
@ -9,12 +9,12 @@ import colossalai
|
|||||||
from colossalai.gemini.chunk import ChunkManager, search_chunk_configuration
|
from colossalai.gemini.chunk import ChunkManager, search_chunk_configuration
|
||||||
from colossalai.gemini.gemini_mgr import GeminiManager
|
from colossalai.gemini.gemini_mgr import GeminiManager
|
||||||
from colossalai.nn.optimizer import HybridAdam
|
from colossalai.nn.optimizer import HybridAdam
|
||||||
|
from colossalai.nn.optimizer.zero_optimizer import ZeroOptimizer
|
||||||
from colossalai.nn.parallel import ZeroDDP
|
from colossalai.nn.parallel import ZeroDDP
|
||||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use
|
from colossalai.testing import parameterize, rerun_if_address_is_in_use
|
||||||
from colossalai.utils import free_port
|
from colossalai.utils import free_port
|
||||||
from colossalai.utils.cuda import get_current_device
|
from colossalai.utils.cuda import get_current_device
|
||||||
from colossalai.utils.model.colo_init_context import ColoInitContext
|
from colossalai.utils.model.colo_init_context import ColoInitContext
|
||||||
from colossalai.zero import ZeroOptimizer
|
|
||||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||||
from tests.test_tensor.common_utils import debug_print, set_seed
|
from tests.test_tensor.common_utils import debug_print, set_seed
|
||||||
|
|
||||||
|
@ -7,16 +7,14 @@ from torch.nn.parallel import DistributedDataParallel as DDP
|
|||||||
|
|
||||||
import colossalai
|
import colossalai
|
||||||
from colossalai.amp import convert_to_apex_amp
|
from colossalai.amp import convert_to_apex_amp
|
||||||
from colossalai.gemini.chunk import ChunkManager, search_chunk_configuration
|
from colossalai.gemini.chunk import search_chunk_configuration
|
||||||
from colossalai.gemini.gemini_mgr import GeminiManager
|
from colossalai.nn.optimizer.gemini_optimizer import GeminiAdamOptimizer
|
||||||
from colossalai.nn.optimizer import HybridAdam
|
from colossalai.nn.parallel import GeminiDDP, ZeroDDP
|
||||||
from colossalai.nn.parallel import ZeroDDP
|
|
||||||
from colossalai.tensor import ColoTensor, ColoTensorSpec, ComputePattern, ComputeSpec, ProcessGroup, ShardSpec
|
from colossalai.tensor import ColoTensor, ColoTensorSpec, ComputePattern, ComputeSpec, ProcessGroup, ShardSpec
|
||||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use
|
from colossalai.testing import parameterize, rerun_if_address_is_in_use
|
||||||
from colossalai.utils import free_port
|
from colossalai.utils import free_port
|
||||||
from colossalai.utils.cuda import get_current_device
|
from colossalai.utils.cuda import get_current_device
|
||||||
from colossalai.utils.model.colo_init_context import ColoInitContext
|
from colossalai.utils.model.colo_init_context import ColoInitContext
|
||||||
from colossalai.zero import ZeroOptimizer
|
|
||||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||||
from tests.test_tensor.common_utils import set_seed, tensor_shard_equal
|
from tests.test_tensor.common_utils import set_seed, tensor_shard_equal
|
||||||
from tests.test_tensor.model.test_gpt2 import init_megatron_spec
|
from tests.test_tensor.model.test_gpt2 import init_megatron_spec
|
||||||
@ -96,19 +94,23 @@ def run_gpt(placement_policy, tp_init_spec_func=None):
|
|||||||
init_device = torch.device('cpu')
|
init_device = torch.device('cpu')
|
||||||
else:
|
else:
|
||||||
init_device = None
|
init_device = None
|
||||||
chunk_manager = ChunkManager(config_dict, init_device=init_device)
|
|
||||||
gemini_manager = GeminiManager(placement_policy, chunk_manager)
|
|
||||||
model = ZeroDDP(model, gemini_manager, pin_memory=True)
|
|
||||||
|
|
||||||
optimizer = HybridAdam(model.parameters(), lr=1e-3)
|
model = GeminiDDP(model, init_device, placement_policy, True, False, 32)
|
||||||
zero_optim = ZeroOptimizer(optimizer, model, initial_scale=1)
|
# The same as the following 3 lines
|
||||||
|
# chunk_manager = ChunkManager(config_dict, init_device=init_device)
|
||||||
|
# gemini_manager = GeminiManager(placement_policy, chunk_manager)
|
||||||
|
# model = ZeroDDP(model, gemini_manager, pin_memory=True)
|
||||||
|
|
||||||
|
zero_optim = GeminiAdamOptimizer(model, lr=1e-3, initial_scale=1)
|
||||||
|
# The same as the following 2 lines
|
||||||
|
# optimizer = HybridAdam(model.parameters(), lr=1e-3)
|
||||||
|
# zero_optim = ZeroOptimizer(optimizer, model, initial_scale=1)
|
||||||
|
|
||||||
amp_config = dict(opt_level='O2', keep_batchnorm_fp32=False, loss_scale=1)
|
amp_config = dict(opt_level='O2', keep_batchnorm_fp32=False, loss_scale=1)
|
||||||
torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1e-3)
|
torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1e-3)
|
||||||
torch_model, torch_optim = convert_to_apex_amp(torch_model, torch_optim, amp_config)
|
torch_model, torch_optim = convert_to_apex_amp(torch_model, torch_optim, amp_config)
|
||||||
torch_model = DDP(torch_model, device_ids=[pg.rank()], process_group=pg.dp_process_group())
|
torch_model = DDP(torch_model, device_ids=[pg.rank()], process_group=pg.dp_process_group())
|
||||||
|
|
||||||
print(chunk_manager)
|
|
||||||
check_param(model, torch_model, pg)
|
check_param(model, torch_model, pg)
|
||||||
|
|
||||||
model.eval()
|
model.eval()
|
||||||
|
Loading…
Reference in New Issue
Block a user