[zero] hotfix master param sync (#4618)

* [zero] add method to update master params

* [zero] update zero plugin

* [plugin] update low level zero plugin
This commit is contained in:
Hongxin Liu
2023-09-05 15:04:02 +08:00
committed by GitHub
parent aaeb520ce3
commit 807e01a4ba
5 changed files with 122 additions and 45 deletions

View File

@@ -3,6 +3,7 @@ import os
import warnings
from functools import partial
from pathlib import Path
from types import MethodType
from typing import Callable, Iterator, List, Optional, Tuple, Union
import torch
@@ -25,9 +26,9 @@ from colossalai.checkpoint_io.utils import (
sharded_optimizer_loading_epilogue,
unwrap_optimizer,
)
from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper
from colossalai.utils import get_current_device
from colossalai.zero import LowLevelZeroOptimizer, zero_model_wrapper, zero_optim_wrapper
from colossalai.zero import LowLevelZeroOptimizer
from .dp_plugin_base import DPPluginBase
from .torch_ddp_plugin import TorchDDPCheckpointIO
@@ -44,6 +45,34 @@ def _convert_floating_point(x, dtype: torch.dtype = torch.float16):
SUPPORTED_PRECISION = ['fp16', 'bf16', 'fp32']
class LowLevelZeroModel(ModelWrapper, AMPModelMixin):
def __init__(self, module: nn.Module, precision: str) -> None:
super().__init__(module)
self.dtype = None
if precision == 'fp16':
self.dtype = torch.float16
elif precision == 'bf16':
self.dtype = torch.bfloat16
if self.dtype is not None:
module = module.to(self.dtype)
module = module.to(get_current_device())
self.module = module
self.convert_fn = None
if self.dtype is not None:
self.convert_fn = partial(_convert_floating_point, dtype=self.dtype)
def forward(self, *args, **kwargs):
if self.convert_fn is not None:
args = tree_map(self.convert_fn, args)
kwargs = tree_map(self.convert_fn, kwargs)
return super().forward(*args, **kwargs)
def unwrap(self):
# TODO(ver217): this is a workaround for loading model
return self
class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool = False):
@@ -165,30 +194,36 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
sharded_optimizer_loading_epilogue(optimizer)
def save_unsharded_model(self, model: LowLevelZeroModel, checkpoint: str, gather_dtensor: bool,
use_safetensors: bool):
assert isinstance(model, LowLevelZeroModel)
super().save_unsharded_model(model.module, checkpoint, gather_dtensor, use_safetensors)
class LowLevelZeroModel(ModelWrapper):
def save_sharded_model(self,
model: nn.Module,
checkpoint_path: str,
gather_dtensor: bool = True,
prefix: Optional[str] = None,
max_shard_size: int = 1024,
use_safetensors: bool = False):
assert isinstance(model, LowLevelZeroModel)
super().save_sharded_model(model.module, checkpoint_path, gather_dtensor, prefix, max_shard_size,
use_safetensors)
def __init__(self, module: nn.Module, stage: int, precision: str) -> None:
super().__init__(module)
self.dtype = None
if precision == 'fp16':
self.dtype = torch.float16
elif precision == 'bf16':
self.dtype = torch.bfloat16
module = zero_model_wrapper(module, zero_stage=stage)
if self.dtype is not None:
module = module.to(self.dtype)
module = module.to(get_current_device())
self.module = module
self.convert_fn = None
if self.dtype is not None:
self.convert_fn = partial(_convert_floating_point, dtype=self.dtype)
def load_unsharded_model(self, model: LowLevelZeroModel, checkpoint: str, strict: bool = True):
assert isinstance(model, LowLevelZeroModel)
super().load_unsharded_model(model.module, checkpoint, strict)
model.update_master_params()
def forward(self, *args, **kwargs):
if self.convert_fn is not None:
args = tree_map(self.convert_fn, args)
kwargs = tree_map(self.convert_fn, kwargs)
return super().forward(*args, **kwargs)
def load_sharded_model(self,
model: LowLevelZeroModel,
checkpoint_index_file: Path,
strict: bool = False,
use_safetensors: bool = False,
load_sub_module: bool = True):
assert isinstance(model, LowLevelZeroModel)
super().load_sharded_model(model.module, checkpoint_index_file, strict, use_safetensors, load_sub_module)
model.update_master_params()
class LowLevelZeroPlugin(DPPluginBase):
@@ -248,22 +283,24 @@ class LowLevelZeroPlugin(DPPluginBase):
super().__init__()
assert stage in (1, 2), f'LowLevelZeroPlugin only supports stage 1/2 training'
assert precision in SUPPORTED_PRECISION, f'LowLevelZeroPlugin only supports {SUPPORTED_PRECISION} training'
assert norm_type == 2.0, f'LowLevelZeroPlugin only supports norm_type=2.0 now'
self.stage = stage
self.precision = precision
self.zero_optim_config = dict(reduce_bucket_size=reduce_bucket_size_in_m * 1024 * 1024,
communication_dtype=communication_dtype,
overlap_communication=overlap_communication,
cpu_offload=cpu_offload)
self.optim_kwargs = dict(initial_scale=initial_scale,
growth_factor=growth_factor,
backoff_factor=backoff_factor,
growth_interval=growth_interval,
hysteresis=hysteresis,
min_scale=min_scale,
max_scale=max_scale,
max_norm=max_norm,
norm_type=norm_type)
self.zero_optim_kwargs = dict(
initial_scale=initial_scale,
growth_factor=growth_factor,
backoff_factor=backoff_factor,
growth_interval=growth_interval,
hysteresis=hysteresis,
min_scale=min_scale,
max_scale=max_scale,
clip_grad_norm=max_norm,
reduce_bucket_size=reduce_bucket_size_in_m * 1024 * 1024,
communication_dtype=communication_dtype,
overlap_communication=overlap_communication,
cpu_offload=cpu_offload,
partition_grad=(stage == 2),
)
self.verbose = verbose
# set class name with stage, for better error message
@@ -294,15 +331,15 @@ class LowLevelZeroPlugin(DPPluginBase):
) -> Tuple[nn.Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]:
if not isinstance(model, ModelWrapper):
model = LowLevelZeroModel(model, self.stage, self.precision)
model = LowLevelZeroModel(model, self.precision)
if optimizer is not None and \
not isinstance(optimizer, OptimizerWrapper):
optimizer = zero_optim_wrapper(model.unwrap(),
optimizer,
optim_config=self.zero_optim_config,
**self.optim_kwargs,
verbose=self.verbose)
optimizer: LowLevelZeroOptimizer = LowLevelZeroOptimizer(optimizer,
**self.zero_optim_kwargs,
verbose=self.verbose)
# inject update_master_params
model.update_master_params = MethodType(optimizer.update_master_params, model)
return model, optimizer, criterion, dataloader, lr_scheduler