[zero] support shard optimizer state dict of zero (#4194)

* support shard optimizer of zero

* polish code

* support sync grad manually
This commit is contained in:
LuGY
2023-07-11 18:03:13 +08:00
committed by Hongxin Liu
parent dd7cc58299
commit 1a49a5ea00
4 changed files with 239 additions and 68 deletions

View File

@@ -1,5 +1,8 @@
import logging
import os
import warnings
from functools import partial
from pathlib import Path
from typing import Callable, Iterator, List, Optional, Tuple, Union
import torch
@@ -10,10 +13,16 @@ from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from torch.utils._pytree import tree_map
from torch.utils.data import DataLoader
from colossalai.checkpoint_io import CheckpointIO, GeneralCheckpointIO
from colossalai.checkpoint_io import CheckpointIndexFile, CheckpointIO
from colossalai.checkpoint_io.utils import (
get_optimizer_base_filenames,
get_shard_filename,
save_param_groups,
save_state_dict,
)
from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.utils import get_current_device
from colossalai.zero import zero_model_wrapper, zero_optim_wrapper
from colossalai.zero import LowLevelZeroOptimizer, zero_model_wrapper, zero_optim_wrapper
from .dp_plugin_base import DPPluginBase
from .torch_ddp_plugin import TorchDDPCheckpointIO
@@ -32,21 +41,104 @@ SUPPORTED_PRECISION = ['fp16', 'bf16', 'fp32']
class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool):
"""
Save optimizer to checkpoint but only on master process.
"""
# TODO(ver217): optimizer state dict is sharded, and cannot get full state dict now
warnings.warn(
'LowLevelZeroPlugin does not support save full optimizer checkpoint now. Save it on every process.')
checkpoint = f'{checkpoint}.rank{self.coordinator.rank}'
GeneralCheckpointIO.save_unsharded_optimizer(self, optimizer, checkpoint, gather_dtensor)
def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool = False):
"""Save optimizer to checkpoint but only on master process.
def load_optimizer(self, optimizer: Optimizer, checkpoint: str):
warnings.warn(
'LowLevelZeroPlugin can only load optimizer checkpoint saved by itself with the same number of processes.')
checkpoint = f'{checkpoint}.rank{self.coordinator.rank}'
super().load_optimizer(optimizer, checkpoint)
Args:
optimizer (OptimizerWrapper): Optimizer to save state_dict
checkpoint (str): Path to save checkpoint
gather_dtensor (bool): Whether to gather_dtensor, not used
"""
# the `state_dict` in LowLevelZeroOptimizer has communication
# if only the master rank collect state_dict and save,
# the communication on each rank would not match
state_dict = optimizer.state_dict()
if self.coordinator.is_master():
save_state_dict(state_dict, checkpoint, use_safetensors=False)
def save_sharded_optimizer(self,
optimizer: OptimizerWrapper,
checkpoint: str,
gather_dtensor: bool = False,
prefix: str = None,
size_per_shard: int = 1024):
"""
Save sharded Zero-optimizer checkpoint under the given checkpointing path.
The following files will be created under the path:
- An index file (pytorch_optim.bin.index.json) containing a map between optimizer states and file names
- A group file (pytorch_optim_group.bin) recording information of param_groups
- Multiple files (pytorch_optim-000XX.bin) that store state tensors of optimizer in a sharding way
Args:
optimizer (OptimizerWrapper): Optimizer to save sharded state_dict
checkpoint (str): Path to save optimizer state_dict
gather_dtensor (bool): Whether to gather_dtensor, not used
prefix (str): Perfix of file to save
size_per_shard (int): Max file size of each file that store state tensors
"""
if os.path.isfile(checkpoint):
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
return
Path(checkpoint).mkdir(parents=True, exist_ok=True)
# state_dict only provide only 'param_groups'
state_dict = optimizer.optim.state_dict()
# state shard would be handled by the low-level zero optimizer
sharded_state = optimizer.state_dict_shard(max_shard_size=size_per_shard)
# Preparing file paths and index file.
states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix)
index_file = CheckpointIndexFile(checkpoint)
# Store the information of param groups to param_group_file.
index_file.append_meta_data("param_groups", param_group_file)
group_file_path = os.path.join(checkpoint, param_group_file)
save_param_groups(state_dict, group_file_path)
# Save shards of optimizer states.
total_size = 0
for idx, shard_pair in enumerate(sharded_state):
shard, current_size = shard_pair
shard_file = get_shard_filename(states_name, idx)
total_size = total_size + current_size
for param_id in shard.keys():
index_file.append_weight_map(str(param_id), shard_file)
checkpoint_file_path = os.path.join(checkpoint, shard_file)
if self.coordinator.is_master():
save_state_dict(shard, checkpoint_file_path, use_safetensors=False)
# Wrap up index file.
index_file.append_meta_data("total_size", total_size)
if self.coordinator.is_master():
index_file.write_index_file(save_index_file)
logging.info(f"The optimizer is going to be split to checkpoint shards. "
f"You can find where each parameters has been saved in the "
f"index located at {save_index_file}.")
def load_sharded_optimizer(self, optimizer: OptimizerWrapper, index_file_path: str, prefix: str):
"""Load sharded optimizer with the given path to index file.
Args:
optimizer (OptimizerWrapper): Optimizer to load state_dict
index_file_path (str): Path to the index file
prefix (str): Not used.
"""
super().load_sharded_optimizer(optimizer, index_file_path, prefix)
current_rank_state_dict = optimizer.optim.state_dict()['state']
for param_idx, state in current_rank_state_dict.items():
for k, v in state.items():
if isinstance(v, torch.Tensor) and k != 'step':
padding_size = (self.coordinator.world_size -
v.numel() % self.coordinator.world_size) % self.coordinator.world_size
with torch.no_grad():
v = v.flatten()
if padding_size > 0:
v = torch.nn.functional.pad(v, [0, padding_size])
v_list = v.split(v.numel() // self.coordinator.world_size)
current_rank_state_dict[param_idx][k] = v_list[self.coordinator.rank].detach()
class LowLevelZeroModel(ModelWrapper):
@@ -74,36 +166,6 @@ class LowLevelZeroModel(ModelWrapper):
return super().forward(*args, **kwargs)
class LowLevelZeroOptimizer(OptimizerWrapper):
def __init__(self,
module: nn.Module,
optimizer: Optimizer,
zero_optim_config: dict,
optim_kwargs: dict,
verbose: bool = False) -> None:
optimizer = zero_optim_wrapper(module,
optimizer,
optim_config=zero_optim_config,
**optim_kwargs,
verbose=verbose)
super().__init__(optimizer)
def backward(self, loss: Tensor, *args, **kwargs):
self.optim.backward(loss)
def clip_grad_by_norm(self,
max_norm: Union[float, int],
norm_type: Union[float, int] = 2,
error_if_nonfinite: bool = False,
*args,
**kwargs) -> Tensor:
warnings.warn(f'LowLevelZero controls grad clipping by itself, so you should not use clip_grad_by_norm')
def clip_grad_by_value(self, clip_value: float, *args, **kwargs) -> None:
raise NotImplementedError('LowLevelZero does not support clip_grad_by_value')
class LowLevelZeroPlugin(DPPluginBase):
"""
Plugin for low level zero.
@@ -211,8 +273,11 @@ class LowLevelZeroPlugin(DPPluginBase):
if optimizer is not None and \
not isinstance(optimizer, OptimizerWrapper):
optimizer = LowLevelZeroOptimizer(model.unwrap(), optimizer, self.zero_optim_config, self.optim_kwargs,
self.verbose)
optimizer = zero_optim_wrapper(model.unwrap(),
optimizer,
optim_config=self.zero_optim_config,
**self.optim_kwargs,
verbose=self.verbose)
return model, optimizer, criterion, dataloader, lr_scheduler