mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 11:02:05 +00:00
[zero] support shard optimizer state dict of zero (#4194)
* support shard optimizer of zero * polish code * support sync grad manually
This commit is contained in:
@@ -1,5 +1,8 @@
|
||||
import logging
|
||||
import os
|
||||
import warnings
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from typing import Callable, Iterator, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
@@ -10,10 +13,16 @@ from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
|
||||
from torch.utils._pytree import tree_map
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from colossalai.checkpoint_io import CheckpointIO, GeneralCheckpointIO
|
||||
from colossalai.checkpoint_io import CheckpointIndexFile, CheckpointIO
|
||||
from colossalai.checkpoint_io.utils import (
|
||||
get_optimizer_base_filenames,
|
||||
get_shard_filename,
|
||||
save_param_groups,
|
||||
save_state_dict,
|
||||
)
|
||||
from colossalai.interface import ModelWrapper, OptimizerWrapper
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.zero import zero_model_wrapper, zero_optim_wrapper
|
||||
from colossalai.zero import LowLevelZeroOptimizer, zero_model_wrapper, zero_optim_wrapper
|
||||
|
||||
from .dp_plugin_base import DPPluginBase
|
||||
from .torch_ddp_plugin import TorchDDPCheckpointIO
|
||||
@@ -32,21 +41,104 @@ SUPPORTED_PRECISION = ['fp16', 'bf16', 'fp32']
|
||||
|
||||
class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
|
||||
|
||||
def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool):
|
||||
"""
|
||||
Save optimizer to checkpoint but only on master process.
|
||||
"""
|
||||
# TODO(ver217): optimizer state dict is sharded, and cannot get full state dict now
|
||||
warnings.warn(
|
||||
'LowLevelZeroPlugin does not support save full optimizer checkpoint now. Save it on every process.')
|
||||
checkpoint = f'{checkpoint}.rank{self.coordinator.rank}'
|
||||
GeneralCheckpointIO.save_unsharded_optimizer(self, optimizer, checkpoint, gather_dtensor)
|
||||
def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool = False):
|
||||
"""Save optimizer to checkpoint but only on master process.
|
||||
|
||||
def load_optimizer(self, optimizer: Optimizer, checkpoint: str):
|
||||
warnings.warn(
|
||||
'LowLevelZeroPlugin can only load optimizer checkpoint saved by itself with the same number of processes.')
|
||||
checkpoint = f'{checkpoint}.rank{self.coordinator.rank}'
|
||||
super().load_optimizer(optimizer, checkpoint)
|
||||
Args:
|
||||
optimizer (OptimizerWrapper): Optimizer to save state_dict
|
||||
checkpoint (str): Path to save checkpoint
|
||||
gather_dtensor (bool): Whether to gather_dtensor, not used
|
||||
"""
|
||||
|
||||
# the `state_dict` in LowLevelZeroOptimizer has communication
|
||||
# if only the master rank collect state_dict and save,
|
||||
# the communication on each rank would not match
|
||||
state_dict = optimizer.state_dict()
|
||||
if self.coordinator.is_master():
|
||||
save_state_dict(state_dict, checkpoint, use_safetensors=False)
|
||||
|
||||
def save_sharded_optimizer(self,
|
||||
optimizer: OptimizerWrapper,
|
||||
checkpoint: str,
|
||||
gather_dtensor: bool = False,
|
||||
prefix: str = None,
|
||||
size_per_shard: int = 1024):
|
||||
"""
|
||||
Save sharded Zero-optimizer checkpoint under the given checkpointing path.
|
||||
The following files will be created under the path:
|
||||
- An index file (pytorch_optim.bin.index.json) containing a map between optimizer states and file names
|
||||
- A group file (pytorch_optim_group.bin) recording information of param_groups
|
||||
- Multiple files (pytorch_optim-000XX.bin) that store state tensors of optimizer in a sharding way
|
||||
|
||||
Args:
|
||||
optimizer (OptimizerWrapper): Optimizer to save sharded state_dict
|
||||
checkpoint (str): Path to save optimizer state_dict
|
||||
gather_dtensor (bool): Whether to gather_dtensor, not used
|
||||
prefix (str): Perfix of file to save
|
||||
size_per_shard (int): Max file size of each file that store state tensors
|
||||
"""
|
||||
if os.path.isfile(checkpoint):
|
||||
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
|
||||
return
|
||||
|
||||
Path(checkpoint).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# state_dict only provide only 'param_groups'
|
||||
state_dict = optimizer.optim.state_dict()
|
||||
# state shard would be handled by the low-level zero optimizer
|
||||
sharded_state = optimizer.state_dict_shard(max_shard_size=size_per_shard)
|
||||
|
||||
# Preparing file paths and index file.
|
||||
states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix)
|
||||
index_file = CheckpointIndexFile(checkpoint)
|
||||
|
||||
# Store the information of param groups to param_group_file.
|
||||
index_file.append_meta_data("param_groups", param_group_file)
|
||||
group_file_path = os.path.join(checkpoint, param_group_file)
|
||||
save_param_groups(state_dict, group_file_path)
|
||||
|
||||
# Save shards of optimizer states.
|
||||
total_size = 0
|
||||
for idx, shard_pair in enumerate(sharded_state):
|
||||
shard, current_size = shard_pair
|
||||
shard_file = get_shard_filename(states_name, idx)
|
||||
total_size = total_size + current_size
|
||||
for param_id in shard.keys():
|
||||
index_file.append_weight_map(str(param_id), shard_file)
|
||||
|
||||
checkpoint_file_path = os.path.join(checkpoint, shard_file)
|
||||
if self.coordinator.is_master():
|
||||
save_state_dict(shard, checkpoint_file_path, use_safetensors=False)
|
||||
|
||||
# Wrap up index file.
|
||||
index_file.append_meta_data("total_size", total_size)
|
||||
if self.coordinator.is_master():
|
||||
index_file.write_index_file(save_index_file)
|
||||
logging.info(f"The optimizer is going to be split to checkpoint shards. "
|
||||
f"You can find where each parameters has been saved in the "
|
||||
f"index located at {save_index_file}.")
|
||||
|
||||
def load_sharded_optimizer(self, optimizer: OptimizerWrapper, index_file_path: str, prefix: str):
|
||||
"""Load sharded optimizer with the given path to index file.
|
||||
|
||||
Args:
|
||||
optimizer (OptimizerWrapper): Optimizer to load state_dict
|
||||
index_file_path (str): Path to the index file
|
||||
prefix (str): Not used.
|
||||
"""
|
||||
super().load_sharded_optimizer(optimizer, index_file_path, prefix)
|
||||
current_rank_state_dict = optimizer.optim.state_dict()['state']
|
||||
for param_idx, state in current_rank_state_dict.items():
|
||||
for k, v in state.items():
|
||||
if isinstance(v, torch.Tensor) and k != 'step':
|
||||
padding_size = (self.coordinator.world_size -
|
||||
v.numel() % self.coordinator.world_size) % self.coordinator.world_size
|
||||
with torch.no_grad():
|
||||
v = v.flatten()
|
||||
if padding_size > 0:
|
||||
v = torch.nn.functional.pad(v, [0, padding_size])
|
||||
v_list = v.split(v.numel() // self.coordinator.world_size)
|
||||
current_rank_state_dict[param_idx][k] = v_list[self.coordinator.rank].detach()
|
||||
|
||||
|
||||
class LowLevelZeroModel(ModelWrapper):
|
||||
@@ -74,36 +166,6 @@ class LowLevelZeroModel(ModelWrapper):
|
||||
return super().forward(*args, **kwargs)
|
||||
|
||||
|
||||
class LowLevelZeroOptimizer(OptimizerWrapper):
|
||||
|
||||
def __init__(self,
|
||||
module: nn.Module,
|
||||
optimizer: Optimizer,
|
||||
zero_optim_config: dict,
|
||||
optim_kwargs: dict,
|
||||
verbose: bool = False) -> None:
|
||||
optimizer = zero_optim_wrapper(module,
|
||||
optimizer,
|
||||
optim_config=zero_optim_config,
|
||||
**optim_kwargs,
|
||||
verbose=verbose)
|
||||
super().__init__(optimizer)
|
||||
|
||||
def backward(self, loss: Tensor, *args, **kwargs):
|
||||
self.optim.backward(loss)
|
||||
|
||||
def clip_grad_by_norm(self,
|
||||
max_norm: Union[float, int],
|
||||
norm_type: Union[float, int] = 2,
|
||||
error_if_nonfinite: bool = False,
|
||||
*args,
|
||||
**kwargs) -> Tensor:
|
||||
warnings.warn(f'LowLevelZero controls grad clipping by itself, so you should not use clip_grad_by_norm')
|
||||
|
||||
def clip_grad_by_value(self, clip_value: float, *args, **kwargs) -> None:
|
||||
raise NotImplementedError('LowLevelZero does not support clip_grad_by_value')
|
||||
|
||||
|
||||
class LowLevelZeroPlugin(DPPluginBase):
|
||||
"""
|
||||
Plugin for low level zero.
|
||||
@@ -211,8 +273,11 @@ class LowLevelZeroPlugin(DPPluginBase):
|
||||
|
||||
if optimizer is not None and \
|
||||
not isinstance(optimizer, OptimizerWrapper):
|
||||
optimizer = LowLevelZeroOptimizer(model.unwrap(), optimizer, self.zero_optim_config, self.optim_kwargs,
|
||||
self.verbose)
|
||||
optimizer = zero_optim_wrapper(model.unwrap(),
|
||||
optimizer,
|
||||
optim_config=self.zero_optim_config,
|
||||
**self.optim_kwargs,
|
||||
verbose=self.verbose)
|
||||
|
||||
return model, optimizer, criterion, dataloader, lr_scheduler
|
||||
|
||||
|
Reference in New Issue
Block a user