mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 03:20:52 +00:00
[checkpointio] support unsharded checkpointIO for hybrid parallel (#4774)
* support unsharded saving/loading for model * support optimizer unsharded saving * update doc * support unsharded loading for optimizer * small fix
This commit is contained in:
@@ -9,7 +9,6 @@ import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from torch.distributed import ProcessGroup
|
||||
from torch.optim import Optimizer
|
||||
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
|
||||
|
||||
from colossalai.cluster import DistCoordinator
|
||||
@@ -24,10 +23,12 @@ from .utils import (
|
||||
get_optimizer_base_filenames,
|
||||
is_safetensors_available,
|
||||
load_shard_state_dict,
|
||||
load_state_dict,
|
||||
load_state_dict_into_model,
|
||||
load_states_into_optimizer,
|
||||
save_config_file,
|
||||
save_param_groups,
|
||||
save_state_dict,
|
||||
save_state_dict_shards,
|
||||
search_tp_partition_dim,
|
||||
sharded_optimizer_loading_epilogue,
|
||||
@@ -119,13 +120,13 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
||||
use_zero: bool,
|
||||
dp_group: ProcessGroup,
|
||||
tp_group: ProcessGroup,
|
||||
master_to_working_map: Optional[Dict[int, torch.Tensor]] = None,
|
||||
size_per_shard: int = 1024,
|
||||
):
|
||||
# An internel method that breaks state_dict of optimizer into shards within limited size.
|
||||
|
||||
state_dict_sharder = StateDictSharder(size_per_shard)
|
||||
param_info = optimizer.param_info
|
||||
master_to_working_map = optimizer.get_master_to_working_map()
|
||||
|
||||
for param, state in optimizer.optim.state.items():
|
||||
if param is None:
|
||||
@@ -217,7 +218,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
||||
index_file.append_meta_data("total_size", total_size)
|
||||
index_file.write_index_file(save_index_file)
|
||||
save_config_file(model, checkpoint)
|
||||
if self.verbose:
|
||||
if self.verbose and self.coordinator.is_master():
|
||||
logging.info(
|
||||
f"The model is split into checkpoint shards. "
|
||||
f"You can find where each parameters has been saved in the "
|
||||
@@ -273,7 +274,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
||||
final_index_file.write_index_file(final_index_file_path)
|
||||
save_config_file(model, checkpoint)
|
||||
rmtree(tmp_index_file_folder)
|
||||
if self.verbose:
|
||||
if self.verbose and self.coordinator.is_master():
|
||||
logging.info(
|
||||
f"The model is split into checkpoint shards. "
|
||||
f"You can find where each parameters has been saved in the "
|
||||
@@ -353,7 +354,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
||||
# Update master params if mixed-precision training is enabled.
|
||||
model_before_wrapping.update_master_params()
|
||||
|
||||
if self.verbose:
|
||||
if self.verbose and self.coordinator.is_master():
|
||||
logging.info(f"The model has been successfully loaded from sharded checkpoint: {ckpt_root_path}.")
|
||||
|
||||
def save_sharded_optimizer(
|
||||
@@ -399,7 +400,6 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
||||
use_zero=self.use_zero,
|
||||
dp_group=self.dp_group,
|
||||
tp_group=self.tp_group,
|
||||
master_to_working_map=optimizer.get_master_to_working_map(),
|
||||
size_per_shard=size_per_shard,
|
||||
)
|
||||
states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix)
|
||||
@@ -424,7 +424,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
||||
# Store index file.
|
||||
index_file.append_meta_data("total_size", total_size)
|
||||
index_file.write_index_file(save_index_file)
|
||||
if self.verbose:
|
||||
if self.verbose and self.coordinator.is_master():
|
||||
logging.info(
|
||||
f"The optimizer is going to be split to checkpoint shards. "
|
||||
f"You can find where each parameters has been saved in the "
|
||||
@@ -484,7 +484,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
||||
final_index_file.write_index_file(final_index_file_path)
|
||||
rmtree(tmp_index_file_folder)
|
||||
|
||||
if self.verbose:
|
||||
if self.verbose and self.coordinator.is_master():
|
||||
logging.info(
|
||||
f"The model is split into checkpoint shards. "
|
||||
f"You can find where each parameters has been saved in the "
|
||||
@@ -579,24 +579,196 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
||||
optimizer.optim.state[param] = sharded_state
|
||||
|
||||
sharded_optimizer_loading_epilogue(optimizer.optim)
|
||||
if self.verbose:
|
||||
if self.verbose and self.coordinator.is_master():
|
||||
logging.info(f"The optimizer has been successfully loaded from sharded checkpoint: {ckpt_root_path}.")
|
||||
|
||||
def load_unsharded_model(self, model: nn.Module, checkpoint: str, strict: bool = True):
|
||||
# TODO(Baizhou): support this feature after implementing complete state_dict collection
|
||||
raise NotImplementedError
|
||||
def save_unsharded_model(self, model: ModelWrapper, checkpoint: str, gather_dtensor: bool, use_safetensors: bool):
|
||||
"""
|
||||
Save model state dict to a single file with given checkpointing path.
|
||||
|
||||
def save_unsharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, use_safetensors: bool):
|
||||
# TODO(Baizhou): support this feature after implementing complete state_dict collection
|
||||
raise NotImplementedError
|
||||
Args:
|
||||
model (nn.Module): Model on local device to be saved.
|
||||
checkpoint (str): Checkpointing path which should be a file path. Can be absolute or relative path.
|
||||
gather_dtensor (bool, optional): Whether to gather dtensor, currently not used. Defaults to True.
|
||||
use_safetensors (bool, optional): Whether to use safe tensors. Defaults to False.
|
||||
"""
|
||||
if self.coordinator.is_master():
|
||||
logging.warning("Please avoid using unsharded checkpointing methods when dealing with large models!")
|
||||
|
||||
def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool):
|
||||
# TODO(Baizhou): support this feature after implementing complete state_dict collection
|
||||
raise NotImplementedError
|
||||
assert isinstance(model, ModelWrapper), "Please boost the model before saving!"
|
||||
model = model.unwrap()
|
||||
|
||||
def load_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool):
|
||||
# TODO(Baizhou): support this feature after implementing complete state_dict collection
|
||||
raise NotImplementedError
|
||||
if self.dp_rank != 0:
|
||||
return
|
||||
|
||||
# The logic of collecting parameter shards along tp degree
|
||||
# has been implemented by _save_to_state_dict method of ParallelModule in Shardformer.
|
||||
state_dict = model.state_dict()
|
||||
|
||||
if self.pp_size == 1:
|
||||
# When pipeline is not used, let master rank directly save the collected state_dict.
|
||||
if self.tp_rank == 0:
|
||||
save_state_dict(state_dict, checkpoint, use_safetensors)
|
||||
else:
|
||||
# When pipeline is used, first collect state_dict from every pipeline stage, then save the complete state_dict.
|
||||
state_dict_list = [None for _ in range(self.pp_size)]
|
||||
dist.barrier(self.pp_group)
|
||||
dist.all_gather_object(state_dict_list, state_dict, self.pp_group)
|
||||
|
||||
# Only the master rank do the saving.
|
||||
if self.coordinator.is_master():
|
||||
complete_state_dict = dict()
|
||||
for _state_dict in state_dict_list:
|
||||
complete_state_dict.update(_state_dict)
|
||||
save_state_dict(complete_state_dict, checkpoint, use_safetensors)
|
||||
|
||||
def load_unsharded_model(self, model: ModelWrapper, checkpoint: str, strict: bool = False):
|
||||
"""
|
||||
Load model from a single file with the given path of checkpoint.
|
||||
|
||||
Args:
|
||||
model (nn.Module): The model to be loaded.
|
||||
checkpoint_index_file (str): Path to the checkpoint file.
|
||||
strict (bool, optional): For name matching during loading state_dict. Defaults to False.
|
||||
This argument should be manually set to False since not all params in checkpoint are needed for each device when pipeline is enabled.
|
||||
"""
|
||||
if self.coordinator.is_master():
|
||||
logging.warning("Please avoid using unsharded checkpointing methods when dealing with large models!")
|
||||
|
||||
assert isinstance(model, ModelWrapper), "Please boost the model before loading!"
|
||||
strict = False
|
||||
model_before_wrapping = model
|
||||
model = model.unwrap()
|
||||
|
||||
# Load from checkpoint. Since the logic of breaking parameter shards along tp degree
|
||||
# has been implemented by _load_from_state_dict method of ParallelModule in Shardformer,
|
||||
# model.load_state_dict can be directly called.
|
||||
state_dict = load_state_dict(checkpoint)
|
||||
model.load_state_dict(state_dict, strict=strict)
|
||||
|
||||
# Update master params if mixed-precision training is enabled.
|
||||
model_before_wrapping.update_master_params()
|
||||
|
||||
def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool):
|
||||
"""
|
||||
Save optimizer state dict to a file with given path.
|
||||
|
||||
Args:
|
||||
optimizer (OptimizerWrapper): Optimizer to save sharded state_dict.
|
||||
checkpoint (str): Path to save optimizer state_dict.
|
||||
gather_dtensor (bool): Whether to gather_dtensor, not used.
|
||||
"""
|
||||
if self.coordinator.is_master():
|
||||
logging.warning("Please avoid using unsharded checkpointing methods when dealing with large models!")
|
||||
|
||||
assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before saving!"
|
||||
|
||||
# optimizer states of parameters kept by local device('s pipeline stage)
|
||||
local_states = dict()
|
||||
|
||||
for param, state in optimizer.optim.state.items():
|
||||
if param is None:
|
||||
continue
|
||||
|
||||
# working param is needed for obtaining correct param_id
|
||||
master_to_working_map = optimizer.get_master_to_working_map()
|
||||
if master_to_working_map is not None:
|
||||
working_param = master_to_working_map[id(param)]
|
||||
else:
|
||||
working_param = param
|
||||
|
||||
# gather complete state from tp shards & dp shards
|
||||
param_id = optimizer.param_info["param2id"][id(working_param)]
|
||||
original_shape = optimizer.param_info["param2shape"][id(working_param)]
|
||||
local_states[param_id] = HybridParallelCheckpointIO.gather_from_sharded_optimizer_state(
|
||||
state,
|
||||
working_param,
|
||||
original_shape=original_shape,
|
||||
dp_group=self.dp_group,
|
||||
tp_group=self.tp_group,
|
||||
use_zero=self.use_zero,
|
||||
inplace=False,
|
||||
device=torch.device("cuda"),
|
||||
)
|
||||
|
||||
if self.pp_size == 1:
|
||||
# When pipeline is not used, let master rank directly save the collected state_dict.
|
||||
state_dict = {"param_groups": optimizer.param_info["param_groups"], "state": local_states}
|
||||
if self.coordinator.is_master():
|
||||
save_state_dict(state_dict, checkpoint, use_safetensors=False)
|
||||
else:
|
||||
# When pipeline is used, first collect state_dict from every pipeline stage, then save the complete state_dict.
|
||||
states_list = [None for _ in range(self.pp_size)]
|
||||
dist.barrier(self.pp_group)
|
||||
dist.all_gather_object(states_list, local_states, self.pp_group)
|
||||
|
||||
# Only the master rank do the saving.
|
||||
if self.coordinator.is_master():
|
||||
state_dict = {"param_groups": optimizer.param_info["param_groups"], "state": dict()}
|
||||
for _states in states_list:
|
||||
state_dict["state"].update(_states)
|
||||
save_state_dict(state_dict, checkpoint, use_safetensors=False)
|
||||
|
||||
def load_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str):
|
||||
"""
|
||||
Load optimizer from a file with given path.
|
||||
|
||||
Args:
|
||||
optimizer (OptimizerWrapper): The optimizer to be loaded.
|
||||
checkpoint_index_file (str): Path to the checkpoint file.
|
||||
"""
|
||||
|
||||
def _get_param_id_from_optimizer_param(
|
||||
param: torch.Tensor, master_to_working_map: Optional[Dict[int, torch.Tensor]] = None
|
||||
):
|
||||
if master_to_working_map is not None:
|
||||
working_param = master_to_working_map[id(param)]
|
||||
else:
|
||||
working_param = param
|
||||
return optimizer.param_info["param2id"][id(working_param)]
|
||||
|
||||
if self.coordinator.is_master():
|
||||
logging.warning("Please avoid using unsharded checkpointing methods when dealing with large models!")
|
||||
|
||||
assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before loading!"
|
||||
|
||||
# Complete optimizer state_dict loaded from checkpoint, need to be processed later.
|
||||
state_dict = load_state_dict(checkpoint)
|
||||
|
||||
# Load param_groups.
|
||||
updated_groups = []
|
||||
saved_groups = state_dict["param_groups"]
|
||||
for old_pg, saved_pg in zip(optimizer.optim.param_groups, saved_groups):
|
||||
new_pg = copy.deepcopy(saved_pg)
|
||||
new_pg["params"] = old_pg["params"] # Only keep the parameters kept by current pipeline stage.
|
||||
updated_groups.append(new_pg)
|
||||
optimizer.optim.__dict__.update({"param_groups": updated_groups})
|
||||
|
||||
# Load saved states to optimizer. First discard those states not belonging to current pipeline stage.
|
||||
master_to_working_map = optimizer.get_master_to_working_map()
|
||||
id_map = {}
|
||||
for pg in optimizer.optim.param_groups:
|
||||
for param in pg["params"]:
|
||||
param_id = _get_param_id_from_optimizer_param(param, master_to_working_map)
|
||||
id_map[param_id] = param
|
||||
load_states_into_optimizer(optimizer.optim, state_dict["state"], id_map, strict=True)
|
||||
|
||||
# Then shard the loaded optimizer states if using tp/zero.
|
||||
for param, state in optimizer.optim.state.items():
|
||||
if param is None:
|
||||
continue
|
||||
device = param.device
|
||||
if master_to_working_map is not None:
|
||||
working_param = master_to_working_map[id(param)]
|
||||
else:
|
||||
working_param = param
|
||||
original_shape = optimizer.param_info["param2shape"][id(working_param)]
|
||||
sharded_state = self.shard_from_complete_optimizer_state(
|
||||
state, current_shape=working_param.shape, original_shape=original_shape, device=device, inplace=True
|
||||
)
|
||||
optimizer.optim.state[param] = sharded_state
|
||||
|
||||
sharded_optimizer_loading_epilogue(optimizer.optim)
|
||||
|
||||
def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str):
|
||||
"""
|
||||
@@ -614,6 +786,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
||||
tp_group: ProcessGroup,
|
||||
use_zero: bool,
|
||||
inplace: bool,
|
||||
device: torch.device = torch.device("cpu"),
|
||||
) -> OrderedDict:
|
||||
"""
|
||||
With given parameter and its optimizer states, gather the complete optimizer state for saving.
|
||||
@@ -626,6 +799,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
||||
tp_group (ProcessGroup): The process group of tensor parallel.
|
||||
use_zero (bool): Whether Zero is used.
|
||||
inplace (bool): If set to True, will update the values of argument 'state' in place. Else will make a copy of state.
|
||||
device (torch.device): The destination device of loaded optimizer states. Defaults to torch.device('cpu').
|
||||
|
||||
Returns:
|
||||
OrderedDict: The complete optimizer state of given parameter.
|
||||
@@ -651,7 +825,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
||||
dist.all_gather(gather_tensor, v, group=tp_group)
|
||||
v = torch.cat(gather_tensor, dim=partition_dim)
|
||||
|
||||
state_[k] = v.detach().clone().cpu()
|
||||
state_[k] = v.detach().clone().to(device)
|
||||
|
||||
return state_
|
||||
|
||||
|
Reference in New Issue
Block a user