mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-02 01:28:31 +00:00
[ckpt] Add async ckpt api (#6136)
* fix * fix * fix * fix * fix * fix * fix * fix * fix * fix * fix * fix
This commit is contained in:
@@ -65,7 +65,14 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
|
||||
self.coordinator = DistCoordinator()
|
||||
self.logger = get_dist_logger()
|
||||
|
||||
def save_unsharded_model(self, model: GeminiDDP, checkpoint: str, gather_dtensor: bool, use_safetensors: bool):
|
||||
def save_unsharded_model(
|
||||
self,
|
||||
model: GeminiDDP,
|
||||
checkpoint: str,
|
||||
gather_dtensor: bool,
|
||||
use_safetensors: bool,
|
||||
use_async: bool = False,
|
||||
):
|
||||
"""
|
||||
Save sharded model to checkpoint but only on master process.
|
||||
The model should be unwrapped in self.load_model via ModelWrapper.unwrap.
|
||||
@@ -74,7 +81,10 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
|
||||
assert isinstance(model, GeminiDDP), "Please boost the model before saving!"
|
||||
state_dict = model.state_dict(only_rank_0=True)
|
||||
if self.coordinator.is_master():
|
||||
save_state_dict(state_dict, checkpoint, use_safetensors)
|
||||
if use_async:
|
||||
super().save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors, use_async)
|
||||
else:
|
||||
save_state_dict(state_dict, checkpoint, use_safetensors)
|
||||
|
||||
def load_unsharded_model(self, model: GeminiDDP, checkpoint: str, strict: bool = True):
|
||||
"""
|
||||
@@ -112,6 +122,7 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
|
||||
prefix: Optional[str] = None,
|
||||
max_shard_size: int = 1024,
|
||||
use_safetensors: bool = False,
|
||||
use_async: bool = False,
|
||||
):
|
||||
"""
|
||||
Save sharded model.
|
||||
@@ -130,27 +141,33 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
|
||||
|
||||
# Save shards of optimizer states.
|
||||
is_master = self.coordinator.is_master()
|
||||
total_size = save_state_dict_shards(
|
||||
sharded_state_dict=state_dict_shard,
|
||||
checkpoint=checkpoint_path,
|
||||
index_file=index_file,
|
||||
base_filename=weights_name,
|
||||
is_master=is_master,
|
||||
use_safetensors=use_safetensors,
|
||||
)
|
||||
|
||||
# only save the index file on the master rank
|
||||
if self.coordinator.is_master():
|
||||
index_file.append_meta_data("total_size", total_size)
|
||||
index_file.write_index_file(save_index_file)
|
||||
save_config_file(model.unwrap(), checkpoint_path)
|
||||
self.logger.info(
|
||||
f"The model is split into checkpoint shards. "
|
||||
f"You can find where each parameters has been saved in the "
|
||||
f"index located at {save_index_file}.",
|
||||
ranks=[0],
|
||||
if use_async:
|
||||
super().save_sharded_model(
|
||||
model, checkpoint_path, gather_dtensor, prefix, max_shard_size, use_safetensors, use_async
|
||||
)
|
||||
|
||||
else:
|
||||
total_size = save_state_dict_shards(
|
||||
sharded_state_dict=state_dict_shard,
|
||||
checkpoint=checkpoint_path,
|
||||
index_file=index_file,
|
||||
base_filename=weights_name,
|
||||
is_master=is_master,
|
||||
use_safetensors=use_safetensors,
|
||||
)
|
||||
|
||||
# only save the index file on the master rank
|
||||
if self.coordinator.is_master():
|
||||
index_file.append_meta_data("total_size", total_size)
|
||||
index_file.write_index_file(save_index_file)
|
||||
save_config_file(model.unwrap(), checkpoint_path)
|
||||
self.logger.info(
|
||||
f"The model is split into checkpoint shards. "
|
||||
f"You can find where each parameters has been saved in the "
|
||||
f"index located at {save_index_file}.",
|
||||
ranks=[0],
|
||||
)
|
||||
|
||||
def load_sharded_model(
|
||||
self, model: GeminiDDP, checkpoint_index_file: Path, strict: bool = False, use_safetensors: bool = False
|
||||
):
|
||||
|
@@ -54,7 +54,9 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO):
|
||||
sharded_osd = FSDP.scatter_full_optim_state_dict(checkpoint, fsdp_model)
|
||||
optimizer.load_state_dict(sharded_osd)
|
||||
|
||||
def save_unsharded_model(self, model: ModelWrapper, checkpoint: str, gather_dtensor: bool, use_safetensors: bool):
|
||||
def save_unsharded_model(
|
||||
self, model: ModelWrapper, checkpoint: str, gather_dtensor: bool, use_safetensors: bool, use_async: bool = False
|
||||
):
|
||||
"""
|
||||
Save model to checkpoint but only on master process.
|
||||
"""
|
||||
@@ -82,6 +84,7 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO):
|
||||
prefix: Optional[str] = None,
|
||||
size_per_shard: int = 1024,
|
||||
use_safetensors: bool = False,
|
||||
use_async: bool = False,
|
||||
):
|
||||
"""
|
||||
Save model to checkpoint but only on master process.
|
||||
|
Reference in New Issue
Block a user