mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 19:13:01 +00:00
Merge branch 'main' into feature/shardformer
This commit is contained in:
@@ -1,13 +1,11 @@
|
||||
import gc
|
||||
import logging
|
||||
import os
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from typing import Callable, Iterator, List, Optional, Tuple, Union
|
||||
from typing import Callable, Iterator, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch import Tensor
|
||||
from torch.optim import Optimizer
|
||||
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
|
||||
from torch.utils.data import DataLoader
|
||||
@@ -16,7 +14,6 @@ from colossalai.checkpoint_io import CheckpointIndexFile, CheckpointIO, GeneralC
|
||||
from colossalai.checkpoint_io.utils import (
|
||||
get_model_base_filenames,
|
||||
get_optimizer_base_filenames,
|
||||
get_shard_filename,
|
||||
load_shard_state_dict,
|
||||
save_config_file,
|
||||
save_state_dict,
|
||||
@@ -25,8 +22,7 @@ from colossalai.checkpoint_io.utils import (
|
||||
from colossalai.cluster import DistCoordinator
|
||||
from colossalai.interface import ModelWrapper, OptimizerWrapper
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.zero import GeminiDDP, zero_model_wrapper, zero_optim_wrapper
|
||||
from colossalai.zero.gemini import ZeroOptimizer
|
||||
from colossalai.zero import GeminiDDP, GeminiOptimizer
|
||||
from colossalai.zero.gemini.memory_tracer import MemStats
|
||||
|
||||
from .dp_plugin_base import DPPluginBase
|
||||
@@ -134,11 +130,7 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
|
||||
As there is communication when getting state dict, this must be called on all processes.
|
||||
"""
|
||||
|
||||
# If optimizer is wrapped, unwrap it.
|
||||
if isinstance(optimizer, OptimizerWrapper):
|
||||
optimizer = optimizer.unwrap()
|
||||
|
||||
assert isinstance(optimizer, ZeroOptimizer)
|
||||
assert isinstance(optimizer, GeminiOptimizer)
|
||||
|
||||
if os.path.isfile(checkpoint):
|
||||
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
|
||||
@@ -185,11 +177,7 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
|
||||
if not os.path.isfile(checkpoint_index_file):
|
||||
logging.error(f"Provided path ({checkpoint_index_file}) should be a file")
|
||||
|
||||
# If optimizer is wrapped, unwrap it.
|
||||
if isinstance(optimizer, OptimizerWrapper):
|
||||
optimizer = optimizer.unwrap()
|
||||
|
||||
assert isinstance(optimizer, ZeroOptimizer)
|
||||
assert isinstance(optimizer, GeminiOptimizer)
|
||||
|
||||
# Read checkpoint index file.
|
||||
ckpt_index_file = CheckpointIndexFile.from_file(checkpoint_index_file)
|
||||
@@ -222,47 +210,6 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
|
||||
super().save_lr_scheduler(lr_scheduler, checkpoint)
|
||||
|
||||
|
||||
class GeminiModel(ModelWrapper):
|
||||
|
||||
def __init__(self, module: nn.Module, gemini_config: dict, verbose: bool = False) -> None:
|
||||
super().__init__(module)
|
||||
self.module = zero_model_wrapper(module, zero_stage=3, gemini_config=gemini_config, verbose=verbose)
|
||||
|
||||
def unwrap(self):
|
||||
# as save/load state dict is coupled with the GeminiDDP, we only return GeminiDDP model
|
||||
return self.module
|
||||
|
||||
|
||||
class GeminiOptimizer(OptimizerWrapper):
|
||||
|
||||
def __init__(self,
|
||||
module: GeminiDDP,
|
||||
optimizer: Optimizer,
|
||||
zero_optim_config: dict,
|
||||
optim_kwargs: dict,
|
||||
verbose: bool = False) -> None:
|
||||
optimizer = zero_optim_wrapper(module,
|
||||
optimizer,
|
||||
optim_config=zero_optim_config,
|
||||
**optim_kwargs,
|
||||
verbose=verbose)
|
||||
super().__init__(optimizer)
|
||||
|
||||
def backward(self, loss: Tensor, *args, **kwargs):
|
||||
self.optim.backward(loss)
|
||||
|
||||
def clip_grad_by_norm(self,
|
||||
max_norm: Union[float, int],
|
||||
norm_type: Union[float, int] = 2,
|
||||
error_if_nonfinite: bool = False,
|
||||
*args,
|
||||
**kwargs) -> Tensor:
|
||||
warnings.warn(f'Gemini controls grad clipping by itself, so you should not use clip_grad_by_norm')
|
||||
|
||||
def clip_grad_by_value(self, clip_value: float, *args, **kwargs) -> None:
|
||||
raise NotImplementedError('Gemini does not support clip_grad_by_value')
|
||||
|
||||
|
||||
class GeminiPlugin(DPPluginBase):
|
||||
"""
|
||||
Plugin for Gemini.
|
||||
@@ -279,8 +226,20 @@ class GeminiPlugin(DPPluginBase):
|
||||
>>> model, optimizer, train_dataloader, criterion = booster.boost(model, optimizer, train_dataloader, criterion)
|
||||
|
||||
Args:
|
||||
device (torch.device): device to place the model.
|
||||
placement_policy (str, optional): "cpu", "cuda", "auto". Defaults to "cpu".
|
||||
chunk_config_dict (dict, optional): chunk configuration dictionary.
|
||||
chunk_init_device (torch.device, optional): device to initialize the chunk.
|
||||
placement_policy (str, optional): "static" and "auto". Defaults to "static".
|
||||
shard_param_frac (float, optional): fraction of parameters to be sharded. Only for "static" placement.
|
||||
If `shard_param_frac` is 1.0, it's equal to zero-3. If `shard_param_frac` is 0.0, it's equal to zero-2. Defaults to 1.0.
|
||||
offload_optim_frac (float, optional): fraction of optimizer states to be offloaded. Only for "static" placement.
|
||||
If `shard_param_frac` is 1.0 and `offload_optim_frac` is 0.0, it's equal to old "cuda" placement. Defaults to 0.0.
|
||||
offload_param_frac (float, optional): fraction of parameters to be offloaded. Only for "static" placement.
|
||||
For efficiency, this argument is useful only when `shard_param_frac` is 1.0 and `offload_optim_frac` is 1.0.
|
||||
If `shard_param_frac` is 1.0, `offload_optim_frac` is 1.0 and `offload_param_frac` is 1.0, it's equal to old "cpu" placement.
|
||||
When using static placement, we recommend users to tune `shard_param_frac` first and then `offload_optim_frac`.
|
||||
Defaults to 0.0.
|
||||
warmup_non_model_data_ratio (float, optional): ratio of expected non-model data memory during warmup. Only for "auto" placement. Defaults to 0.8.
|
||||
steady_cuda_cap_ratio (float, optional): ratio of allowed cuda capacity for model data during steady state. Only for "auto" placement. Defaults to 0.9.
|
||||
precision (str, optional): precision. Support 'fp16' and 'bf16'. Defaults to 'fp16'.
|
||||
pin_memory (bool, optional): use pin memory on CPU. Defaults to False.
|
||||
force_outputs_fp32 (bool, optional): force outputs are fp32. Defaults to False.
|
||||
@@ -312,8 +271,14 @@ class GeminiPlugin(DPPluginBase):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
device: Optional[torch.device] = None,
|
||||
placement_policy: str = "cpu",
|
||||
chunk_config_dict: Optional[dict] = None,
|
||||
chunk_init_device: Optional[torch.device] = None,
|
||||
placement_policy: str = "static",
|
||||
shard_param_frac: float = 1.0, # only for static placement
|
||||
offload_optim_frac: float = 0.0, # only for static placement
|
||||
offload_param_frac: float = 0.0, # only for static placement
|
||||
warmup_non_model_data_ratio: float = 0.8, # only for auto placement
|
||||
steady_cuda_cap_ratio: float = 0.9, # only for auto placement
|
||||
precision: str = "fp16",
|
||||
pin_memory: bool = False,
|
||||
force_outputs_fp32: bool = False,
|
||||
@@ -337,8 +302,14 @@ class GeminiPlugin(DPPluginBase):
|
||||
super().__init__()
|
||||
assert precision in SUPPORTED_PRECISION, f'precision {precision} is not supported'
|
||||
self.gemini_config = dict(
|
||||
device=(device or get_current_device()),
|
||||
chunk_config_dict=chunk_config_dict,
|
||||
chunk_init_device=(chunk_init_device or get_current_device()),
|
||||
placement_policy=placement_policy,
|
||||
shard_param_frac=shard_param_frac,
|
||||
offload_optim_frac=offload_optim_frac,
|
||||
offload_param_frac=offload_param_frac,
|
||||
warmup_non_model_data_ratio=warmup_non_model_data_ratio,
|
||||
steady_cuda_cap_ratio=steady_cuda_cap_ratio,
|
||||
pin_memory=pin_memory,
|
||||
force_outputs_fp32=force_outputs_fp32,
|
||||
strict_ddp_mode=strict_ddp_mode,
|
||||
@@ -395,12 +366,15 @@ class GeminiPlugin(DPPluginBase):
|
||||
# model = nn.SyncBatchNorm.convert_sync_batchnorm(model, None)
|
||||
|
||||
# wrap the model with Gemini
|
||||
model = GeminiModel(model, self.gemini_config, self.verbose)
|
||||
model = GeminiDDP(model, **self.gemini_config, verbose=self.verbose)
|
||||
|
||||
if optimizer is not None and \
|
||||
not isinstance(optimizer, OptimizerWrapper):
|
||||
optimizer = GeminiOptimizer(model.unwrap(), optimizer, self.zero_optim_config, self.optim_kwargs,
|
||||
self.verbose)
|
||||
optimizer = GeminiOptimizer(optimizer,
|
||||
model.unwrap(),
|
||||
**self.zero_optim_config,
|
||||
**self.optim_kwargs,
|
||||
verbose=self.verbose)
|
||||
|
||||
return model, optimizer, criterion, dataloader, lr_scheduler
|
||||
|
||||
|
Reference in New Issue
Block a user