[ckpt] Add async ckpt api (#6136)

* fix

* fix

* fix

* fix

* fix

* fix

* fix

* fix

* fix

* fix

* fix

* fix
This commit is contained in:
Wang Binluo
2024-11-15 18:19:16 +08:00
committed by Hongxin Liu
parent d4a436051d
commit 8e08c27e19
12 changed files with 174 additions and 86 deletions

View File

@@ -65,7 +65,14 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
self.coordinator = DistCoordinator()
self.logger = get_dist_logger()
def save_unsharded_model(self, model: GeminiDDP, checkpoint: str, gather_dtensor: bool, use_safetensors: bool):
def save_unsharded_model(
self,
model: GeminiDDP,
checkpoint: str,
gather_dtensor: bool,
use_safetensors: bool,
use_async: bool = False,
):
"""
Save sharded model to checkpoint but only on master process.
The model should be unwrapped in self.load_model via ModelWrapper.unwrap.
@@ -74,7 +81,10 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
assert isinstance(model, GeminiDDP), "Please boost the model before saving!"
state_dict = model.state_dict(only_rank_0=True)
if self.coordinator.is_master():
save_state_dict(state_dict, checkpoint, use_safetensors)
if use_async:
super().save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors, use_async)
else:
save_state_dict(state_dict, checkpoint, use_safetensors)
def load_unsharded_model(self, model: GeminiDDP, checkpoint: str, strict: bool = True):
"""
@@ -112,6 +122,7 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
prefix: Optional[str] = None,
max_shard_size: int = 1024,
use_safetensors: bool = False,
use_async: bool = False,
):
"""
Save sharded model.
@@ -130,27 +141,33 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
# Save shards of optimizer states.
is_master = self.coordinator.is_master()
total_size = save_state_dict_shards(
sharded_state_dict=state_dict_shard,
checkpoint=checkpoint_path,
index_file=index_file,
base_filename=weights_name,
is_master=is_master,
use_safetensors=use_safetensors,
)
# only save the index file on the master rank
if self.coordinator.is_master():
index_file.append_meta_data("total_size", total_size)
index_file.write_index_file(save_index_file)
save_config_file(model.unwrap(), checkpoint_path)
self.logger.info(
f"The model is split into checkpoint shards. "
f"You can find where each parameters has been saved in the "
f"index located at {save_index_file}.",
ranks=[0],
if use_async:
super().save_sharded_model(
model, checkpoint_path, gather_dtensor, prefix, max_shard_size, use_safetensors, use_async
)
else:
total_size = save_state_dict_shards(
sharded_state_dict=state_dict_shard,
checkpoint=checkpoint_path,
index_file=index_file,
base_filename=weights_name,
is_master=is_master,
use_safetensors=use_safetensors,
)
# only save the index file on the master rank
if self.coordinator.is_master():
index_file.append_meta_data("total_size", total_size)
index_file.write_index_file(save_index_file)
save_config_file(model.unwrap(), checkpoint_path)
self.logger.info(
f"The model is split into checkpoint shards. "
f"You can find where each parameters has been saved in the "
f"index located at {save_index_file}.",
ranks=[0],
)
def load_sharded_model(
self, model: GeminiDDP, checkpoint_index_file: Path, strict: bool = False, use_safetensors: bool = False
):

View File

@@ -54,7 +54,9 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO):
sharded_osd = FSDP.scatter_full_optim_state_dict(checkpoint, fsdp_model)
optimizer.load_state_dict(sharded_osd)
def save_unsharded_model(self, model: ModelWrapper, checkpoint: str, gather_dtensor: bool, use_safetensors: bool):
def save_unsharded_model(
self, model: ModelWrapper, checkpoint: str, gather_dtensor: bool, use_safetensors: bool, use_async: bool = False
):
"""
Save model to checkpoint but only on master process.
"""
@@ -82,6 +84,7 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO):
prefix: Optional[str] = None,
size_per_shard: int = 1024,
use_safetensors: bool = False,
use_async: bool = False,
):
"""
Save model to checkpoint but only on master process.