mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-04 18:40:28 +00:00
[checkpointio] Sharded Optimizer Checkpoint for Gemini Plugin (#4302)
* sharded optimizer checkpoint for gemini plugin * modify test to reduce testing time * update doc * fix bug when keep_gatherd is true under GeminiPlugin
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
import gc
|
||||
import logging
|
||||
import os
|
||||
import warnings
|
||||
@@ -12,11 +13,19 @@ from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from colossalai.checkpoint_io import CheckpointIndexFile, CheckpointIO, GeneralCheckpointIO
|
||||
from colossalai.checkpoint_io.utils import get_model_base_filenames, get_shard_filename, save_state_dict
|
||||
from colossalai.checkpoint_io.utils import (
|
||||
get_model_base_filenames,
|
||||
get_optimizer_base_filenames,
|
||||
get_shard_filename,
|
||||
load_shard_state_dict,
|
||||
save_state_dict,
|
||||
save_state_dict_shards,
|
||||
)
|
||||
from colossalai.cluster import DistCoordinator
|
||||
from colossalai.interface import ModelWrapper, OptimizerWrapper
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.zero import GeminiDDP, zero_model_wrapper, zero_optim_wrapper
|
||||
from colossalai.zero.gemini import ZeroOptimizer
|
||||
from colossalai.zero.gemini.memory_tracer import MemStats
|
||||
|
||||
from .dp_plugin_base import DPPluginBase
|
||||
@@ -37,7 +46,7 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
|
||||
"""
|
||||
Save sharded model to checkpoint but only on master process.
|
||||
The model should be unwrapped in self.load_model via ModelWrapper.unwrap.
|
||||
As there is communication when getting state dict, this must be called on all processes.
|
||||
As there is communication when getting state dict, model.state_dict() must be called on all processes.
|
||||
"""
|
||||
state_dict = model.state_dict(only_rank_0=True)
|
||||
if self.coordinator.is_master():
|
||||
@@ -54,7 +63,7 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
|
||||
"""
|
||||
Save unsharded optimizer state dict to checkpoint.
|
||||
After calling optimizer.state_dict(), the complete optimizer states will be collected on master rank.
|
||||
As there is communication when getting state dict, this must be called on all processes.
|
||||
As there is communication when getting state dict, optimizer.state_dict() must be called on all processes.
|
||||
The saving process will only be executed by master rank.
|
||||
"""
|
||||
state_dict = optimizer.state_dict()
|
||||
@@ -76,7 +85,8 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
|
||||
max_shard_size: int = 1024,
|
||||
use_safetensors: bool = False):
|
||||
"""
|
||||
Save sharded model
|
||||
Save sharded model.
|
||||
As there is communication when getting state dict, model.state_dict() must be called on all processes.
|
||||
"""
|
||||
if os.path.isfile(checkpoint_path):
|
||||
logging.error(f"Provided path ({checkpoint_path}) should be a directory, not a file")
|
||||
@@ -86,28 +96,24 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
|
||||
|
||||
state_dict_shard = model.state_dict_shard(max_shard_size=max_shard_size, only_rank_0=True, dtype=torch.float32)
|
||||
weights_name, save_index_file = get_model_base_filenames(prefix, use_safetensors)
|
||||
total_size = 0
|
||||
index_file = CheckpointIndexFile(checkpoint_path)
|
||||
for idx, shard_pair in enumerate(state_dict_shard):
|
||||
if not self.coordinator.is_master():
|
||||
continue
|
||||
shard = shard_pair[0]
|
||||
shard_file = get_shard_filename(weights_name, idx)
|
||||
total_size = total_size + shard_pair[1]
|
||||
for key in shard.keys():
|
||||
index_file.append_weight_map(key, shard_file)
|
||||
|
||||
checkpoint_file_path = os.path.join(checkpoint_path, shard_file)
|
||||
save_state_dict(shard, checkpoint_file_path, use_safetensors)
|
||||
|
||||
index_file.append_meta_data("total_size", total_size)
|
||||
# Save shards of optimizer states.
|
||||
is_master = self.coordinator.is_master()
|
||||
total_size = save_state_dict_shards(sharded_state_dict=state_dict_shard,
|
||||
checkpoint=checkpoint_path,
|
||||
index_file=index_file,
|
||||
base_filename=weights_name,
|
||||
is_master=is_master,
|
||||
use_safetensors=use_safetensors)
|
||||
|
||||
# only save the index file on the master rank
|
||||
if self.coordinator.is_master():
|
||||
index_file.append_meta_data("total_size", total_size)
|
||||
index_file.write_index_file(save_index_file)
|
||||
logging.info(f"The model is split into checkpoint shards. "
|
||||
f"You can find where each parameters has been saved in the "
|
||||
f"index located at {save_index_file}.")
|
||||
logging.info(f"The model is split into checkpoint shards. "
|
||||
f"You can find where each parameters has been saved in the "
|
||||
f"index located at {save_index_file}.")
|
||||
|
||||
def load_sharded_model(self,
|
||||
model: GeminiDDP,
|
||||
@@ -115,7 +121,7 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
|
||||
strict: bool = False,
|
||||
use_safetensors: bool = False):
|
||||
"""
|
||||
load shard model, load model from multiple files
|
||||
Load shard model, load model from multiple files.
|
||||
"""
|
||||
return super().load_sharded_model(model, checkpoint_index_file, strict, use_safetensors, load_sub_module=False)
|
||||
|
||||
@@ -125,16 +131,93 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
|
||||
Save sharded optimizer state dict to checkpoint folder.
|
||||
As there is communication when getting state dict, this must be called on all processes.
|
||||
"""
|
||||
|
||||
# If optimizer is wrapped, unwrap it.
|
||||
if isinstance(optimizer, OptimizerWrapper):
|
||||
optimizer = optimizer.unwrap()
|
||||
|
||||
assert isinstance(optimizer, ZeroOptimizer)
|
||||
|
||||
if os.path.isfile(checkpoint):
|
||||
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
|
||||
return
|
||||
|
||||
Path(checkpoint).mkdir(parents=True, exist_ok=True)
|
||||
super().save_sharded_optimizer(optimizer, checkpoint, gather_dtensor, prefix, size_per_shard)
|
||||
|
||||
# Preparing file paths and index file.
|
||||
states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix)
|
||||
index_file = CheckpointIndexFile(checkpoint)
|
||||
|
||||
# Store the information of param groups to param_group_file.
|
||||
index_file.append_meta_data("param_groups", param_group_file)
|
||||
group_file_path = os.path.join(checkpoint, param_group_file)
|
||||
param_groups = optimizer.get_param_groups_for_saving()
|
||||
torch.save(param_groups, group_file_path)
|
||||
|
||||
# States are broken into shards within max_shard_size.
|
||||
state_dict_shard = optimizer.state_shard(prefix=prefix, max_shard_size=size_per_shard, only_rank_0=True)
|
||||
|
||||
# Save shards of optimizer states.
|
||||
is_master = self.coordinator.is_master()
|
||||
total_size = save_state_dict_shards(sharded_state_dict=state_dict_shard,
|
||||
checkpoint=checkpoint,
|
||||
index_file=index_file,
|
||||
base_filename=states_name,
|
||||
is_master=is_master,
|
||||
use_safetensors=False)
|
||||
|
||||
# Wrap up index file. Only save it on master rank.
|
||||
if self.coordinator.is_master():
|
||||
index_file.append_meta_data("total_size", total_size)
|
||||
index_file.write_index_file(save_index_file)
|
||||
logging.info(f"The optimizer is going to be split to checkpoint shards. "
|
||||
f"You can find where each parameters has been saved in the "
|
||||
f"index located at {save_index_file}.")
|
||||
|
||||
def load_sharded_optimizer(self, optimizer: Optimizer, checkpoint_index_file: Path, prefix: str):
|
||||
"""
|
||||
Loading sharded optimizer from checkpoint folder, with index file given.
|
||||
For each process, only loading optimizer states of parameters it controls.
|
||||
"""
|
||||
# TODO(Baizhou): To be implemented.
|
||||
pass
|
||||
|
||||
if not os.path.isfile(checkpoint_index_file):
|
||||
logging.error(f"Provided path ({checkpoint_index_file}) should be a file")
|
||||
|
||||
# If optimizer is wrapped, unwrap it.
|
||||
if isinstance(optimizer, OptimizerWrapper):
|
||||
optimizer = optimizer.unwrap()
|
||||
|
||||
assert isinstance(optimizer, ZeroOptimizer)
|
||||
|
||||
# Read checkpoint index file.
|
||||
ckpt_index_file = CheckpointIndexFile.from_file(checkpoint_index_file)
|
||||
|
||||
# Load param_groups.
|
||||
param_group_path = ckpt_index_file.get_param_group_filename()
|
||||
if param_group_path is None:
|
||||
raise RuntimeError(f'Invalid index file path {checkpoint_index_file} for an optimizer. \
|
||||
Lacking param group file under current directory.')
|
||||
saved_param_groups = torch.load(param_group_path)
|
||||
optimizer.load_param_groups(saved_param_groups)
|
||||
|
||||
checkpoint_files, _ = ckpt_index_file.get_checkpoint_filenames()
|
||||
|
||||
# Load optimizer states from shard files under checkpoint path.
|
||||
# For each file, only load the states managed by current process.
|
||||
for shard_file in checkpoint_files:
|
||||
state_dict_shard = load_shard_state_dict(Path(shard_file), use_safetensors=False)
|
||||
optimizer.load_param_states(state_dict_shard)
|
||||
del state_dict_shard
|
||||
gc.collect()
|
||||
|
||||
optimizer.optimizer_loading_epilogue()
|
||||
|
||||
def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str):
|
||||
"""
|
||||
Save model to checkpoint but only on master process.
|
||||
"""
|
||||
if self.coordinator.is_master():
|
||||
super().save_lr_scheduler(lr_scheduler, checkpoint)
|
||||
|
||||
|
||||
class GeminiModel(ModelWrapper):
|
||||
|
Reference in New Issue
Block a user