[booster] torch fsdp fix ckpt (#3788)

This commit is contained in:
wukong1992
2023-05-23 16:58:45 +08:00
committed by GitHub
parent 9265f2d4d7
commit 6b305a99d6
5 changed files with 230 additions and 186 deletions

View File

@@ -196,7 +196,7 @@ class Booster:
If true, the checkpoint will be a folder. Otherwise, it will be a single file. Defaults to False.
size_per_shard (int, optional): Maximum size of checkpoint shard file in MB. This is useful only when ``shard=True``. Defaults to 1024.
"""
self.checkpoint_io.save_model(model, checkpoint, prefix, shard, size_per_shard)
self.checkpoint_io.save_model(model, checkpoint=checkpoint, shard=shard, size_per_shard=size_per_shard)
def load_optimizer(self, optimizer: Optimizer, checkpoint: str):
"""Load optimizer from checkpoint.