mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 11:02:05 +00:00
[ckpt] Add async ckpt api (#6136)
* fix * fix * fix * fix * fix * fix * fix * fix * fix * fix * fix * fix
This commit is contained in:
@@ -117,6 +117,7 @@ class MoECheckpointIO(HybridParallelCheckpointIO):
|
||||
prefix: Optional[str] = None,
|
||||
size_per_shard: int = 1024,
|
||||
use_safetensors: bool = False,
|
||||
use_async: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Save sharded model checkpoint under the given checkpointing path.
|
||||
@@ -161,24 +162,27 @@ class MoECheckpointIO(HybridParallelCheckpointIO):
|
||||
|
||||
if self.pp_size == 1 and self.ep_size == 1:
|
||||
# When pipeline is not used, save the model shards as in general checkpointIO
|
||||
total_size = save_state_dict_shards(
|
||||
sharded_state_dict=state_dict_shard,
|
||||
checkpoint=checkpoint,
|
||||
index_file=index_file,
|
||||
base_filename=weights_name,
|
||||
is_master=control_saving,
|
||||
use_safetensors=use_safetensors,
|
||||
)
|
||||
if control_saving:
|
||||
index_file.append_meta_data("total_size", total_size)
|
||||
index_file.write_index_file(save_index_file)
|
||||
save_config_file(model, checkpoint)
|
||||
if self.verbose and self.coordinator.is_master():
|
||||
logging.info(
|
||||
f"The model is split into checkpoint shards. "
|
||||
f"You can find where each parameters has been saved in the "
|
||||
f"index located at {save_index_file}."
|
||||
)
|
||||
if use_async:
|
||||
super().save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors, use_async=use_async)
|
||||
else:
|
||||
total_size = save_state_dict_shards(
|
||||
sharded_state_dict=state_dict_shard,
|
||||
checkpoint=checkpoint,
|
||||
index_file=index_file,
|
||||
base_filename=weights_name,
|
||||
is_master=control_saving,
|
||||
use_safetensors=use_safetensors,
|
||||
)
|
||||
if control_saving:
|
||||
index_file.append_meta_data("total_size", total_size)
|
||||
index_file.write_index_file(save_index_file)
|
||||
save_config_file(model, checkpoint)
|
||||
if self.verbose and self.coordinator.is_master():
|
||||
logging.info(
|
||||
f"The model is split into checkpoint shards. "
|
||||
f"You can find where each parameters has been saved in the "
|
||||
f"index located at {save_index_file}."
|
||||
)
|
||||
|
||||
dist.barrier()
|
||||
else:
|
||||
@@ -708,10 +712,20 @@ class MoECheckpointIO(HybridParallelCheckpointIO):
|
||||
checkpoint: str,
|
||||
gather_dtensor: bool,
|
||||
use_safetensors: bool,
|
||||
use_async: bool = False,
|
||||
):
|
||||
state_dict = self.pre_save_model(model)
|
||||
if dist.get_rank() == 0:
|
||||
torch.save(state_dict, checkpoint)
|
||||
if use_async:
|
||||
super().save_unsharded_model(
|
||||
model=model,
|
||||
checkpoint=checkpoint,
|
||||
gather_dtensor=gather_dtensor,
|
||||
use_safetensors=use_safetensors,
|
||||
use_async=use_async,
|
||||
)
|
||||
else:
|
||||
torch.save(state_dict, checkpoint)
|
||||
dist.barrier()
|
||||
|
||||
# Copied from colossalai.moe
|
||||
|
Reference in New Issue
Block a user