mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-07 12:01:39 +00:00
[zero] hotfix master param sync (#4618)
* [zero] add method to update master params * [zero] update zero plugin * [plugin] update low level zero plugin
This commit is contained in:
@@ -3,6 +3,7 @@ import os
|
||||
import warnings
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from types import MethodType
|
||||
from typing import Callable, Iterator, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
@@ -25,9 +26,9 @@ from colossalai.checkpoint_io.utils import (
|
||||
sharded_optimizer_loading_epilogue,
|
||||
unwrap_optimizer,
|
||||
)
|
||||
from colossalai.interface import ModelWrapper, OptimizerWrapper
|
||||
from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.zero import LowLevelZeroOptimizer, zero_model_wrapper, zero_optim_wrapper
|
||||
from colossalai.zero import LowLevelZeroOptimizer
|
||||
|
||||
from .dp_plugin_base import DPPluginBase
|
||||
from .torch_ddp_plugin import TorchDDPCheckpointIO
|
||||
@@ -44,6 +45,34 @@ def _convert_floating_point(x, dtype: torch.dtype = torch.float16):
|
||||
SUPPORTED_PRECISION = ['fp16', 'bf16', 'fp32']
|
||||
|
||||
|
||||
class LowLevelZeroModel(ModelWrapper, AMPModelMixin):
|
||||
|
||||
def __init__(self, module: nn.Module, precision: str) -> None:
|
||||
super().__init__(module)
|
||||
self.dtype = None
|
||||
if precision == 'fp16':
|
||||
self.dtype = torch.float16
|
||||
elif precision == 'bf16':
|
||||
self.dtype = torch.bfloat16
|
||||
if self.dtype is not None:
|
||||
module = module.to(self.dtype)
|
||||
module = module.to(get_current_device())
|
||||
self.module = module
|
||||
self.convert_fn = None
|
||||
if self.dtype is not None:
|
||||
self.convert_fn = partial(_convert_floating_point, dtype=self.dtype)
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
if self.convert_fn is not None:
|
||||
args = tree_map(self.convert_fn, args)
|
||||
kwargs = tree_map(self.convert_fn, kwargs)
|
||||
return super().forward(*args, **kwargs)
|
||||
|
||||
def unwrap(self):
|
||||
# TODO(ver217): this is a workaround for loading model
|
||||
return self
|
||||
|
||||
|
||||
class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
|
||||
|
||||
def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool = False):
|
||||
@@ -165,30 +194,36 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
|
||||
|
||||
sharded_optimizer_loading_epilogue(optimizer)
|
||||
|
||||
def save_unsharded_model(self, model: LowLevelZeroModel, checkpoint: str, gather_dtensor: bool,
|
||||
use_safetensors: bool):
|
||||
assert isinstance(model, LowLevelZeroModel)
|
||||
super().save_unsharded_model(model.module, checkpoint, gather_dtensor, use_safetensors)
|
||||
|
||||
class LowLevelZeroModel(ModelWrapper):
|
||||
def save_sharded_model(self,
|
||||
model: nn.Module,
|
||||
checkpoint_path: str,
|
||||
gather_dtensor: bool = True,
|
||||
prefix: Optional[str] = None,
|
||||
max_shard_size: int = 1024,
|
||||
use_safetensors: bool = False):
|
||||
assert isinstance(model, LowLevelZeroModel)
|
||||
super().save_sharded_model(model.module, checkpoint_path, gather_dtensor, prefix, max_shard_size,
|
||||
use_safetensors)
|
||||
|
||||
def __init__(self, module: nn.Module, stage: int, precision: str) -> None:
|
||||
super().__init__(module)
|
||||
self.dtype = None
|
||||
if precision == 'fp16':
|
||||
self.dtype = torch.float16
|
||||
elif precision == 'bf16':
|
||||
self.dtype = torch.bfloat16
|
||||
module = zero_model_wrapper(module, zero_stage=stage)
|
||||
if self.dtype is not None:
|
||||
module = module.to(self.dtype)
|
||||
module = module.to(get_current_device())
|
||||
self.module = module
|
||||
self.convert_fn = None
|
||||
if self.dtype is not None:
|
||||
self.convert_fn = partial(_convert_floating_point, dtype=self.dtype)
|
||||
def load_unsharded_model(self, model: LowLevelZeroModel, checkpoint: str, strict: bool = True):
|
||||
assert isinstance(model, LowLevelZeroModel)
|
||||
super().load_unsharded_model(model.module, checkpoint, strict)
|
||||
model.update_master_params()
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
if self.convert_fn is not None:
|
||||
args = tree_map(self.convert_fn, args)
|
||||
kwargs = tree_map(self.convert_fn, kwargs)
|
||||
return super().forward(*args, **kwargs)
|
||||
def load_sharded_model(self,
|
||||
model: LowLevelZeroModel,
|
||||
checkpoint_index_file: Path,
|
||||
strict: bool = False,
|
||||
use_safetensors: bool = False,
|
||||
load_sub_module: bool = True):
|
||||
assert isinstance(model, LowLevelZeroModel)
|
||||
super().load_sharded_model(model.module, checkpoint_index_file, strict, use_safetensors, load_sub_module)
|
||||
model.update_master_params()
|
||||
|
||||
|
||||
class LowLevelZeroPlugin(DPPluginBase):
|
||||
@@ -248,22 +283,24 @@ class LowLevelZeroPlugin(DPPluginBase):
|
||||
super().__init__()
|
||||
assert stage in (1, 2), f'LowLevelZeroPlugin only supports stage 1/2 training'
|
||||
assert precision in SUPPORTED_PRECISION, f'LowLevelZeroPlugin only supports {SUPPORTED_PRECISION} training'
|
||||
|
||||
assert norm_type == 2.0, f'LowLevelZeroPlugin only supports norm_type=2.0 now'
|
||||
self.stage = stage
|
||||
self.precision = precision
|
||||
self.zero_optim_config = dict(reduce_bucket_size=reduce_bucket_size_in_m * 1024 * 1024,
|
||||
communication_dtype=communication_dtype,
|
||||
overlap_communication=overlap_communication,
|
||||
cpu_offload=cpu_offload)
|
||||
self.optim_kwargs = dict(initial_scale=initial_scale,
|
||||
growth_factor=growth_factor,
|
||||
backoff_factor=backoff_factor,
|
||||
growth_interval=growth_interval,
|
||||
hysteresis=hysteresis,
|
||||
min_scale=min_scale,
|
||||
max_scale=max_scale,
|
||||
max_norm=max_norm,
|
||||
norm_type=norm_type)
|
||||
self.zero_optim_kwargs = dict(
|
||||
initial_scale=initial_scale,
|
||||
growth_factor=growth_factor,
|
||||
backoff_factor=backoff_factor,
|
||||
growth_interval=growth_interval,
|
||||
hysteresis=hysteresis,
|
||||
min_scale=min_scale,
|
||||
max_scale=max_scale,
|
||||
clip_grad_norm=max_norm,
|
||||
reduce_bucket_size=reduce_bucket_size_in_m * 1024 * 1024,
|
||||
communication_dtype=communication_dtype,
|
||||
overlap_communication=overlap_communication,
|
||||
cpu_offload=cpu_offload,
|
||||
partition_grad=(stage == 2),
|
||||
)
|
||||
self.verbose = verbose
|
||||
|
||||
# set class name with stage, for better error message
|
||||
@@ -294,15 +331,15 @@ class LowLevelZeroPlugin(DPPluginBase):
|
||||
) -> Tuple[nn.Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]:
|
||||
|
||||
if not isinstance(model, ModelWrapper):
|
||||
model = LowLevelZeroModel(model, self.stage, self.precision)
|
||||
model = LowLevelZeroModel(model, self.precision)
|
||||
|
||||
if optimizer is not None and \
|
||||
not isinstance(optimizer, OptimizerWrapper):
|
||||
optimizer = zero_optim_wrapper(model.unwrap(),
|
||||
optimizer,
|
||||
optim_config=self.zero_optim_config,
|
||||
**self.optim_kwargs,
|
||||
verbose=self.verbose)
|
||||
optimizer: LowLevelZeroOptimizer = LowLevelZeroOptimizer(optimizer,
|
||||
**self.zero_optim_kwargs,
|
||||
verbose=self.verbose)
|
||||
# inject update_master_params
|
||||
model.update_master_params = MethodType(optimizer.update_master_params, model)
|
||||
|
||||
return model, optimizer, criterion, dataloader, lr_scheduler
|
||||
|
||||
|
Reference in New Issue
Block a user