[ckpt] Add async ckpt api (#6136)

* fix

* fix

* fix

* fix

* fix

* fix

* fix

* fix

* fix

* fix

* fix

* fix
This commit is contained in:
Wang Binluo
2024-11-15 18:19:16 +08:00
committed by Hongxin Liu
parent d4a436051d
commit 8e08c27e19
12 changed files with 174 additions and 86 deletions

View File

@@ -176,10 +176,10 @@ class CheckpointIO(ABC):
if shard:
self.save_sharded_model(
model, checkpoint, gather_dtensor, prefix, size_per_shard, use_safetensors, use_async=use_async
model, checkpoint, gather_dtensor, prefix, size_per_shard, use_safetensors, use_async
)
else:
self.save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors, use_async=use_async)
self.save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors, use_async)
def load_optimizer(self, optimizer: Optimizer, checkpoint: str, prefix: str = None, size_per_shard: int = 1024):
"""